Andrew commited on
Commit ·
9e2d0e8
1
Parent(s): a5cfbf7
test 2?
Browse files- acestep/__init__.py +1 -0
- acestep/audio_utils.py +354 -0
- acestep/constants.py +193 -0
- acestep/constrained_logits_processor.py +0 -0
- acestep/dit_alignment_score.py +877 -0
- acestep/genres_vocab.txt +0 -0
- acestep/gpu_config.py +549 -0
- acestep/handler.py +0 -0
- acestep/inference.py +1310 -0
- acestep/llm_inference.py +0 -0
- acestep/model_downloader.py +634 -0
- handler.py +262 -272
- requirements.txt +8 -2
acestep/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""ACE-Step package."""
|
acestep/audio_utils.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio saving and transcoding utility module
|
| 3 |
+
|
| 4 |
+
Independent audio file operations outside of handler, supporting:
|
| 5 |
+
- Save audio tensor/numpy to files (default FLAC format, fast)
|
| 6 |
+
- Format conversion (FLAC/WAV/MP3)
|
| 7 |
+
- Batch processing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hashlib
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Union, Optional, List, Tuple
|
| 15 |
+
import torch
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torchaudio
|
| 18 |
+
from loguru import logger
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AudioSaver:
|
| 22 |
+
"""Audio saving and transcoding utility class"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, default_format: str = "flac"):
|
| 25 |
+
"""
|
| 26 |
+
Initialize audio saver
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
default_format: Default save format ('flac', 'wav', 'mp3')
|
| 30 |
+
"""
|
| 31 |
+
self.default_format = default_format.lower()
|
| 32 |
+
if self.default_format not in ["flac", "wav", "mp3"]:
|
| 33 |
+
logger.warning(f"Unsupported format {default_format}, using 'flac'")
|
| 34 |
+
self.default_format = "flac"
|
| 35 |
+
|
| 36 |
+
def save_audio(
|
| 37 |
+
self,
|
| 38 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 39 |
+
output_path: Union[str, Path],
|
| 40 |
+
sample_rate: int = 48000,
|
| 41 |
+
format: Optional[str] = None,
|
| 42 |
+
channels_first: bool = True,
|
| 43 |
+
) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Save audio data to file
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
|
| 49 |
+
output_path: Output file path (extension can be omitted)
|
| 50 |
+
sample_rate: Sample rate
|
| 51 |
+
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
|
| 52 |
+
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Actual saved file path
|
| 56 |
+
"""
|
| 57 |
+
format = (format or self.default_format).lower()
|
| 58 |
+
if format not in ["flac", "wav", "mp3"]:
|
| 59 |
+
logger.warning(f"Unsupported format {format}, using {self.default_format}")
|
| 60 |
+
format = self.default_format
|
| 61 |
+
|
| 62 |
+
# Ensure output path has correct extension
|
| 63 |
+
output_path = Path(output_path)
|
| 64 |
+
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
|
| 65 |
+
output_path = output_path.with_suffix(f'.{format}')
|
| 66 |
+
|
| 67 |
+
# Convert to torch tensor
|
| 68 |
+
if isinstance(audio_data, np.ndarray):
|
| 69 |
+
if channels_first:
|
| 70 |
+
# numpy [samples, channels] -> tensor [channels, samples]
|
| 71 |
+
audio_tensor = torch.from_numpy(audio_data.T).float()
|
| 72 |
+
else:
|
| 73 |
+
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
|
| 74 |
+
audio_tensor = torch.from_numpy(audio_data).float()
|
| 75 |
+
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
|
| 76 |
+
audio_tensor = audio_tensor.T
|
| 77 |
+
else:
|
| 78 |
+
# torch tensor
|
| 79 |
+
audio_tensor = audio_data.cpu().float()
|
| 80 |
+
if not channels_first and audio_tensor.dim() == 2:
|
| 81 |
+
# [samples, channels] -> [channels, samples]
|
| 82 |
+
if audio_tensor.shape[0] > audio_tensor.shape[1]:
|
| 83 |
+
audio_tensor = audio_tensor.T
|
| 84 |
+
|
| 85 |
+
# Ensure memory is contiguous
|
| 86 |
+
audio_tensor = audio_tensor.contiguous()
|
| 87 |
+
|
| 88 |
+
# Select backend and save
|
| 89 |
+
try:
|
| 90 |
+
if format == "mp3":
|
| 91 |
+
# MP3 uses ffmpeg backend
|
| 92 |
+
torchaudio.save(
|
| 93 |
+
str(output_path),
|
| 94 |
+
audio_tensor,
|
| 95 |
+
sample_rate,
|
| 96 |
+
channels_first=True,
|
| 97 |
+
backend='ffmpeg',
|
| 98 |
+
)
|
| 99 |
+
elif format in ["flac", "wav"]:
|
| 100 |
+
# FLAC and WAV use soundfile backend (fastest)
|
| 101 |
+
torchaudio.save(
|
| 102 |
+
str(output_path),
|
| 103 |
+
audio_tensor,
|
| 104 |
+
sample_rate,
|
| 105 |
+
channels_first=True,
|
| 106 |
+
backend='soundfile',
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
# Other formats use default backend
|
| 110 |
+
torchaudio.save(
|
| 111 |
+
str(output_path),
|
| 112 |
+
audio_tensor,
|
| 113 |
+
sample_rate,
|
| 114 |
+
channels_first=True,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 118 |
+
return str(output_path)
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
try:
|
| 122 |
+
import soundfile as sf
|
| 123 |
+
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
|
| 124 |
+
sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
|
| 125 |
+
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 126 |
+
return str(output_path)
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"[AudioSaver] Failed to save audio: {e}")
|
| 129 |
+
raise
|
| 130 |
+
|
| 131 |
+
def convert_audio(
|
| 132 |
+
self,
|
| 133 |
+
input_path: Union[str, Path],
|
| 134 |
+
output_path: Union[str, Path],
|
| 135 |
+
output_format: str,
|
| 136 |
+
remove_input: bool = False,
|
| 137 |
+
) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Convert audio format
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
input_path: Input audio file path
|
| 143 |
+
output_path: Output audio file path
|
| 144 |
+
output_format: Target format ('flac', 'wav', 'mp3')
|
| 145 |
+
remove_input: Whether to delete input file
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Output file path
|
| 149 |
+
"""
|
| 150 |
+
input_path = Path(input_path)
|
| 151 |
+
output_path = Path(output_path)
|
| 152 |
+
|
| 153 |
+
if not input_path.exists():
|
| 154 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
| 155 |
+
|
| 156 |
+
# Load audio
|
| 157 |
+
audio_tensor, sample_rate = torchaudio.load(str(input_path))
|
| 158 |
+
|
| 159 |
+
# Save as new format
|
| 160 |
+
output_path = self.save_audio(
|
| 161 |
+
audio_tensor,
|
| 162 |
+
output_path,
|
| 163 |
+
sample_rate=sample_rate,
|
| 164 |
+
format=output_format,
|
| 165 |
+
channels_first=True
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Delete input file if needed
|
| 169 |
+
if remove_input:
|
| 170 |
+
input_path.unlink()
|
| 171 |
+
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
|
| 172 |
+
|
| 173 |
+
return output_path
|
| 174 |
+
|
| 175 |
+
def save_batch(
|
| 176 |
+
self,
|
| 177 |
+
audio_batch: Union[List[torch.Tensor], torch.Tensor],
|
| 178 |
+
output_dir: Union[str, Path],
|
| 179 |
+
file_prefix: str = "audio",
|
| 180 |
+
sample_rate: int = 48000,
|
| 181 |
+
format: Optional[str] = None,
|
| 182 |
+
channels_first: bool = True,
|
| 183 |
+
) -> List[str]:
|
| 184 |
+
"""
|
| 185 |
+
Save audio batch
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
|
| 189 |
+
output_dir: Output directory
|
| 190 |
+
file_prefix: File prefix
|
| 191 |
+
sample_rate: Sample rate
|
| 192 |
+
format: Audio format
|
| 193 |
+
channels_first: Tensor format flag
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
List of saved file paths
|
| 197 |
+
"""
|
| 198 |
+
output_dir = Path(output_dir)
|
| 199 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 200 |
+
|
| 201 |
+
# Process batch
|
| 202 |
+
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
|
| 203 |
+
# [batch, channels, samples]
|
| 204 |
+
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
|
| 205 |
+
elif isinstance(audio_batch, list):
|
| 206 |
+
audio_list = audio_batch
|
| 207 |
+
else:
|
| 208 |
+
audio_list = [audio_batch]
|
| 209 |
+
|
| 210 |
+
saved_paths = []
|
| 211 |
+
for i, audio in enumerate(audio_list):
|
| 212 |
+
output_path = output_dir / f"{file_prefix}_{i:04d}"
|
| 213 |
+
saved_path = self.save_audio(
|
| 214 |
+
audio,
|
| 215 |
+
output_path,
|
| 216 |
+
sample_rate=sample_rate,
|
| 217 |
+
format=format,
|
| 218 |
+
channels_first=channels_first
|
| 219 |
+
)
|
| 220 |
+
saved_paths.append(saved_path)
|
| 221 |
+
|
| 222 |
+
return saved_paths
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_audio_file_hash(audio_file) -> str:
|
| 226 |
+
"""
|
| 227 |
+
Get hash identifier for an audio file.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
audio_file: Path to audio file (str) or file-like object
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Hash string or empty string
|
| 234 |
+
"""
|
| 235 |
+
if audio_file is None:
|
| 236 |
+
return ""
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
if isinstance(audio_file, str):
|
| 240 |
+
if os.path.exists(audio_file):
|
| 241 |
+
with open(audio_file, 'rb') as f:
|
| 242 |
+
return hashlib.md5(f.read()).hexdigest()
|
| 243 |
+
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
|
| 244 |
+
elif hasattr(audio_file, 'name'):
|
| 245 |
+
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
|
| 246 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 247 |
+
except Exception:
|
| 248 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def generate_uuid_from_params(params_dict) -> str:
|
| 252 |
+
"""
|
| 253 |
+
Generate deterministic UUID from generation parameters.
|
| 254 |
+
Same parameters will always generate the same UUID.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
params_dict: Dictionary of parameters
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
UUID string
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
| 264 |
+
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
| 265 |
+
hash_hex = hash_obj.hexdigest()
|
| 266 |
+
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
|
| 267 |
+
return uuid_str
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def generate_uuid_from_audio_data(
|
| 271 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 272 |
+
seed: Optional[int] = None
|
| 273 |
+
) -> str:
|
| 274 |
+
"""
|
| 275 |
+
Generate UUID from audio data (for caching/deduplication)
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
audio_data: Audio data
|
| 279 |
+
seed: Optional seed value
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
UUID string
|
| 283 |
+
"""
|
| 284 |
+
if isinstance(audio_data, torch.Tensor):
|
| 285 |
+
# Convert to numpy and calculate hash
|
| 286 |
+
audio_np = audio_data.cpu().numpy()
|
| 287 |
+
else:
|
| 288 |
+
audio_np = audio_data
|
| 289 |
+
|
| 290 |
+
# Calculate data hash
|
| 291 |
+
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
|
| 292 |
+
|
| 293 |
+
if seed is not None:
|
| 294 |
+
combined = f"{data_hash}_{seed}"
|
| 295 |
+
return hashlib.md5(combined.encode()).hexdigest()
|
| 296 |
+
|
| 297 |
+
return data_hash
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# Global default instance
|
| 301 |
+
_default_saver = AudioSaver(default_format="flac")
|
| 302 |
+
|
| 303 |
+
SILENT_RMS_THRESHOLD = 1e-5
|
| 304 |
+
SILENT_PEAK_THRESHOLD = 1e-5
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def is_audio_silent(
|
| 308 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 309 |
+
rms_threshold: float = SILENT_RMS_THRESHOLD,
|
| 310 |
+
peak_threshold: float = SILENT_PEAK_THRESHOLD,
|
| 311 |
+
channels_first: bool = True,
|
| 312 |
+
) -> Tuple[bool, float, float]:
|
| 313 |
+
"""
|
| 314 |
+
Check if audio is silent or near-silent (e.g. zeroed conditioning output).
|
| 315 |
+
Returns (is_silent, rms, peak) where rms/peak are computed over the full signal.
|
| 316 |
+
"""
|
| 317 |
+
if audio_data is None:
|
| 318 |
+
return True, 0.0, 0.0
|
| 319 |
+
if isinstance(audio_data, np.ndarray):
|
| 320 |
+
x = np.asarray(audio_data, dtype=np.float64).ravel()
|
| 321 |
+
else:
|
| 322 |
+
x = audio_data.cpu().float().numpy().ravel()
|
| 323 |
+
if x.size == 0:
|
| 324 |
+
return True, 0.0, 0.0
|
| 325 |
+
rms = float(np.sqrt(np.mean(x * x)))
|
| 326 |
+
peak = float(np.max(np.abs(x)))
|
| 327 |
+
is_silent = rms <= rms_threshold and peak <= peak_threshold
|
| 328 |
+
return is_silent, rms, peak
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def save_audio(
|
| 332 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 333 |
+
output_path: Union[str, Path],
|
| 334 |
+
sample_rate: int = 48000,
|
| 335 |
+
format: Optional[str] = None,
|
| 336 |
+
channels_first: bool = True,
|
| 337 |
+
) -> str:
|
| 338 |
+
"""
|
| 339 |
+
Convenience function: save audio (using default configuration)
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
audio_data: Audio data
|
| 343 |
+
output_path: Output path
|
| 344 |
+
sample_rate: Sample rate
|
| 345 |
+
format: Format (default flac)
|
| 346 |
+
channels_first: Tensor format flag
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
Saved file path
|
| 350 |
+
"""
|
| 351 |
+
return _default_saver.save_audio(
|
| 352 |
+
audio_data, output_path, sample_rate, format, channels_first
|
| 353 |
+
)
|
| 354 |
+
|
acestep/constants.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Constants for ACE-Step
|
| 3 |
+
Centralized constants used across the codebase
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ==============================================================================
|
| 7 |
+
# Language Constants
|
| 8 |
+
# ==============================================================================
|
| 9 |
+
|
| 10 |
+
# Supported languages for vocal generation and language detection
|
| 11 |
+
# Covers major world languages with good TTS support in the underlying model
|
| 12 |
+
# 'unknown' is used when language cannot be determined automatically
|
| 13 |
+
VALID_LANGUAGES = [
|
| 14 |
+
'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
|
| 15 |
+
'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
|
| 16 |
+
'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
|
| 17 |
+
'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
|
| 18 |
+
'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
|
| 19 |
+
'unknown'
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ==============================================================================
|
| 24 |
+
# Keyscale Constants
|
| 25 |
+
# ==============================================================================
|
| 26 |
+
|
| 27 |
+
# Musical note names using standard Western notation
|
| 28 |
+
KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
|
| 29 |
+
|
| 30 |
+
# Supported accidentals: natural, ASCII sharp/flat, Unicode sharp/flat
|
| 31 |
+
KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
|
| 32 |
+
|
| 33 |
+
# Major and minor scale modes
|
| 34 |
+
KEYSCALE_MODES = ['major', 'minor']
|
| 35 |
+
|
| 36 |
+
# Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
|
| 37 |
+
# Examples: "C major", "F# minor", "B♭ major"
|
| 38 |
+
VALID_KEYSCALES = set()
|
| 39 |
+
for note in KEYSCALE_NOTES:
|
| 40 |
+
for acc in KEYSCALE_ACCIDENTALS:
|
| 41 |
+
for mode in KEYSCALE_MODES:
|
| 42 |
+
VALID_KEYSCALES.add(f"{note}{acc} {mode}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ==============================================================================
|
| 46 |
+
# Metadata Range Constants
|
| 47 |
+
# ==============================================================================
|
| 48 |
+
|
| 49 |
+
# BPM (Beats Per Minute) range - covers most musical styles
|
| 50 |
+
# 30 BPM: Very slow ballads, ambient music
|
| 51 |
+
# 300 BPM: Fast electronic dance music, extreme metal
|
| 52 |
+
BPM_MIN = 30
|
| 53 |
+
BPM_MAX = 300
|
| 54 |
+
|
| 55 |
+
# Duration range (in seconds) - balances quality vs. computational cost
|
| 56 |
+
# 10s: Short loops, musical excerpts
|
| 57 |
+
# 600s: Full songs, extended compositions (10 minutes)
|
| 58 |
+
DURATION_MIN = 10
|
| 59 |
+
DURATION_MAX = 600
|
| 60 |
+
|
| 61 |
+
# Valid time signatures - common musical meter patterns
|
| 62 |
+
# 2: 2/4 time (marches, polka)
|
| 63 |
+
# 3: 3/4 time (waltzes, ballads)
|
| 64 |
+
# 4: 4/4 time (most pop, rock, hip-hop)
|
| 65 |
+
# 6: 6/8 time (compound time, folk dances)
|
| 66 |
+
VALID_TIME_SIGNATURES = [2, 3, 4, 6]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ==============================================================================
|
| 70 |
+
# Task Type Constants
|
| 71 |
+
# ==============================================================================
|
| 72 |
+
|
| 73 |
+
# All supported generation tasks across different model variants
|
| 74 |
+
TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 75 |
+
|
| 76 |
+
# Task types available for turbo models (optimized subset for speed)
|
| 77 |
+
# - text2music: Generate from text descriptions
|
| 78 |
+
# - repaint: Selective audio editing/regeneration
|
| 79 |
+
# - cover: Style transfer using reference audio
|
| 80 |
+
TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
|
| 81 |
+
|
| 82 |
+
# Task types available for base models (full feature set)
|
| 83 |
+
# Additional tasks requiring more computational resources:
|
| 84 |
+
# - extract: Separate individual tracks/stems from audio
|
| 85 |
+
# - lego: Multi-track generation (add layers)
|
| 86 |
+
# - complete: Automatic completion of partial audio
|
| 87 |
+
TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ==============================================================================
|
| 91 |
+
# Instruction Constants
|
| 92 |
+
# ==============================================================================
|
| 93 |
+
|
| 94 |
+
# Default instructions
|
| 95 |
+
DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
| 96 |
+
DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
|
| 97 |
+
DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
|
| 98 |
+
DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
|
| 99 |
+
DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
|
| 100 |
+
|
| 101 |
+
# Instruction templates for each task type
|
| 102 |
+
# Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
|
| 103 |
+
# These should be formatted using .format() or f-strings when used
|
| 104 |
+
TASK_INSTRUCTIONS = {
|
| 105 |
+
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
| 106 |
+
"repaint": "Repaint the mask area based on the given conditions:",
|
| 107 |
+
"cover": "Generate audio semantic tokens based on the given conditions:",
|
| 108 |
+
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
| 109 |
+
"extract_default": "Extract the track from the audio:",
|
| 110 |
+
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
| 111 |
+
"lego_default": "Generate the track based on the audio context:",
|
| 112 |
+
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
| 113 |
+
"complete_default": "Complete the input track:",
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ==============================================================================
|
| 118 |
+
# Track/Instrument Constants
|
| 119 |
+
# ==============================================================================
|
| 120 |
+
|
| 121 |
+
# Supported instrumental track types for multi-track generation and extraction
|
| 122 |
+
# Organized by instrument families for logical grouping:
|
| 123 |
+
# - Wind instruments: woodwinds, brass
|
| 124 |
+
# - Electronic: fx (effects), synth (synthesizer)
|
| 125 |
+
# - String instruments: strings, guitar, bass
|
| 126 |
+
# - Rhythm section: percussion, drums, keyboard
|
| 127 |
+
# - Vocals: backing_vocals, vocals (lead vocals)
|
| 128 |
+
TRACK_NAMES = [
|
| 129 |
+
"woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 130 |
+
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
# Template for SFT (Supervised Fine-Tuning) model prompts
|
| 134 |
+
# Used to format inputs for the language model with instruction, caption, and metadata
|
| 135 |
+
SFT_GEN_PROMPT = """# Instruction
|
| 136 |
+
{}
|
| 137 |
+
|
| 138 |
+
# Caption
|
| 139 |
+
{}
|
| 140 |
+
|
| 141 |
+
# Metas
|
| 142 |
+
{}<|endoftext|>
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ==============================================================================
|
| 147 |
+
# GPU Memory Configuration Constants
|
| 148 |
+
# ==============================================================================
|
| 149 |
+
|
| 150 |
+
# GPU tier thresholds (in GB)
|
| 151 |
+
GPU_TIER_THRESHOLDS = {
|
| 152 |
+
"tier1": 4, # <= 4GB
|
| 153 |
+
"tier2": 6, # 4-6GB
|
| 154 |
+
"tier3": 8, # 6-8GB
|
| 155 |
+
"tier4": 12, # 8-12GB
|
| 156 |
+
"tier5": 16, # 12-16GB
|
| 157 |
+
"tier6": 24, # 16-24GB
|
| 158 |
+
# "unlimited" for >= 24GB
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
# LM model memory requirements (in GB)
|
| 162 |
+
LM_MODEL_MEMORY_GB = {
|
| 163 |
+
"0.6B": 3.0,
|
| 164 |
+
"1.7B": 8.0,
|
| 165 |
+
"4B": 12.0,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
# LM model names mapping
|
| 169 |
+
LM_MODEL_NAMES = {
|
| 170 |
+
"0.6B": "acestep-5Hz-lm-0.6B",
|
| 171 |
+
"1.7B": "acestep-5Hz-lm-1.7B",
|
| 172 |
+
"4B": "acestep-5Hz-lm-4B",
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ==============================================================================
|
| 177 |
+
# Debug Constants
|
| 178 |
+
# ==============================================================================
|
| 179 |
+
|
| 180 |
+
# Tensor debug mode (values: "OFF" | "ON" | "VERBOSE")
|
| 181 |
+
TENSOR_DEBUG_MODE = "OFF"
|
| 182 |
+
|
| 183 |
+
# Placeholder debug switches for other main functionality (default "OFF")
|
| 184 |
+
# Update names/usage as features adopt them.
|
| 185 |
+
DEBUG_API_SERVER = "OFF"
|
| 186 |
+
DEBUG_INFERENCE = "OFF"
|
| 187 |
+
DEBUG_TRAINING = "OFF"
|
| 188 |
+
DEBUG_DATASET = "OFF"
|
| 189 |
+
DEBUG_AUDIO = "OFF"
|
| 190 |
+
DEBUG_LLM = "OFF"
|
| 191 |
+
DEBUG_UI = "OFF"
|
| 192 |
+
DEBUG_MODEL_LOADING = "OFF"
|
| 193 |
+
DEBUG_GPU = "OFF"
|
acestep/constrained_logits_processor.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/dit_alignment_score.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DiT Alignment Score Module
|
| 3 |
+
|
| 4 |
+
This module provides lyrics-to-audio alignment using cross-attention matrices
|
| 5 |
+
from DiT model for generating LRC timestamps.
|
| 6 |
+
|
| 7 |
+
Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
|
| 8 |
+
"""
|
| 9 |
+
import numba
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from dataclasses import dataclass, asdict
|
| 14 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ================= Data Classes =================
|
| 18 |
+
@dataclass
|
| 19 |
+
class TokenTimestamp:
|
| 20 |
+
"""Stores per-token timing information."""
|
| 21 |
+
token_id: int
|
| 22 |
+
text: str
|
| 23 |
+
start: float
|
| 24 |
+
end: float
|
| 25 |
+
probability: float
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class SentenceTimestamp:
|
| 30 |
+
"""Stores per-sentence timing information with token list."""
|
| 31 |
+
text: str
|
| 32 |
+
start: float
|
| 33 |
+
end: float
|
| 34 |
+
tokens: List[TokenTimestamp]
|
| 35 |
+
confidence: float
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ================= DTW Algorithm (Numba Optimized) =================
|
| 39 |
+
@numba.jit(nopython=True)
|
| 40 |
+
def dtw_cpu(x: np.ndarray):
|
| 41 |
+
"""
|
| 42 |
+
Dynamic Time Warping algorithm optimized with Numba.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
x: Cost matrix of shape [N, M]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (text_indices, time_indices) arrays
|
| 49 |
+
"""
|
| 50 |
+
N, M = x.shape
|
| 51 |
+
# Use float32 for memory efficiency
|
| 52 |
+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
| 53 |
+
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
| 54 |
+
cost[0, 0] = 0
|
| 55 |
+
|
| 56 |
+
for j in range(1, M + 1):
|
| 57 |
+
for i in range(1, N + 1):
|
| 58 |
+
c0 = cost[i - 1, j - 1]
|
| 59 |
+
c1 = cost[i - 1, j]
|
| 60 |
+
c2 = cost[i, j - 1]
|
| 61 |
+
|
| 62 |
+
if c0 < c1 and c0 < c2:
|
| 63 |
+
c, t = c0, 0
|
| 64 |
+
elif c1 < c0 and c1 < c2:
|
| 65 |
+
c, t = c1, 1
|
| 66 |
+
else:
|
| 67 |
+
c, t = c2, 2
|
| 68 |
+
|
| 69 |
+
cost[i, j] = x[i - 1, j - 1] + c
|
| 70 |
+
trace[i, j] = t
|
| 71 |
+
|
| 72 |
+
return _backtrace(trace, N, M)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@numba.jit(nopython=True)
|
| 76 |
+
def _backtrace(trace: np.ndarray, N: int, M: int):
|
| 77 |
+
"""
|
| 78 |
+
Optimized backtrace function for DTW.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
trace: Trace matrix of shape (N+1, M+1)
|
| 82 |
+
N, M: Original matrix dimensions
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Path array of shape (2, path_len) - first row is text indices, second is time indices
|
| 86 |
+
"""
|
| 87 |
+
# Boundary handling
|
| 88 |
+
trace[0, :] = 2
|
| 89 |
+
trace[:, 0] = 1
|
| 90 |
+
|
| 91 |
+
# Pre-allocate array, max path length is N+M
|
| 92 |
+
max_path_len = N + M
|
| 93 |
+
path = np.zeros((2, max_path_len), dtype=np.int32)
|
| 94 |
+
|
| 95 |
+
i, j = N, M
|
| 96 |
+
path_idx = max_path_len - 1
|
| 97 |
+
|
| 98 |
+
while i > 0 or j > 0:
|
| 99 |
+
path[0, path_idx] = i - 1 # text index
|
| 100 |
+
path[1, path_idx] = j - 1 # time index
|
| 101 |
+
path_idx -= 1
|
| 102 |
+
|
| 103 |
+
t = trace[i, j]
|
| 104 |
+
if t == 0:
|
| 105 |
+
i -= 1
|
| 106 |
+
j -= 1
|
| 107 |
+
elif t == 1:
|
| 108 |
+
i -= 1
|
| 109 |
+
elif t == 2:
|
| 110 |
+
j -= 1
|
| 111 |
+
else:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
actual_len = max_path_len - path_idx - 1
|
| 115 |
+
return path[:, path_idx + 1:max_path_len]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ================= Utility Functions =================
|
| 119 |
+
def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
|
| 120 |
+
"""
|
| 121 |
+
Apply median filter to tensor.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: Input tensor
|
| 125 |
+
filter_width: Width of median filter
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Filtered tensor
|
| 129 |
+
"""
|
| 130 |
+
pad_width = filter_width // 2
|
| 131 |
+
if x.shape[-1] <= pad_width:
|
| 132 |
+
return x
|
| 133 |
+
if x.ndim == 2:
|
| 134 |
+
x = x[None, :]
|
| 135 |
+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
| 136 |
+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
| 137 |
+
if result.ndim > 2:
|
| 138 |
+
result = result.squeeze(0)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ================= Main Aligner Class =================
|
| 143 |
+
class MusicStampsAligner:
|
| 144 |
+
"""
|
| 145 |
+
Aligner class for generating lyrics timestamps from cross-attention matrices.
|
| 146 |
+
|
| 147 |
+
Uses bidirectional consensus denoising and DTW for alignment.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, tokenizer):
|
| 151 |
+
"""
|
| 152 |
+
Initialize the aligner.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
tokenizer: Text tokenizer for decoding tokens
|
| 156 |
+
"""
|
| 157 |
+
self.tokenizer = tokenizer
|
| 158 |
+
|
| 159 |
+
def _apply_bidirectional_consensus(
|
| 160 |
+
self,
|
| 161 |
+
weights_stack: torch.Tensor,
|
| 162 |
+
violence_level: float,
|
| 163 |
+
medfilt_width: int
|
| 164 |
+
) -> tuple:
|
| 165 |
+
"""
|
| 166 |
+
Core denoising logic using bidirectional consensus.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
weights_stack: Attention weights [Heads, Tokens, Frames]
|
| 170 |
+
violence_level: Denoising strength coefficient
|
| 171 |
+
medfilt_width: Median filter width
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Tuple of (calc_matrix, energy_matrix) as numpy arrays
|
| 175 |
+
"""
|
| 176 |
+
# A. Bidirectional Consensus
|
| 177 |
+
row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
|
| 178 |
+
col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
|
| 179 |
+
processed = row_prob * col_prob
|
| 180 |
+
|
| 181 |
+
# 1. Row suppression (kill horizontal crossing lines)
|
| 182 |
+
row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
|
| 183 |
+
processed = processed - (violence_level * row_medians)
|
| 184 |
+
processed = torch.relu(processed)
|
| 185 |
+
|
| 186 |
+
# 2. Column suppression (kill vertical crossing lines)
|
| 187 |
+
col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
|
| 188 |
+
processed = processed - (violence_level * col_medians)
|
| 189 |
+
processed = torch.relu(processed)
|
| 190 |
+
|
| 191 |
+
# C. Power sharpening
|
| 192 |
+
processed = processed ** 2
|
| 193 |
+
|
| 194 |
+
# Energy matrix for confidence
|
| 195 |
+
energy_matrix = processed.mean(dim=0).cpu().numpy()
|
| 196 |
+
|
| 197 |
+
# D. Z-Score normalization
|
| 198 |
+
std, mean = torch.std_mean(processed, unbiased=False)
|
| 199 |
+
weights_processed = (processed - mean) / (std + 1e-9)
|
| 200 |
+
|
| 201 |
+
# E. Median filtering
|
| 202 |
+
weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
|
| 203 |
+
calc_matrix = weights_processed.mean(dim=0).numpy()
|
| 204 |
+
|
| 205 |
+
return calc_matrix, energy_matrix
|
| 206 |
+
|
| 207 |
+
def _preprocess_attention(
|
| 208 |
+
self,
|
| 209 |
+
attention_matrix: torch.Tensor,
|
| 210 |
+
custom_config: Dict[int, List[int]],
|
| 211 |
+
violence_level: float,
|
| 212 |
+
medfilt_width: int = 7
|
| 213 |
+
) -> tuple:
|
| 214 |
+
"""
|
| 215 |
+
Preprocess attention matrix for alignment.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
|
| 219 |
+
custom_config: Dict mapping layer indices to head indices
|
| 220 |
+
violence_level: Denoising strength
|
| 221 |
+
medfilt_width: Median filter width
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Tuple of (calc_matrix, energy_matrix, visual_matrix)
|
| 225 |
+
"""
|
| 226 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 227 |
+
weights = torch.tensor(attention_matrix)
|
| 228 |
+
else:
|
| 229 |
+
weights = attention_matrix.clone()
|
| 230 |
+
|
| 231 |
+
weights = weights.cpu().float()
|
| 232 |
+
|
| 233 |
+
selected_tensors = []
|
| 234 |
+
for layer_idx, head_indices in custom_config.items():
|
| 235 |
+
for head_idx in head_indices:
|
| 236 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 237 |
+
head_matrix = weights[layer_idx, head_idx]
|
| 238 |
+
selected_tensors.append(head_matrix)
|
| 239 |
+
|
| 240 |
+
if not selected_tensors:
|
| 241 |
+
return None, None, None
|
| 242 |
+
|
| 243 |
+
# Stack selected heads: [Heads, Tokens, Frames]
|
| 244 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 245 |
+
visual_matrix = weights_stack.mean(dim=0).numpy()
|
| 246 |
+
|
| 247 |
+
calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
|
| 248 |
+
weights_stack, violence_level, medfilt_width
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return calc_matrix, energy_matrix, visual_matrix
|
| 252 |
+
|
| 253 |
+
def stamps_align_info(
|
| 254 |
+
self,
|
| 255 |
+
attention_matrix: torch.Tensor,
|
| 256 |
+
lyrics_tokens: List[int],
|
| 257 |
+
total_duration_seconds: float,
|
| 258 |
+
custom_config: Dict[int, List[int]],
|
| 259 |
+
return_matrices: bool = False,
|
| 260 |
+
violence_level: float = 2.0,
|
| 261 |
+
medfilt_width: int = 1
|
| 262 |
+
) -> Dict[str, Any]:
|
| 263 |
+
"""
|
| 264 |
+
Get alignment information from attention matrix.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
|
| 268 |
+
lyrics_tokens: List of lyrics token IDs
|
| 269 |
+
total_duration_seconds: Total audio duration in seconds
|
| 270 |
+
custom_config: Dict mapping layer indices to head indices
|
| 271 |
+
return_matrices: Whether to return intermediate matrices
|
| 272 |
+
violence_level: Denoising strength
|
| 273 |
+
medfilt_width: Median filter width
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
|
| 277 |
+
and optionally energy_matrix and vis_matrix
|
| 278 |
+
"""
|
| 279 |
+
calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
|
| 280 |
+
attention_matrix, custom_config, violence_level, medfilt_width
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if calc_matrix is None:
|
| 284 |
+
return {
|
| 285 |
+
"calc_matrix": None,
|
| 286 |
+
"lyrics_tokens": lyrics_tokens,
|
| 287 |
+
"total_duration_seconds": total_duration_seconds,
|
| 288 |
+
"error": "No valid attention heads found"
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
return_dict = {
|
| 292 |
+
"calc_matrix": calc_matrix,
|
| 293 |
+
"lyrics_tokens": lyrics_tokens,
|
| 294 |
+
"total_duration_seconds": total_duration_seconds
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
if return_matrices:
|
| 298 |
+
return_dict['energy_matrix'] = energy_matrix
|
| 299 |
+
return_dict['vis_matrix'] = visual_matrix
|
| 300 |
+
|
| 301 |
+
return return_dict
|
| 302 |
+
|
| 303 |
+
def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
|
| 304 |
+
"""
|
| 305 |
+
Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
|
| 306 |
+
|
| 307 |
+
For Chinese and other multi-byte characters, the tokenizer may split them
|
| 308 |
+
into multiple byte-level tokens. Decoding each token individually produces
|
| 309 |
+
invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
|
| 310 |
+
to correctly track which characters each token contributes.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
token_ids: List of token IDs
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
List of decoded text for each token position
|
| 317 |
+
"""
|
| 318 |
+
decoded_tokens = []
|
| 319 |
+
prev_bytes = b""
|
| 320 |
+
|
| 321 |
+
for i in range(len(token_ids)):
|
| 322 |
+
# Decode tokens from start to current position
|
| 323 |
+
current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
|
| 324 |
+
current_bytes = current_text.encode('utf-8', errors='surrogatepass')
|
| 325 |
+
|
| 326 |
+
# The contribution of current token is the new bytes added
|
| 327 |
+
if len(current_bytes) >= len(prev_bytes):
|
| 328 |
+
new_bytes = current_bytes[len(prev_bytes):]
|
| 329 |
+
# Try to decode the new bytes; if incomplete, use empty string
|
| 330 |
+
try:
|
| 331 |
+
token_text = new_bytes.decode('utf-8')
|
| 332 |
+
except UnicodeDecodeError:
|
| 333 |
+
# Incomplete UTF-8 sequence, this token doesn't complete a character
|
| 334 |
+
token_text = ""
|
| 335 |
+
else:
|
| 336 |
+
# Edge case: current decode is shorter (shouldn't happen normally)
|
| 337 |
+
token_text = ""
|
| 338 |
+
|
| 339 |
+
decoded_tokens.append(token_text)
|
| 340 |
+
prev_bytes = current_bytes
|
| 341 |
+
|
| 342 |
+
return decoded_tokens
|
| 343 |
+
|
| 344 |
+
def token_timestamps(
|
| 345 |
+
self,
|
| 346 |
+
calc_matrix: np.ndarray,
|
| 347 |
+
lyrics_tokens: List[int],
|
| 348 |
+
total_duration_seconds: float
|
| 349 |
+
) -> List[TokenTimestamp]:
|
| 350 |
+
"""
|
| 351 |
+
Generate per-token timestamps using DTW.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
calc_matrix: Processed attention matrix [Tokens, Frames]
|
| 355 |
+
lyrics_tokens: List of token IDs
|
| 356 |
+
total_duration_seconds: Total audio duration
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
List of TokenTimestamp objects
|
| 360 |
+
"""
|
| 361 |
+
n_frames = calc_matrix.shape[-1]
|
| 362 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
|
| 363 |
+
|
| 364 |
+
seconds_per_frame = total_duration_seconds / n_frames
|
| 365 |
+
alignment_results = []
|
| 366 |
+
|
| 367 |
+
# Use incremental decoding to properly handle multi-byte UTF-8 characters
|
| 368 |
+
decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
|
| 369 |
+
|
| 370 |
+
for i in range(len(lyrics_tokens)):
|
| 371 |
+
mask = (text_indices == i)
|
| 372 |
+
|
| 373 |
+
if not np.any(mask):
|
| 374 |
+
start = alignment_results[-1].end if alignment_results else 0.0
|
| 375 |
+
end = start
|
| 376 |
+
token_conf = 0.0
|
| 377 |
+
else:
|
| 378 |
+
times = time_indices[mask] * seconds_per_frame
|
| 379 |
+
start = times[0]
|
| 380 |
+
end = times[-1]
|
| 381 |
+
token_conf = 0.0
|
| 382 |
+
|
| 383 |
+
if end < start:
|
| 384 |
+
end = start
|
| 385 |
+
|
| 386 |
+
alignment_results.append(TokenTimestamp(
|
| 387 |
+
token_id=lyrics_tokens[i],
|
| 388 |
+
text=decoded_tokens[i],
|
| 389 |
+
start=float(start),
|
| 390 |
+
end=float(end),
|
| 391 |
+
probability=token_conf
|
| 392 |
+
))
|
| 393 |
+
|
| 394 |
+
return alignment_results
|
| 395 |
+
|
| 396 |
+
def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
|
| 397 |
+
"""
|
| 398 |
+
Decode a sentence by decoding all token IDs together.
|
| 399 |
+
This avoids UTF-8 encoding issues from joining individual token texts.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
tokens: List of TokenTimestamp objects
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Properly decoded sentence text
|
| 406 |
+
"""
|
| 407 |
+
token_ids = [t.token_id for t in tokens]
|
| 408 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
| 409 |
+
|
| 410 |
+
def sentence_timestamps(
|
| 411 |
+
self,
|
| 412 |
+
token_alignment: List[TokenTimestamp]
|
| 413 |
+
) -> List[SentenceTimestamp]:
|
| 414 |
+
"""
|
| 415 |
+
Group token timestamps into sentence timestamps.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
token_alignment: List of TokenTimestamp objects
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
List of SentenceTimestamp objects
|
| 422 |
+
"""
|
| 423 |
+
results = []
|
| 424 |
+
current_tokens = []
|
| 425 |
+
|
| 426 |
+
for token in token_alignment:
|
| 427 |
+
current_tokens.append(token)
|
| 428 |
+
|
| 429 |
+
if '\n' in token.text:
|
| 430 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 431 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 432 |
+
|
| 433 |
+
if full_text.strip():
|
| 434 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 435 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 436 |
+
|
| 437 |
+
results.append(SentenceTimestamp(
|
| 438 |
+
text=full_text.strip(),
|
| 439 |
+
start=round(current_tokens[0].start, 3),
|
| 440 |
+
end=round(current_tokens[-1].end, 3),
|
| 441 |
+
tokens=list(current_tokens),
|
| 442 |
+
confidence=sent_conf
|
| 443 |
+
))
|
| 444 |
+
|
| 445 |
+
current_tokens = []
|
| 446 |
+
|
| 447 |
+
# Handle last sentence
|
| 448 |
+
if current_tokens:
|
| 449 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 450 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 451 |
+
if full_text.strip():
|
| 452 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 453 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 454 |
+
|
| 455 |
+
results.append(SentenceTimestamp(
|
| 456 |
+
text=full_text.strip(),
|
| 457 |
+
start=round(current_tokens[0].start, 3),
|
| 458 |
+
end=round(current_tokens[-1].end, 3),
|
| 459 |
+
tokens=list(current_tokens),
|
| 460 |
+
confidence=sent_conf
|
| 461 |
+
))
|
| 462 |
+
|
| 463 |
+
# Normalize confidence scores
|
| 464 |
+
if results:
|
| 465 |
+
all_scores = [s.confidence for s in results]
|
| 466 |
+
min_score = min(all_scores)
|
| 467 |
+
max_score = max(all_scores)
|
| 468 |
+
score_range = max_score - min_score
|
| 469 |
+
|
| 470 |
+
if score_range > 1e-9:
|
| 471 |
+
for s in results:
|
| 472 |
+
normalized_score = (s.confidence - min_score) / score_range
|
| 473 |
+
s.confidence = round(normalized_score, 2)
|
| 474 |
+
else:
|
| 475 |
+
for s in results:
|
| 476 |
+
s.confidence = round(s.confidence, 2)
|
| 477 |
+
|
| 478 |
+
return results
|
| 479 |
+
|
| 480 |
+
def format_lrc(
|
| 481 |
+
self,
|
| 482 |
+
sentence_timestamps: List[SentenceTimestamp],
|
| 483 |
+
include_end_time: bool = False
|
| 484 |
+
) -> str:
|
| 485 |
+
"""
|
| 486 |
+
Format sentence timestamps as LRC lyrics format.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
sentence_timestamps: List of SentenceTimestamp objects
|
| 490 |
+
include_end_time: Whether to include end time (enhanced LRC format)
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
LRC formatted string
|
| 494 |
+
"""
|
| 495 |
+
lines = []
|
| 496 |
+
|
| 497 |
+
for sentence in sentence_timestamps:
|
| 498 |
+
# Convert seconds to mm:ss.xx format
|
| 499 |
+
start_minutes = int(sentence.start // 60)
|
| 500 |
+
start_seconds = sentence.start % 60
|
| 501 |
+
|
| 502 |
+
if include_end_time:
|
| 503 |
+
end_minutes = int(sentence.end // 60)
|
| 504 |
+
end_seconds = sentence.end % 60
|
| 505 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
|
| 506 |
+
else:
|
| 507 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
|
| 508 |
+
|
| 509 |
+
# Clean the text (remove structural tags like [verse], [chorus])
|
| 510 |
+
text = sentence.text
|
| 511 |
+
|
| 512 |
+
lines.append(f"{timestamp}{text}")
|
| 513 |
+
|
| 514 |
+
return "\n".join(lines)
|
| 515 |
+
|
| 516 |
+
def get_timestamps_and_lrc(
|
| 517 |
+
self,
|
| 518 |
+
calc_matrix: np.ndarray,
|
| 519 |
+
lyrics_tokens: List[int],
|
| 520 |
+
total_duration_seconds: float
|
| 521 |
+
) -> Dict[str, Any]:
|
| 522 |
+
"""
|
| 523 |
+
Convenience method to get both timestamps and LRC in one call.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
calc_matrix: Processed attention matrix
|
| 527 |
+
lyrics_tokens: List of token IDs
|
| 528 |
+
total_duration_seconds: Total audio duration
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Dict containing token_timestamps, sentence_timestamps, and lrc_text
|
| 532 |
+
"""
|
| 533 |
+
token_stamps = self.token_timestamps(
|
| 534 |
+
calc_matrix=calc_matrix,
|
| 535 |
+
lyrics_tokens=lyrics_tokens,
|
| 536 |
+
total_duration_seconds=total_duration_seconds
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
sentence_stamps = self.sentence_timestamps(token_stamps)
|
| 540 |
+
lrc_text = self.format_lrc(sentence_stamps)
|
| 541 |
+
|
| 542 |
+
return {
|
| 543 |
+
"token_timestamps": token_stamps,
|
| 544 |
+
"sentence_timestamps": sentence_stamps,
|
| 545 |
+
"lrc_text": lrc_text
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class MusicLyricScorer:
|
| 550 |
+
"""
|
| 551 |
+
Scorer class for evaluating lyrics-to-audio alignment quality.
|
| 552 |
+
|
| 553 |
+
Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
|
| 554 |
+
using tensor operations for potential differentiability or GPU acceleration.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def __init__(self, tokenizer: Any):
|
| 558 |
+
"""
|
| 559 |
+
Initialize the aligner.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
tokenizer: Tokenizer instance (must implement .decode()).
|
| 563 |
+
"""
|
| 564 |
+
self.tokenizer = tokenizer
|
| 565 |
+
|
| 566 |
+
def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
|
| 567 |
+
"""
|
| 568 |
+
Generate a mask distinguishing lyrics (1) from structural tags (0).
|
| 569 |
+
Uses self.tokenizer to decode tokens.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
token_ids: List of token IDs.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
Numpy array of shape [len(token_ids)] with 1 or 0.
|
| 576 |
+
"""
|
| 577 |
+
decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
|
| 578 |
+
mask = np.ones(len(token_ids), dtype=np.int32)
|
| 579 |
+
in_bracket = False
|
| 580 |
+
|
| 581 |
+
for i, token_str in enumerate(decoded_tokens):
|
| 582 |
+
if '[' in token_str:
|
| 583 |
+
in_bracket = True
|
| 584 |
+
if in_bracket:
|
| 585 |
+
mask[i] = 0
|
| 586 |
+
if ']' in token_str:
|
| 587 |
+
in_bracket = False
|
| 588 |
+
mask[i] = 0
|
| 589 |
+
return mask
|
| 590 |
+
|
| 591 |
+
def _preprocess_attention(
|
| 592 |
+
self,
|
| 593 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 594 |
+
custom_config: Dict[int, List[int]],
|
| 595 |
+
medfilt_width: int = 1
|
| 596 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
|
| 597 |
+
"""
|
| 598 |
+
Extracts and normalizes the attention matrix.
|
| 599 |
+
|
| 600 |
+
Logic V4: Uses Min-Max normalization to highlight energy differences.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
|
| 604 |
+
custom_config: Config mapping layers to heads.
|
| 605 |
+
medfilt_width: Width for median filtering.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
|
| 609 |
+
"""
|
| 610 |
+
# 1. Prepare Tensor
|
| 611 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 612 |
+
weights = torch.tensor(attention_matrix)
|
| 613 |
+
else:
|
| 614 |
+
weights = attention_matrix.clone()
|
| 615 |
+
weights = weights.cpu().float()
|
| 616 |
+
|
| 617 |
+
# 2. Select Heads based on config
|
| 618 |
+
selected_tensors = []
|
| 619 |
+
for layer_idx, head_indices in custom_config.items():
|
| 620 |
+
for head_idx in head_indices:
|
| 621 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 622 |
+
selected_tensors.append(weights[layer_idx, head_idx])
|
| 623 |
+
|
| 624 |
+
if not selected_tensors:
|
| 625 |
+
return None, None, None
|
| 626 |
+
|
| 627 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 628 |
+
|
| 629 |
+
# 3. Average Heads
|
| 630 |
+
avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
|
| 631 |
+
|
| 632 |
+
# 4. Preprocessing Logic
|
| 633 |
+
# Min-Max normalization preserving energy distribution
|
| 634 |
+
# Median filter is applied to the energy matrix
|
| 635 |
+
energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
|
| 636 |
+
energy_matrix = energy_tensor.numpy()
|
| 637 |
+
|
| 638 |
+
e_min, e_max = energy_matrix.min(), energy_matrix.max()
|
| 639 |
+
|
| 640 |
+
if e_max - e_min > 1e-9:
|
| 641 |
+
energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
|
| 642 |
+
else:
|
| 643 |
+
energy_matrix = np.zeros_like(energy_matrix)
|
| 644 |
+
|
| 645 |
+
# Contrast enhancement for DTW pathfinding
|
| 646 |
+
# calc_matrix is used for pathfinding, energy_matrix for scoring
|
| 647 |
+
calc_matrix = energy_matrix ** 2
|
| 648 |
+
|
| 649 |
+
return calc_matrix, energy_matrix, avg_weights
|
| 650 |
+
|
| 651 |
+
def _compute_alignment_metrics(
|
| 652 |
+
self,
|
| 653 |
+
energy_matrix: torch.Tensor,
|
| 654 |
+
path_coords: torch.Tensor,
|
| 655 |
+
type_mask: torch.Tensor,
|
| 656 |
+
time_weight: float = 0.01,
|
| 657 |
+
overlap_frames: float = 9.0,
|
| 658 |
+
instrumental_weight: float = 1.0
|
| 659 |
+
) -> Tuple[float, float, float]:
|
| 660 |
+
"""
|
| 661 |
+
Core metric calculation logic using high-precision Tensor operations.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
energy_matrix: Normalized energy [Rows, Cols].
|
| 665 |
+
path_coords: DTW path coordinates [Steps, 2].
|
| 666 |
+
type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
|
| 667 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 668 |
+
overlap_frames: Allowed overlap for monotonicity check.
|
| 669 |
+
instrumental_weight: Weight for non-lyric tokens in confidence calc.
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
Tuple of (coverage, monotonicity, confidence).
|
| 673 |
+
"""
|
| 674 |
+
# Ensure high precision for internal calculation
|
| 675 |
+
energy_matrix = energy_matrix.to(dtype=torch.float64)
|
| 676 |
+
path_coords = path_coords.long()
|
| 677 |
+
type_mask = type_mask.long()
|
| 678 |
+
|
| 679 |
+
device = energy_matrix.device
|
| 680 |
+
rows, cols = energy_matrix.shape
|
| 681 |
+
|
| 682 |
+
is_lyrics_row = (type_mask == 1)
|
| 683 |
+
|
| 684 |
+
# ================= A. Coverage Score =================
|
| 685 |
+
# Ratio of lyric lines that have significant energy peak
|
| 686 |
+
row_max_energies = energy_matrix.max(dim=1).values
|
| 687 |
+
total_sung_rows = is_lyrics_row.sum().double()
|
| 688 |
+
|
| 689 |
+
coverage_threshold = 0.1
|
| 690 |
+
valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
|
| 691 |
+
valid_sung_rows = valid_sung_mask.sum().double()
|
| 692 |
+
|
| 693 |
+
if total_sung_rows > 0:
|
| 694 |
+
coverage_score = valid_sung_rows / total_sung_rows
|
| 695 |
+
else:
|
| 696 |
+
coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 697 |
+
|
| 698 |
+
# ================= B. Monotonicity Score =================
|
| 699 |
+
# Check if the "center of mass" of lyric lines moves forward in time
|
| 700 |
+
col_indices = torch.arange(cols, device=device, dtype=torch.float64)
|
| 701 |
+
|
| 702 |
+
# Zero out low energy noise
|
| 703 |
+
weights = torch.where(
|
| 704 |
+
energy_matrix > time_weight,
|
| 705 |
+
energy_matrix,
|
| 706 |
+
torch.zeros_like(energy_matrix)
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
sum_w = weights.sum(dim=1)
|
| 710 |
+
sum_t = (weights * col_indices).sum(dim=1)
|
| 711 |
+
|
| 712 |
+
# Calculate centroids
|
| 713 |
+
centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
|
| 714 |
+
valid_w_mask = sum_w > 1e-9
|
| 715 |
+
centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
|
| 716 |
+
|
| 717 |
+
# Extract sequence of valid lyrics centroids
|
| 718 |
+
valid_sequence_mask = is_lyrics_row & (centroids >= 0)
|
| 719 |
+
sung_centroids = centroids[valid_sequence_mask]
|
| 720 |
+
|
| 721 |
+
cnt = sung_centroids.shape[0]
|
| 722 |
+
if cnt > 1:
|
| 723 |
+
curr_c = sung_centroids[:-1]
|
| 724 |
+
next_c = sung_centroids[1:]
|
| 725 |
+
|
| 726 |
+
# Check non-decreasing order with overlap tolerance
|
| 727 |
+
non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
|
| 728 |
+
pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
|
| 729 |
+
monotonicity_score = non_decreasing / pairs
|
| 730 |
+
else:
|
| 731 |
+
monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 732 |
+
|
| 733 |
+
# ================= C. Path Confidence =================
|
| 734 |
+
# Average energy along the optimal path
|
| 735 |
+
if path_coords.shape[0] > 0:
|
| 736 |
+
p_rows = path_coords[:, 0]
|
| 737 |
+
p_cols = path_coords[:, 1]
|
| 738 |
+
|
| 739 |
+
path_energies = energy_matrix[p_rows, p_cols]
|
| 740 |
+
step_weights = torch.ones_like(path_energies)
|
| 741 |
+
|
| 742 |
+
# Lower weight for instrumental/tag steps
|
| 743 |
+
is_inst_step = (type_mask[p_rows] == 0)
|
| 744 |
+
step_weights[is_inst_step] = instrumental_weight
|
| 745 |
+
|
| 746 |
+
total_energy = (path_energies * step_weights).sum()
|
| 747 |
+
total_steps = step_weights.sum()
|
| 748 |
+
|
| 749 |
+
if total_steps > 0:
|
| 750 |
+
path_confidence = total_energy / total_steps
|
| 751 |
+
else:
|
| 752 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 753 |
+
else:
|
| 754 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 755 |
+
|
| 756 |
+
return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
|
| 757 |
+
|
| 758 |
+
def lyrics_alignment_info(
|
| 759 |
+
self,
|
| 760 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 761 |
+
token_ids: List[int],
|
| 762 |
+
custom_config: Dict[int, List[int]],
|
| 763 |
+
return_matrices: bool = False,
|
| 764 |
+
medfilt_width: int = 1
|
| 765 |
+
) -> Dict[str, Any]:
|
| 766 |
+
"""
|
| 767 |
+
Generates alignment path and processed matrices.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
attention_matrix: Input attention tensor.
|
| 771 |
+
token_ids: Corresponding token IDs.
|
| 772 |
+
custom_config: Layer/Head configuration.
|
| 773 |
+
return_matrices: If True, returns matrices in the output.
|
| 774 |
+
medfilt_width: Median filter width.
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
Dict or AlignmentInfo object containing path and masks.
|
| 778 |
+
"""
|
| 779 |
+
calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
|
| 780 |
+
attention_matrix, custom_config, medfilt_width
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
if calc_matrix is None:
|
| 784 |
+
return {
|
| 785 |
+
"calc_matrix": None,
|
| 786 |
+
"error": "No valid attention heads found"
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
# 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
|
| 790 |
+
# Uses self.tokenizer internally
|
| 791 |
+
type_mask = self._generate_token_type_mask(token_ids)
|
| 792 |
+
|
| 793 |
+
# Safety check for shape mismatch
|
| 794 |
+
if len(type_mask) != energy_matrix.shape[0]:
|
| 795 |
+
# Fallback to all lyrics if shapes don't align
|
| 796 |
+
type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
|
| 797 |
+
|
| 798 |
+
# 2. DTW Pathfinding
|
| 799 |
+
# Using negative calc_matrix because DTW minimizes cost
|
| 800 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
|
| 801 |
+
path_coords = np.stack([text_indices, time_indices], axis=1)
|
| 802 |
+
|
| 803 |
+
return_dict = {
|
| 804 |
+
"path_coords": path_coords,
|
| 805 |
+
"type_mask": type_mask,
|
| 806 |
+
"energy_matrix": energy_matrix
|
| 807 |
+
}
|
| 808 |
+
if return_matrices:
|
| 809 |
+
return_dict['calc_matrix'] = calc_matrix
|
| 810 |
+
return_dict['vis_matrix'] = vis_matrix
|
| 811 |
+
|
| 812 |
+
return return_dict
|
| 813 |
+
|
| 814 |
+
def calculate_score(
|
| 815 |
+
self,
|
| 816 |
+
energy_matrix: Union[torch.Tensor, np.ndarray],
|
| 817 |
+
type_mask: Union[torch.Tensor, np.ndarray],
|
| 818 |
+
path_coords: Union[torch.Tensor, np.ndarray],
|
| 819 |
+
time_weight: float = 0.01,
|
| 820 |
+
overlap_frames: float = 9.0,
|
| 821 |
+
instrumental_weight: float = 1.0
|
| 822 |
+
) -> Dict[str, Any]:
|
| 823 |
+
"""
|
| 824 |
+
Calculates the final alignment score based on pre-computed components.
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
energy_matrix: Processed energy matrix.
|
| 828 |
+
type_mask: Token type mask.
|
| 829 |
+
path_coords: DTW path coordinates.
|
| 830 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 831 |
+
overlap_frames: Allowed backward movement frames.
|
| 832 |
+
instrumental_weight: Weight for non-lyric path steps.
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
AlignmentScore object containing individual metrics and final score.
|
| 836 |
+
"""
|
| 837 |
+
# Ensure Inputs are Tensors on the correct device
|
| 838 |
+
if not isinstance(energy_matrix, torch.Tensor):
|
| 839 |
+
# Use available accelerator device; fallback to CPU if none
|
| 840 |
+
if torch.cuda.is_available():
|
| 841 |
+
_score_device = "cuda"
|
| 842 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 843 |
+
_score_device = "mps"
|
| 844 |
+
else:
|
| 845 |
+
_score_device = "cpu"
|
| 846 |
+
energy_matrix = torch.tensor(energy_matrix, device=_score_device, dtype=torch.float32)
|
| 847 |
+
|
| 848 |
+
device = energy_matrix.device
|
| 849 |
+
|
| 850 |
+
if not isinstance(type_mask, torch.Tensor):
|
| 851 |
+
type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
|
| 852 |
+
else:
|
| 853 |
+
type_mask = type_mask.to(device=device, dtype=torch.long)
|
| 854 |
+
|
| 855 |
+
if not isinstance(path_coords, torch.Tensor):
|
| 856 |
+
path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
|
| 857 |
+
else:
|
| 858 |
+
path_coords = path_coords.to(device=device, dtype=torch.long)
|
| 859 |
+
|
| 860 |
+
# Compute Metrics
|
| 861 |
+
coverage, monotonicity, confidence = self._compute_alignment_metrics(
|
| 862 |
+
energy_matrix=energy_matrix,
|
| 863 |
+
path_coords=path_coords,
|
| 864 |
+
type_mask=type_mask,
|
| 865 |
+
time_weight=time_weight,
|
| 866 |
+
overlap_frames=overlap_frames,
|
| 867 |
+
instrumental_weight=instrumental_weight
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
# Final Score Calculation
|
| 871 |
+
# (Cov^2 * Mono^2 * Conf)
|
| 872 |
+
final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
|
| 873 |
+
final_score = float(np.clip(final_score, 0.0, 1.0))
|
| 874 |
+
|
| 875 |
+
return {
|
| 876 |
+
"lyrics_score": round(final_score, 4)
|
| 877 |
+
}
|
acestep/genres_vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/gpu_config.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPU Configuration Module
|
| 3 |
+
Centralized GPU memory detection and adaptive configuration management
|
| 4 |
+
|
| 5 |
+
Debug Mode:
|
| 6 |
+
Set environment variable MAX_CUDA_VRAM to simulate different GPU memory sizes.
|
| 7 |
+
Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
|
| 8 |
+
|
| 9 |
+
For MPS testing, use MAX_MPS_VRAM to simulate MPS memory.
|
| 10 |
+
Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
|
| 11 |
+
|
| 12 |
+
This is useful for testing GPU tier configurations on high-end hardware.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional, List, Dict, Tuple
|
| 19 |
+
from loguru import logger
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Environment variable for debugging/testing different GPU memory configurations
|
| 23 |
+
DEBUG_MAX_CUDA_VRAM_ENV = "MAX_CUDA_VRAM"
|
| 24 |
+
DEBUG_MAX_MPS_VRAM_ENV = "MAX_MPS_VRAM"
|
| 25 |
+
|
| 26 |
+
# Tolerance for 16GB detection: reported VRAM like 15.5GB is effectively 16GB hardware
|
| 27 |
+
# Real-world 16GB GPUs often report 15.7-15.9GB due to system/driver reservations
|
| 28 |
+
VRAM_16GB_TOLERANCE_GB = 0.5
|
| 29 |
+
VRAM_16GB_MIN_GB = 16.0 - VRAM_16GB_TOLERANCE_GB # treat as 16GB class if >= this
|
| 30 |
+
|
| 31 |
+
# PyTorch installation URLs for diagnostics
|
| 32 |
+
PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
|
| 33 |
+
PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class GPUConfig:
|
| 38 |
+
"""GPU configuration based on available memory"""
|
| 39 |
+
tier: str # "tier1", "tier2", etc. or "unlimited"
|
| 40 |
+
gpu_memory_gb: float
|
| 41 |
+
|
| 42 |
+
# Duration limits (in seconds)
|
| 43 |
+
max_duration_with_lm: int # When LM is initialized
|
| 44 |
+
max_duration_without_lm: int # When LM is not initialized
|
| 45 |
+
|
| 46 |
+
# Batch size limits
|
| 47 |
+
max_batch_size_with_lm: int
|
| 48 |
+
max_batch_size_without_lm: int
|
| 49 |
+
|
| 50 |
+
# LM configuration
|
| 51 |
+
init_lm_default: bool # Whether to initialize LM by default
|
| 52 |
+
available_lm_models: List[str] # Available LM models for this tier
|
| 53 |
+
|
| 54 |
+
# LM memory allocation (GB) for each model size
|
| 55 |
+
lm_memory_gb: Dict[str, float] # e.g., {"0.6B": 3, "1.7B": 8, "4B": 12}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# GPU tier configurations
|
| 59 |
+
GPU_TIER_CONFIGS = {
|
| 60 |
+
"tier1": { # <= 4GB
|
| 61 |
+
"max_duration_with_lm": 180, # 3 minutes
|
| 62 |
+
"max_duration_without_lm": 180, # 3 minutes
|
| 63 |
+
"max_batch_size_with_lm": 1,
|
| 64 |
+
"max_batch_size_without_lm": 1,
|
| 65 |
+
"init_lm_default": False,
|
| 66 |
+
"available_lm_models": [],
|
| 67 |
+
"lm_memory_gb": {},
|
| 68 |
+
},
|
| 69 |
+
"tier2": { # 4-6GB
|
| 70 |
+
"max_duration_with_lm": 360, # 6 minutes
|
| 71 |
+
"max_duration_without_lm": 360, # 6 minutes
|
| 72 |
+
"max_batch_size_with_lm": 1,
|
| 73 |
+
"max_batch_size_without_lm": 1,
|
| 74 |
+
"init_lm_default": False,
|
| 75 |
+
"available_lm_models": [],
|
| 76 |
+
"lm_memory_gb": {},
|
| 77 |
+
},
|
| 78 |
+
"tier3": { # 6-8GB
|
| 79 |
+
"max_duration_with_lm": 240, # 4 minutes with LM
|
| 80 |
+
"max_duration_without_lm": 360, # 6 minutes without LM
|
| 81 |
+
"max_batch_size_with_lm": 1,
|
| 82 |
+
"max_batch_size_without_lm": 2,
|
| 83 |
+
"init_lm_default": False, # Don't init by default due to limited memory
|
| 84 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B"],
|
| 85 |
+
"lm_memory_gb": {"0.6B": 3},
|
| 86 |
+
},
|
| 87 |
+
"tier4": { # 8-12GB
|
| 88 |
+
"max_duration_with_lm": 240, # 4 minutes with LM
|
| 89 |
+
"max_duration_without_lm": 360, # 6 minutes without LM
|
| 90 |
+
"max_batch_size_with_lm": 2,
|
| 91 |
+
"max_batch_size_without_lm": 4,
|
| 92 |
+
"init_lm_default": False, # Don't init by default
|
| 93 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B"],
|
| 94 |
+
"lm_memory_gb": {"0.6B": 3},
|
| 95 |
+
},
|
| 96 |
+
"tier5": { # 12-16GB
|
| 97 |
+
"max_duration_with_lm": 240, # 4 minutes with LM
|
| 98 |
+
"max_duration_without_lm": 360, # 6 minutes without LM
|
| 99 |
+
"max_batch_size_with_lm": 2,
|
| 100 |
+
"max_batch_size_without_lm": 4,
|
| 101 |
+
"init_lm_default": True,
|
| 102 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B"],
|
| 103 |
+
"lm_memory_gb": {"0.6B": 3, "1.7B": 8},
|
| 104 |
+
},
|
| 105 |
+
"tier6": { # 16-24GB
|
| 106 |
+
"max_duration_with_lm": 480, # 8 minutes
|
| 107 |
+
"max_duration_without_lm": 480, # 8 minutes
|
| 108 |
+
"max_batch_size_with_lm": 4,
|
| 109 |
+
"max_batch_size_without_lm": 8,
|
| 110 |
+
"init_lm_default": True,
|
| 111 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
|
| 112 |
+
"lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
|
| 113 |
+
},
|
| 114 |
+
"unlimited": { # >= 24GB
|
| 115 |
+
"max_duration_with_lm": 600, # 10 minutes (max supported)
|
| 116 |
+
"max_duration_without_lm": 600, # 10 minutes
|
| 117 |
+
"max_batch_size_with_lm": 8,
|
| 118 |
+
"max_batch_size_without_lm": 8,
|
| 119 |
+
"init_lm_default": True,
|
| 120 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
|
| 121 |
+
"lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
|
| 122 |
+
},
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_gpu_memory_gb() -> float:
|
| 127 |
+
"""
|
| 128 |
+
Get GPU memory in GB. Returns 0 if no GPU is available.
|
| 129 |
+
|
| 130 |
+
Debug Mode:
|
| 131 |
+
Set environment variable MAX_CUDA_VRAM to override the detected GPU memory.
|
| 132 |
+
Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
|
| 133 |
+
|
| 134 |
+
For MPS testing, set MAX_MPS_VRAM to override MPS memory detection.
|
| 135 |
+
Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
|
| 136 |
+
|
| 137 |
+
This allows testing different GPU tier configurations on high-end hardware.
|
| 138 |
+
"""
|
| 139 |
+
# Check for debug override first
|
| 140 |
+
debug_vram = os.environ.get(DEBUG_MAX_CUDA_VRAM_ENV)
|
| 141 |
+
if debug_vram is not None:
|
| 142 |
+
try:
|
| 143 |
+
simulated_gb = float(debug_vram)
|
| 144 |
+
logger.warning(f"⚠️ DEBUG MODE: Simulating GPU memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_CUDA_VRAM_ENV} environment variable)")
|
| 145 |
+
return simulated_gb
|
| 146 |
+
except ValueError:
|
| 147 |
+
logger.warning(f"Invalid {DEBUG_MAX_CUDA_VRAM_ENV} value: {debug_vram}, ignoring")
|
| 148 |
+
debug_mps_vram = os.environ.get(DEBUG_MAX_MPS_VRAM_ENV)
|
| 149 |
+
if debug_mps_vram is not None:
|
| 150 |
+
try:
|
| 151 |
+
simulated_gb = float(debug_mps_vram)
|
| 152 |
+
logger.warning(f"⚠️ DEBUG MODE: Simulating MPS memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_MPS_VRAM_ENV} environment variable)")
|
| 153 |
+
return simulated_gb
|
| 154 |
+
except ValueError:
|
| 155 |
+
logger.warning(f"Invalid {DEBUG_MAX_MPS_VRAM_ENV} value: {debug_mps_vram}, ignoring")
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
import torch
|
| 159 |
+
if torch.cuda.is_available():
|
| 160 |
+
# Get total memory of the first GPU in GB
|
| 161 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory
|
| 162 |
+
memory_gb = total_memory / (1024**3) # Convert bytes to GB
|
| 163 |
+
device_name = torch.cuda.get_device_name(0)
|
| 164 |
+
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
|
| 165 |
+
if is_rocm:
|
| 166 |
+
logger.info(f"ROCm GPU detected: {device_name} ({memory_gb:.1f} GB, HIP {torch.version.hip})")
|
| 167 |
+
else:
|
| 168 |
+
logger.info(f"CUDA GPU detected: {device_name} ({memory_gb:.1f} GB)")
|
| 169 |
+
return memory_gb
|
| 170 |
+
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
| 171 |
+
# Get total memory of the first XPU in GB
|
| 172 |
+
total_memory = torch.xpu.get_device_properties(0).total_memory
|
| 173 |
+
memory_gb = total_memory / (1024**3) # Convert bytes to GB
|
| 174 |
+
return memory_gb
|
| 175 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 176 |
+
mps_module = getattr(torch, "mps", None)
|
| 177 |
+
try:
|
| 178 |
+
if mps_module is not None and hasattr(mps_module, "recommended_max_memory"):
|
| 179 |
+
total_memory = mps_module.recommended_max_memory()
|
| 180 |
+
memory_gb = total_memory / (1024**3) # Convert bytes to GB
|
| 181 |
+
return memory_gb
|
| 182 |
+
if mps_module is not None and hasattr(mps_module, "get_device_properties"):
|
| 183 |
+
props = mps_module.get_device_properties(0)
|
| 184 |
+
total_memory = getattr(props, "total_memory", None)
|
| 185 |
+
if total_memory:
|
| 186 |
+
memory_gb = total_memory / (1024**3)
|
| 187 |
+
return memory_gb
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.warning(f"Failed to detect MPS memory: {e}")
|
| 190 |
+
|
| 191 |
+
# Fallback: estimate from system unified memory (Apple Silicon shares CPU/GPU RAM)
|
| 192 |
+
try:
|
| 193 |
+
import subprocess
|
| 194 |
+
result = subprocess.run(
|
| 195 |
+
["sysctl", "-n", "hw.memsize"],
|
| 196 |
+
capture_output=True, text=True, timeout=5
|
| 197 |
+
)
|
| 198 |
+
total_system_bytes = int(result.stdout.strip())
|
| 199 |
+
# MPS can use up to ~75% of unified memory for GPU workloads
|
| 200 |
+
memory_gb = (total_system_bytes / (1024**3)) * 0.75
|
| 201 |
+
return memory_gb
|
| 202 |
+
except Exception:
|
| 203 |
+
logger.warning(f"MPS available but total memory not exposed. Set {DEBUG_MAX_MPS_VRAM_ENV} to enable tiering.")
|
| 204 |
+
# Conservative fallback for M1/M2
|
| 205 |
+
return 8.0
|
| 206 |
+
else:
|
| 207 |
+
# No GPU detected - provide diagnostic information
|
| 208 |
+
_log_gpu_diagnostic_info(torch)
|
| 209 |
+
return 0
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.warning(f"Failed to detect GPU memory: {e}")
|
| 212 |
+
return 0
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _log_gpu_diagnostic_info(torch_module):
|
| 216 |
+
"""
|
| 217 |
+
Log diagnostic information when GPU is not detected to help users troubleshoot.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
torch_module: The torch module to inspect for build information
|
| 221 |
+
"""
|
| 222 |
+
logger.warning("=" * 80)
|
| 223 |
+
logger.warning("⚠️ GPU NOT DETECTED - DIAGNOSTIC INFORMATION")
|
| 224 |
+
logger.warning("=" * 80)
|
| 225 |
+
|
| 226 |
+
# Check PyTorch build type
|
| 227 |
+
is_rocm_build = hasattr(torch_module.version, 'hip') and torch_module.version.hip is not None
|
| 228 |
+
is_cuda_build = hasattr(torch_module.version, 'cuda') and torch_module.version.cuda is not None
|
| 229 |
+
|
| 230 |
+
if is_rocm_build:
|
| 231 |
+
logger.warning("✓ PyTorch ROCm build detected")
|
| 232 |
+
logger.warning(f" HIP version: {torch_module.version.hip}")
|
| 233 |
+
logger.warning("")
|
| 234 |
+
logger.warning("❌ torch.cuda.is_available() returned False")
|
| 235 |
+
logger.warning("")
|
| 236 |
+
logger.warning("Common causes for AMD/ROCm GPUs:")
|
| 237 |
+
logger.warning(" 1. ROCm drivers not installed or not properly configured")
|
| 238 |
+
logger.warning(" 2. GPU not supported by installed ROCm version")
|
| 239 |
+
logger.warning(" 3. Missing or incorrect HSA_OVERRIDE_GFX_VERSION environment variable")
|
| 240 |
+
logger.warning(" 4. ROCm runtime libraries not in system path")
|
| 241 |
+
logger.warning("")
|
| 242 |
+
|
| 243 |
+
# Check for common environment variables
|
| 244 |
+
hsa_override = os.environ.get('HSA_OVERRIDE_GFX_VERSION')
|
| 245 |
+
if hsa_override:
|
| 246 |
+
logger.warning(f" HSA_OVERRIDE_GFX_VERSION is set to: {hsa_override}")
|
| 247 |
+
else:
|
| 248 |
+
logger.warning(" ⚠️ HSA_OVERRIDE_GFX_VERSION is not set")
|
| 249 |
+
logger.warning(" For RDNA3 GPUs (RX 7000 series, RX 9000 series):")
|
| 250 |
+
logger.warning(" - RX 7900 XT/XTX, RX 9070 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.0")
|
| 251 |
+
logger.warning(" - RX 7800 XT, RX 7700 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.1")
|
| 252 |
+
logger.warning(" - RX 7600: set HSA_OVERRIDE_GFX_VERSION=11.0.2")
|
| 253 |
+
|
| 254 |
+
logger.warning("")
|
| 255 |
+
logger.warning("Troubleshooting steps:")
|
| 256 |
+
logger.warning(" 1. Verify ROCm installation:")
|
| 257 |
+
logger.warning(" rocm-smi # Should list your GPU")
|
| 258 |
+
logger.warning(" 2. Check PyTorch ROCm build:")
|
| 259 |
+
logger.warning(" python -c \"import torch; print(f'ROCm: {torch.version.hip}')\"")
|
| 260 |
+
logger.warning(" 3. Set HSA_OVERRIDE_GFX_VERSION for your GPU (see above)")
|
| 261 |
+
logger.warning(" 4. On Windows: Use start_gradio_ui_rocm.bat which sets required env vars")
|
| 262 |
+
logger.warning(" 5. See docs/en/ACE-Step1.5-Rocm-Manual-Linux.md for Linux setup")
|
| 263 |
+
logger.warning(" 6. See requirements-rocm.txt for Windows ROCm setup instructions")
|
| 264 |
+
|
| 265 |
+
elif is_cuda_build:
|
| 266 |
+
logger.warning("✓ PyTorch CUDA build detected")
|
| 267 |
+
logger.warning(f" CUDA version: {torch_module.version.cuda}")
|
| 268 |
+
logger.warning("")
|
| 269 |
+
logger.warning("❌ torch.cuda.is_available() returned False")
|
| 270 |
+
logger.warning("")
|
| 271 |
+
logger.warning("Common causes for NVIDIA GPUs:")
|
| 272 |
+
logger.warning(" 1. NVIDIA drivers not installed")
|
| 273 |
+
logger.warning(" 2. CUDA runtime not installed or version mismatch")
|
| 274 |
+
logger.warning(" 3. GPU not supported by installed CUDA version")
|
| 275 |
+
logger.warning("")
|
| 276 |
+
logger.warning("Troubleshooting steps:")
|
| 277 |
+
logger.warning(" 1. Verify NVIDIA driver installation:")
|
| 278 |
+
logger.warning(" nvidia-smi # Should list your GPU")
|
| 279 |
+
logger.warning(" 2. Check CUDA version compatibility")
|
| 280 |
+
logger.warning(" 3. Reinstall PyTorch with CUDA support:")
|
| 281 |
+
logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
|
| 282 |
+
|
| 283 |
+
else:
|
| 284 |
+
logger.warning("⚠️ PyTorch build type: CPU-only")
|
| 285 |
+
logger.warning("")
|
| 286 |
+
logger.warning("You have installed a CPU-only version of PyTorch!")
|
| 287 |
+
logger.warning("")
|
| 288 |
+
logger.warning("For NVIDIA GPUs:")
|
| 289 |
+
logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
|
| 290 |
+
logger.warning("")
|
| 291 |
+
logger.warning("For AMD GPUs with ROCm:")
|
| 292 |
+
logger.warning(" Windows: See requirements-rocm.txt for detailed instructions")
|
| 293 |
+
logger.warning(f" Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
|
| 294 |
+
logger.warning("")
|
| 295 |
+
logger.warning("For more information, see README.md section 'AMD / ROCm GPUs'")
|
| 296 |
+
|
| 297 |
+
logger.warning("=" * 80)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_gpu_tier(gpu_memory_gb: float) -> str:
|
| 301 |
+
"""
|
| 302 |
+
Determine GPU tier based on available memory.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
gpu_memory_gb: GPU memory in GB
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Tier string: "tier1", "tier2", "tier3", "tier4", "tier5", "tier6", or "unlimited"
|
| 309 |
+
"""
|
| 310 |
+
if gpu_memory_gb <= 0:
|
| 311 |
+
# CPU mode - use tier1 limits
|
| 312 |
+
return "tier1"
|
| 313 |
+
elif gpu_memory_gb <= 4:
|
| 314 |
+
return "tier1"
|
| 315 |
+
elif gpu_memory_gb <= 6:
|
| 316 |
+
return "tier2"
|
| 317 |
+
elif gpu_memory_gb <= 8:
|
| 318 |
+
return "tier3"
|
| 319 |
+
elif gpu_memory_gb <= 12:
|
| 320 |
+
return "tier4"
|
| 321 |
+
elif gpu_memory_gb < VRAM_16GB_MIN_GB:
|
| 322 |
+
return "tier5"
|
| 323 |
+
elif gpu_memory_gb <= 24:
|
| 324 |
+
if gpu_memory_gb < 16.0:
|
| 325 |
+
logger.info(f"Detected {gpu_memory_gb:.2f}GB VRAM — treating as 16GB class GPU")
|
| 326 |
+
return "tier6"
|
| 327 |
+
else:
|
| 328 |
+
return "unlimited"
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
|
| 332 |
+
"""
|
| 333 |
+
Get GPU configuration based on detected or provided GPU memory.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
gpu_memory_gb: GPU memory in GB. If None, will be auto-detected.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
GPUConfig object with all configuration parameters
|
| 340 |
+
"""
|
| 341 |
+
if gpu_memory_gb is None:
|
| 342 |
+
gpu_memory_gb = get_gpu_memory_gb()
|
| 343 |
+
|
| 344 |
+
tier = get_gpu_tier(gpu_memory_gb)
|
| 345 |
+
config = GPU_TIER_CONFIGS[tier]
|
| 346 |
+
|
| 347 |
+
return GPUConfig(
|
| 348 |
+
tier=tier,
|
| 349 |
+
gpu_memory_gb=gpu_memory_gb,
|
| 350 |
+
max_duration_with_lm=config["max_duration_with_lm"],
|
| 351 |
+
max_duration_without_lm=config["max_duration_without_lm"],
|
| 352 |
+
max_batch_size_with_lm=config["max_batch_size_with_lm"],
|
| 353 |
+
max_batch_size_without_lm=config["max_batch_size_without_lm"],
|
| 354 |
+
init_lm_default=config["init_lm_default"],
|
| 355 |
+
available_lm_models=config["available_lm_models"],
|
| 356 |
+
lm_memory_gb=config["lm_memory_gb"],
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def get_lm_model_size(model_path: str) -> str:
|
| 361 |
+
"""
|
| 362 |
+
Extract LM model size from model path.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
model_path: Model path string (e.g., "acestep-5Hz-lm-0.6B")
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
Model size string: "0.6B", "1.7B", or "4B"
|
| 369 |
+
"""
|
| 370 |
+
if "0.6B" in model_path:
|
| 371 |
+
return "0.6B"
|
| 372 |
+
elif "1.7B" in model_path:
|
| 373 |
+
return "1.7B"
|
| 374 |
+
elif "4B" in model_path:
|
| 375 |
+
return "4B"
|
| 376 |
+
else:
|
| 377 |
+
# Default to smallest model assumption
|
| 378 |
+
return "0.6B"
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def get_lm_gpu_memory_ratio(model_path: str, total_gpu_memory_gb: float) -> Tuple[float, float]:
|
| 382 |
+
"""
|
| 383 |
+
Calculate GPU memory utilization ratio for LM model.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
model_path: LM model path (e.g., "acestep-5Hz-lm-0.6B")
|
| 387 |
+
total_gpu_memory_gb: Total GPU memory in GB
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
Tuple of (gpu_memory_utilization_ratio, target_memory_gb)
|
| 391 |
+
"""
|
| 392 |
+
model_size = get_lm_model_size(model_path)
|
| 393 |
+
|
| 394 |
+
# Target memory allocation for each model size
|
| 395 |
+
target_memory = {
|
| 396 |
+
"0.6B": 3.0,
|
| 397 |
+
"1.7B": 8.0,
|
| 398 |
+
"4B": 12.0,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
target_gb = target_memory.get(model_size, 3.0)
|
| 402 |
+
|
| 403 |
+
# For large GPUs (>=24GB), don't restrict memory too much
|
| 404 |
+
if total_gpu_memory_gb >= 24:
|
| 405 |
+
# Use a reasonable ratio that allows the model to run efficiently
|
| 406 |
+
ratio = min(0.9, max(0.2, target_gb / total_gpu_memory_gb))
|
| 407 |
+
else:
|
| 408 |
+
# For smaller GPUs, strictly limit memory usage
|
| 409 |
+
ratio = min(0.9, max(0.1, target_gb / total_gpu_memory_gb))
|
| 410 |
+
|
| 411 |
+
return ratio, target_gb
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def check_duration_limit(
|
| 415 |
+
duration: float,
|
| 416 |
+
gpu_config: GPUConfig,
|
| 417 |
+
lm_initialized: bool
|
| 418 |
+
) -> Tuple[bool, str]:
|
| 419 |
+
"""
|
| 420 |
+
Check if requested duration is within limits for current GPU configuration.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
duration: Requested duration in seconds
|
| 424 |
+
gpu_config: Current GPU configuration
|
| 425 |
+
lm_initialized: Whether LM is initialized
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
Tuple of (is_valid, warning_message)
|
| 429 |
+
"""
|
| 430 |
+
max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
|
| 431 |
+
|
| 432 |
+
if duration > max_duration:
|
| 433 |
+
warning_msg = (
|
| 434 |
+
f"⚠️ Requested duration ({duration:.0f}s) exceeds the limit for your GPU "
|
| 435 |
+
f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_duration}s "
|
| 436 |
+
f"({'with' if lm_initialized else 'without'} LM). "
|
| 437 |
+
f"Duration will be clamped to {max_duration}s."
|
| 438 |
+
)
|
| 439 |
+
return False, warning_msg
|
| 440 |
+
|
| 441 |
+
return True, ""
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def check_batch_size_limit(
|
| 445 |
+
batch_size: int,
|
| 446 |
+
gpu_config: GPUConfig,
|
| 447 |
+
lm_initialized: bool
|
| 448 |
+
) -> Tuple[bool, str]:
|
| 449 |
+
"""
|
| 450 |
+
Check if requested batch size is within limits for current GPU configuration.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
batch_size: Requested batch size
|
| 454 |
+
gpu_config: Current GPU configuration
|
| 455 |
+
lm_initialized: Whether LM is initialized
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
Tuple of (is_valid, warning_message)
|
| 459 |
+
"""
|
| 460 |
+
max_batch_size = gpu_config.max_batch_size_with_lm if lm_initialized else gpu_config.max_batch_size_without_lm
|
| 461 |
+
|
| 462 |
+
if batch_size > max_batch_size:
|
| 463 |
+
warning_msg = (
|
| 464 |
+
f"⚠️ Requested batch size ({batch_size}) exceeds the limit for your GPU "
|
| 465 |
+
f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_batch_size} "
|
| 466 |
+
f"({'with' if lm_initialized else 'without'} LM). "
|
| 467 |
+
f"Batch size will be clamped to {max_batch_size}."
|
| 468 |
+
)
|
| 469 |
+
return False, warning_msg
|
| 470 |
+
|
| 471 |
+
return True, ""
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def is_lm_model_supported(model_path: str, gpu_config: GPUConfig) -> Tuple[bool, str]:
|
| 475 |
+
"""
|
| 476 |
+
Check if the specified LM model is supported for current GPU configuration.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
model_path: LM model path
|
| 480 |
+
gpu_config: Current GPU configuration
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
Tuple of (is_supported, warning_message)
|
| 484 |
+
"""
|
| 485 |
+
if not gpu_config.available_lm_models:
|
| 486 |
+
return False, (
|
| 487 |
+
f"⚠️ Your GPU ({gpu_config.gpu_memory_gb:.1f}GB) does not have enough memory "
|
| 488 |
+
f"to run any LM model. Please disable LM initialization."
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
model_size = get_lm_model_size(model_path)
|
| 492 |
+
|
| 493 |
+
# Check if model size is in available models
|
| 494 |
+
for available_model in gpu_config.available_lm_models:
|
| 495 |
+
if model_size in available_model:
|
| 496 |
+
return True, ""
|
| 497 |
+
|
| 498 |
+
return False, (
|
| 499 |
+
f"⚠️ LM model {model_path} ({model_size}) is not supported for your GPU "
|
| 500 |
+
f"({gpu_config.gpu_memory_gb:.1f}GB). Available models: {', '.join(gpu_config.available_lm_models)}"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def get_recommended_lm_model(gpu_config: GPUConfig) -> Optional[str]:
|
| 505 |
+
"""
|
| 506 |
+
Get recommended LM model for current GPU configuration.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
gpu_config: Current GPU configuration
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
Recommended LM model path, or None if LM is not supported
|
| 513 |
+
"""
|
| 514 |
+
if not gpu_config.available_lm_models:
|
| 515 |
+
return None
|
| 516 |
+
|
| 517 |
+
# Return the largest available model (last in the list)
|
| 518 |
+
return gpu_config.available_lm_models[-1]
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def print_gpu_config_info(gpu_config: GPUConfig):
|
| 522 |
+
"""Print GPU configuration information for debugging."""
|
| 523 |
+
logger.info(f"GPU Configuration:")
|
| 524 |
+
logger.info(f" - GPU Memory: {gpu_config.gpu_memory_gb:.1f} GB")
|
| 525 |
+
logger.info(f" - Tier: {gpu_config.tier}")
|
| 526 |
+
logger.info(f" - Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)")
|
| 527 |
+
logger.info(f" - Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)")
|
| 528 |
+
logger.info(f" - Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}")
|
| 529 |
+
logger.info(f" - Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}")
|
| 530 |
+
logger.info(f" - Init LM by Default: {gpu_config.init_lm_default}")
|
| 531 |
+
logger.info(f" - Available LM Models: {gpu_config.available_lm_models or 'None'}")
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# Global GPU config instance (initialized lazily)
|
| 535 |
+
_global_gpu_config: Optional[GPUConfig] = None
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def get_global_gpu_config() -> GPUConfig:
|
| 539 |
+
"""Get the global GPU configuration, initializing if necessary."""
|
| 540 |
+
global _global_gpu_config
|
| 541 |
+
if _global_gpu_config is None:
|
| 542 |
+
_global_gpu_config = get_gpu_config()
|
| 543 |
+
return _global_gpu_config
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def set_global_gpu_config(config: GPUConfig):
|
| 547 |
+
"""Set the global GPU configuration."""
|
| 548 |
+
global _global_gpu_config
|
| 549 |
+
_global_gpu_config = config
|
acestep/handler.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/inference.py
ADDED
|
@@ -0,0 +1,1310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step Inference API Module
|
| 3 |
+
|
| 4 |
+
This module provides a standardized inference interface for music generation,
|
| 5 |
+
designed for third-party integration. It offers both a simplified API and
|
| 6 |
+
backward-compatible Gradio UI support.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
+
import shutil
|
| 13 |
+
import subprocess
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 16 |
+
from dataclasses import dataclass, field, asdict
|
| 17 |
+
from loguru import logger
|
| 18 |
+
|
| 19 |
+
from acestep.audio_utils import AudioSaver, generate_uuid_from_params, is_audio_silent
|
| 20 |
+
from acestep.constants import TASK_INSTRUCTIONS
|
| 21 |
+
from acestep.gpu_config import get_gpu_config
|
| 22 |
+
|
| 23 |
+
# HuggingFace Space environment detection
|
| 24 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 25 |
+
|
| 26 |
+
def _get_spaces_gpu_decorator(duration=180):
|
| 27 |
+
"""
|
| 28 |
+
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 29 |
+
Returns identity decorator if not in Space environment.
|
| 30 |
+
"""
|
| 31 |
+
if IS_HUGGINGFACE_SPACE:
|
| 32 |
+
try:
|
| 33 |
+
import spaces
|
| 34 |
+
return spaces.GPU(duration=duration)
|
| 35 |
+
except ImportError:
|
| 36 |
+
logger.warning("spaces package not found, GPU decorator disabled")
|
| 37 |
+
return lambda func: func
|
| 38 |
+
return lambda func: func
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class GenerationParams:
|
| 43 |
+
"""Configuration for music generation parameters.
|
| 44 |
+
|
| 45 |
+
Attributes:
|
| 46 |
+
# Text Inputs
|
| 47 |
+
caption: A short text prompt describing the desired music (main prompt). < 512 characters
|
| 48 |
+
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
|
| 49 |
+
instrumental: If True, generate instrumental music regardless of lyrics.
|
| 50 |
+
|
| 51 |
+
# Music Metadata
|
| 52 |
+
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
|
| 53 |
+
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
|
| 54 |
+
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
|
| 55 |
+
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
|
| 56 |
+
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
|
| 57 |
+
|
| 58 |
+
# Generation Parameters
|
| 59 |
+
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
|
| 60 |
+
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
|
| 61 |
+
seed: Integer seed for reproducibility. -1 means use random seed each time.
|
| 62 |
+
|
| 63 |
+
# Advanced DiT Parameters
|
| 64 |
+
use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
|
| 65 |
+
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
|
| 66 |
+
cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
|
| 67 |
+
shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
|
| 68 |
+
|
| 69 |
+
# Task-Specific Parameters
|
| 70 |
+
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
|
| 71 |
+
reference_audio: Path to a reference audio file for style transfer or cover tasks.
|
| 72 |
+
src_audio: Path to a source audio file for audio-to-audio tasks.
|
| 73 |
+
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
|
| 74 |
+
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
|
| 75 |
+
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
|
| 76 |
+
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
|
| 77 |
+
instruction: Optional task instruction prompt. If empty, auto-generated by system.
|
| 78 |
+
|
| 79 |
+
# 5Hz Language Model Parameters for CoT reasoning
|
| 80 |
+
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
|
| 81 |
+
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
|
| 82 |
+
lm_cfg_scale: Classifier-free guidance scale for the LLM.
|
| 83 |
+
lm_top_k: LLM top-k sampling (0 = disabled).
|
| 84 |
+
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
|
| 85 |
+
lm_negative_prompt: Negative prompt to use for LLM (for control).
|
| 86 |
+
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
|
| 87 |
+
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
|
| 88 |
+
use_cot_language: Whether to let LLM detect vocal language via CoT.
|
| 89 |
+
"""
|
| 90 |
+
# Required Inputs
|
| 91 |
+
task_type: str = "text2music"
|
| 92 |
+
instruction: str = "Fill the audio semantic mask based on the given conditions:"
|
| 93 |
+
|
| 94 |
+
# Audio Uploads
|
| 95 |
+
reference_audio: Optional[str] = None
|
| 96 |
+
src_audio: Optional[str] = None
|
| 97 |
+
|
| 98 |
+
# LM Codes Hints
|
| 99 |
+
audio_codes: str = ""
|
| 100 |
+
|
| 101 |
+
# Text Inputs
|
| 102 |
+
caption: str = ""
|
| 103 |
+
lyrics: str = ""
|
| 104 |
+
instrumental: bool = False
|
| 105 |
+
|
| 106 |
+
# Metadata
|
| 107 |
+
vocal_language: str = "unknown"
|
| 108 |
+
bpm: Optional[int] = None
|
| 109 |
+
keyscale: str = ""
|
| 110 |
+
timesignature: str = ""
|
| 111 |
+
duration: float = -1.0
|
| 112 |
+
|
| 113 |
+
# Advanced Settings
|
| 114 |
+
inference_steps: int = 8
|
| 115 |
+
seed: int = -1
|
| 116 |
+
guidance_scale: float = 7.0
|
| 117 |
+
use_adg: bool = False
|
| 118 |
+
cfg_interval_start: float = 0.0
|
| 119 |
+
cfg_interval_end: float = 1.0
|
| 120 |
+
shift: float = 1.0
|
| 121 |
+
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 122 |
+
# Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 123 |
+
# If provided, overrides inference_steps and shift
|
| 124 |
+
timesteps: Optional[List[float]] = None
|
| 125 |
+
|
| 126 |
+
repainting_start: float = 0.0
|
| 127 |
+
repainting_end: float = -1
|
| 128 |
+
audio_cover_strength: float = 1.0
|
| 129 |
+
|
| 130 |
+
# 5Hz Language Model Parameters
|
| 131 |
+
thinking: bool = True
|
| 132 |
+
lm_temperature: float = 0.85
|
| 133 |
+
lm_cfg_scale: float = 2.0
|
| 134 |
+
lm_top_k: int = 0
|
| 135 |
+
lm_top_p: float = 0.9
|
| 136 |
+
lm_negative_prompt: str = "NO USER INPUT"
|
| 137 |
+
use_cot_metas: bool = True
|
| 138 |
+
use_cot_caption: bool = True
|
| 139 |
+
use_cot_lyrics: bool = False # TODO: not used yet
|
| 140 |
+
use_cot_language: bool = True
|
| 141 |
+
use_constrained_decoding: bool = True
|
| 142 |
+
|
| 143 |
+
cot_bpm: Optional[int] = None
|
| 144 |
+
cot_keyscale: str = ""
|
| 145 |
+
cot_timesignature: str = ""
|
| 146 |
+
cot_duration: Optional[float] = None
|
| 147 |
+
cot_vocal_language: str = "unknown"
|
| 148 |
+
cot_caption: str = ""
|
| 149 |
+
cot_lyrics: str = ""
|
| 150 |
+
|
| 151 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 152 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 153 |
+
return asdict(self)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@dataclass
|
| 157 |
+
class GenerationConfig:
|
| 158 |
+
"""Configuration for music generation.
|
| 159 |
+
|
| 160 |
+
Attributes:
|
| 161 |
+
batch_size: Number of audio samples to generate
|
| 162 |
+
allow_lm_batch: Whether to allow batch processing in LM
|
| 163 |
+
use_random_seed: Whether to use random seed
|
| 164 |
+
seeds: Seed(s) for batch generation. Can be:
|
| 165 |
+
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
|
| 166 |
+
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
|
| 167 |
+
- int: Single seed value (will be converted to list and padded)
|
| 168 |
+
lm_batch_chunk_size: Batch chunk size for LM processing
|
| 169 |
+
constrained_decoding_debug: Whether to enable constrained decoding debug
|
| 170 |
+
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
|
| 171 |
+
"""
|
| 172 |
+
batch_size: int = 2
|
| 173 |
+
allow_lm_batch: bool = False
|
| 174 |
+
use_random_seed: bool = True
|
| 175 |
+
seeds: Optional[List[int]] = None
|
| 176 |
+
lm_batch_chunk_size: int = 8
|
| 177 |
+
constrained_decoding_debug: bool = False
|
| 178 |
+
audio_format: str = "flac" # Default to FLAC for fast saving
|
| 179 |
+
|
| 180 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 181 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 182 |
+
return asdict(self)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@dataclass
|
| 186 |
+
class GenerationResult:
|
| 187 |
+
"""Result of music generation.
|
| 188 |
+
|
| 189 |
+
Attributes:
|
| 190 |
+
# Audio Outputs
|
| 191 |
+
audios: List of audio dictionaries with paths, keys, params
|
| 192 |
+
status_message: Status message from generation
|
| 193 |
+
extra_outputs: Extra outputs from generation
|
| 194 |
+
success: Whether generation completed successfully
|
| 195 |
+
error: Error message if generation failed
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
# Audio Outputs
|
| 199 |
+
audios: List[Dict[str, Any]] = field(default_factory=list)
|
| 200 |
+
# Generation Information
|
| 201 |
+
status_message: str = ""
|
| 202 |
+
extra_outputs: Dict[str, Any] = field(default_factory=dict)
|
| 203 |
+
# Success Status
|
| 204 |
+
success: bool = True
|
| 205 |
+
error: Optional[str] = None
|
| 206 |
+
|
| 207 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 208 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 209 |
+
return asdict(self)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@dataclass
|
| 213 |
+
class UnderstandResult:
|
| 214 |
+
"""Result of music understanding from audio codes.
|
| 215 |
+
|
| 216 |
+
Attributes:
|
| 217 |
+
# Metadata Fields
|
| 218 |
+
caption: Generated caption describing the music
|
| 219 |
+
lyrics: Generated or extracted lyrics
|
| 220 |
+
bpm: Beats per minute (None if not detected)
|
| 221 |
+
duration: Duration in seconds (None if not detected)
|
| 222 |
+
keyscale: Musical key (e.g., "C Major")
|
| 223 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 224 |
+
timesignature: Time signature (e.g., "4/4")
|
| 225 |
+
|
| 226 |
+
# Status
|
| 227 |
+
status_message: Status message from understanding
|
| 228 |
+
success: Whether understanding completed successfully
|
| 229 |
+
error: Error message if understanding failed
|
| 230 |
+
"""
|
| 231 |
+
# Metadata Fields
|
| 232 |
+
caption: str = ""
|
| 233 |
+
lyrics: str = ""
|
| 234 |
+
bpm: Optional[int] = None
|
| 235 |
+
duration: Optional[float] = None
|
| 236 |
+
keyscale: str = ""
|
| 237 |
+
language: str = ""
|
| 238 |
+
timesignature: str = ""
|
| 239 |
+
|
| 240 |
+
# Status
|
| 241 |
+
status_message: str = ""
|
| 242 |
+
success: bool = True
|
| 243 |
+
error: Optional[str] = None
|
| 244 |
+
|
| 245 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 246 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 247 |
+
return asdict(self)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _update_metadata_from_lm(
|
| 251 |
+
metadata: Dict[str, Any],
|
| 252 |
+
bpm: Optional[int],
|
| 253 |
+
key_scale: str,
|
| 254 |
+
time_signature: str,
|
| 255 |
+
audio_duration: Optional[float],
|
| 256 |
+
vocal_language: str,
|
| 257 |
+
caption: str,
|
| 258 |
+
lyrics: str,
|
| 259 |
+
) -> Tuple[Optional[int], str, str, Optional[float], str, str, str]:
|
| 260 |
+
"""Update metadata fields from LM output if not provided by user."""
|
| 261 |
+
|
| 262 |
+
if bpm is None and metadata.get('bpm'):
|
| 263 |
+
bpm_value = metadata.get('bpm')
|
| 264 |
+
if bpm_value not in ["N/A", ""]:
|
| 265 |
+
try:
|
| 266 |
+
bpm = int(bpm_value)
|
| 267 |
+
except (ValueError, TypeError):
|
| 268 |
+
pass
|
| 269 |
+
|
| 270 |
+
if not key_scale and metadata.get('keyscale'):
|
| 271 |
+
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 272 |
+
if key_scale_value != "N/A":
|
| 273 |
+
key_scale = key_scale_value
|
| 274 |
+
|
| 275 |
+
if not time_signature and metadata.get('timesignature'):
|
| 276 |
+
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 277 |
+
if time_signature_value != "N/A":
|
| 278 |
+
time_signature = time_signature_value
|
| 279 |
+
|
| 280 |
+
if audio_duration is None or audio_duration <= 0:
|
| 281 |
+
audio_duration_value = metadata.get('duration', -1)
|
| 282 |
+
if audio_duration_value not in ["N/A", ""]:
|
| 283 |
+
try:
|
| 284 |
+
audio_duration = float(audio_duration_value)
|
| 285 |
+
except (ValueError, TypeError):
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
if not vocal_language and metadata.get('vocal_language'):
|
| 289 |
+
vocal_language = metadata.get('vocal_language')
|
| 290 |
+
if not caption and metadata.get('caption'):
|
| 291 |
+
caption = metadata.get('caption')
|
| 292 |
+
if not lyrics and metadata.get('lyrics'):
|
| 293 |
+
lyrics = metadata.get('lyrics')
|
| 294 |
+
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@_get_spaces_gpu_decorator(duration=180)
|
| 298 |
+
def generate_music(
|
| 299 |
+
dit_handler,
|
| 300 |
+
llm_handler,
|
| 301 |
+
params: GenerationParams,
|
| 302 |
+
config: GenerationConfig,
|
| 303 |
+
save_dir: Optional[str] = None,
|
| 304 |
+
progress=None,
|
| 305 |
+
) -> GenerationResult:
|
| 306 |
+
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 310 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 311 |
+
params: Generation parameters (GenerationParams instance)
|
| 312 |
+
config: Generation configuration (GenerationConfig instance)
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
GenerationResult with generated audio files and metadata
|
| 316 |
+
"""
|
| 317 |
+
try:
|
| 318 |
+
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 319 |
+
audio_code_string_to_use = params.audio_codes
|
| 320 |
+
lm_generated_metadata = None
|
| 321 |
+
lm_generated_audio_codes_list = []
|
| 322 |
+
lm_total_time_costs = {
|
| 323 |
+
"phase1_time": 0.0,
|
| 324 |
+
"phase2_time": 0.0,
|
| 325 |
+
"total_time": 0.0,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 329 |
+
bpm = params.bpm
|
| 330 |
+
key_scale = params.keyscale
|
| 331 |
+
time_signature = params.timesignature
|
| 332 |
+
audio_duration = params.duration
|
| 333 |
+
dit_input_caption = params.caption
|
| 334 |
+
dit_input_vocal_language = params.vocal_language
|
| 335 |
+
dit_input_lyrics = params.lyrics
|
| 336 |
+
# Determine if we need to generate audio codes
|
| 337 |
+
# If user has provided audio_codes, we don't need to generate them
|
| 338 |
+
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
|
| 339 |
+
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
|
| 340 |
+
|
| 341 |
+
# Safety: cover task without any source audio or codes produces silence.
|
| 342 |
+
if params.task_type == "cover":
|
| 343 |
+
no_src_audio = not (params.reference_audio or params.src_audio)
|
| 344 |
+
if no_src_audio and not user_provided_audio_codes:
|
| 345 |
+
logger.warning("Cover task requested without source audio or audio codes. Falling back to text2music.")
|
| 346 |
+
params.task_type = "text2music"
|
| 347 |
+
if params.instruction == TASK_INSTRUCTIONS.get("cover"):
|
| 348 |
+
params.instruction = TASK_INSTRUCTIONS.get("text2music", params.instruction)
|
| 349 |
+
|
| 350 |
+
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
|
| 351 |
+
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
|
| 352 |
+
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
|
| 353 |
+
# Note: This logic can be refined based on specific requirements
|
| 354 |
+
need_audio_codes = not user_provided_audio_codes
|
| 355 |
+
|
| 356 |
+
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
|
| 357 |
+
# Determine actual batch size for chunk processing
|
| 358 |
+
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
| 359 |
+
|
| 360 |
+
# Prepare seeds for batch generation
|
| 361 |
+
# Use config.seed if provided, otherwise fallback to params.seed
|
| 362 |
+
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
|
| 363 |
+
seed_for_generation = ""
|
| 364 |
+
# Original code (commented out because it crashes on int seeds):
|
| 365 |
+
# if config.seeds is not None and len(config.seeds) > 0:
|
| 366 |
+
# if isinstance(config.seeds, list):
|
| 367 |
+
# # Convert List[int] to comma-separated string
|
| 368 |
+
# seed_for_generation = ",".join(str(s) for s in config.seeds)
|
| 369 |
+
|
| 370 |
+
if config.seeds is not None:
|
| 371 |
+
if isinstance(config.seeds, list) and len(config.seeds) > 0:
|
| 372 |
+
# Convert List[int] to comma-separated string
|
| 373 |
+
seed_for_generation = ",".join(str(s) for s in config.seeds)
|
| 374 |
+
elif isinstance(config.seeds, int):
|
| 375 |
+
# Fix: Explicitly handle single integer seeds by converting to string.
|
| 376 |
+
# Previously, this would crash because 'len()' was called on an int.
|
| 377 |
+
seed_for_generation = str(config.seeds)
|
| 378 |
+
|
| 379 |
+
# Use dit_handler.prepare_seeds to handle seed list generation and padding
|
| 380 |
+
# This will handle all the logic: padding with random seeds if needed, etc.
|
| 381 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
| 382 |
+
|
| 383 |
+
# LM-based Chain-of-Thought reasoning
|
| 384 |
+
# Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
|
| 385 |
+
# and don't need LM to generate audio codes
|
| 386 |
+
skip_lm_tasks = {"cover", "repaint"}
|
| 387 |
+
|
| 388 |
+
# Determine if we should use LLM
|
| 389 |
+
# LLM is needed for:
|
| 390 |
+
# 1. thinking=True: generate audio codes via LM
|
| 391 |
+
# 2. use_cot_caption=True: enhance/generate caption via CoT
|
| 392 |
+
# 3. use_cot_language=True: detect vocal language via CoT
|
| 393 |
+
# 4. use_cot_metas=True: fill missing metadata via CoT
|
| 394 |
+
need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
|
| 395 |
+
use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
|
| 396 |
+
lm_status = []
|
| 397 |
+
|
| 398 |
+
if params.task_type in skip_lm_tasks:
|
| 399 |
+
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
|
| 400 |
+
|
| 401 |
+
logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
|
| 402 |
+
f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
|
| 403 |
+
f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
|
| 404 |
+
f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
|
| 405 |
+
|
| 406 |
+
def _infer_audio_duration_seconds(audio_path: str) -> Optional[float]:
|
| 407 |
+
"""Best-effort duration inference for common audio formats."""
|
| 408 |
+
if not audio_path:
|
| 409 |
+
return None
|
| 410 |
+
# Try torchaudio (supports more formats when ffmpeg backend is available)
|
| 411 |
+
try:
|
| 412 |
+
import torchaudio
|
| 413 |
+
info = torchaudio.info(audio_path)
|
| 414 |
+
if info and info.num_frames and info.sample_rate:
|
| 415 |
+
return float(info.num_frames) / float(info.sample_rate)
|
| 416 |
+
except Exception:
|
| 417 |
+
pass
|
| 418 |
+
# Try soundfile (fast for wav/flac)
|
| 419 |
+
try:
|
| 420 |
+
import soundfile as sf
|
| 421 |
+
info = sf.info(audio_path)
|
| 422 |
+
if info and info.frames and info.samplerate:
|
| 423 |
+
return float(info.frames) / float(info.samplerate)
|
| 424 |
+
except Exception:
|
| 425 |
+
pass
|
| 426 |
+
# macOS fallback: use afinfo for m4a/aac
|
| 427 |
+
if sys.platform == "darwin" and shutil.which("afinfo"):
|
| 428 |
+
try:
|
| 429 |
+
result = subprocess.run(
|
| 430 |
+
["afinfo", audio_path],
|
| 431 |
+
check=False,
|
| 432 |
+
capture_output=True,
|
| 433 |
+
text=True,
|
| 434 |
+
)
|
| 435 |
+
if result.stdout:
|
| 436 |
+
for line in result.stdout.splitlines():
|
| 437 |
+
if "duration:" in line:
|
| 438 |
+
# Example: "duration: 183.165s"
|
| 439 |
+
parts = line.strip().split()
|
| 440 |
+
for p in parts:
|
| 441 |
+
if p.endswith("s"):
|
| 442 |
+
try:
|
| 443 |
+
return float(p.rstrip("s"))
|
| 444 |
+
except ValueError:
|
| 445 |
+
continue
|
| 446 |
+
except Exception:
|
| 447 |
+
pass
|
| 448 |
+
return None
|
| 449 |
+
|
| 450 |
+
# Clamp duration and batch size to GPU limits (applies to non-Gradio callers too)
|
| 451 |
+
try:
|
| 452 |
+
# If duration not provided, try to infer from source audio to enable safe clamping.
|
| 453 |
+
if (audio_duration is None or float(audio_duration) <= 0) and (params.src_audio or params.reference_audio):
|
| 454 |
+
audio_path = params.src_audio or params.reference_audio
|
| 455 |
+
try:
|
| 456 |
+
inferred = _infer_audio_duration_seconds(audio_path)
|
| 457 |
+
if inferred and inferred > 0:
|
| 458 |
+
audio_duration = inferred
|
| 459 |
+
params.duration = inferred
|
| 460 |
+
logger.info(f"[generate_music] Inferred duration from audio file: {inferred:.2f}s")
|
| 461 |
+
except Exception as e:
|
| 462 |
+
logger.warning(f"[generate_music] Failed to infer duration from audio file: {e}")
|
| 463 |
+
|
| 464 |
+
gpu_config = get_gpu_config()
|
| 465 |
+
max_duration = gpu_config.max_duration_with_lm if use_lm else gpu_config.max_duration_without_lm
|
| 466 |
+
if audio_duration is not None and float(audio_duration) > 0 and float(audio_duration) > max_duration:
|
| 467 |
+
logger.warning(f"[generate_music] Duration {audio_duration}s exceeds GPU limit {max_duration}s. Clamping.")
|
| 468 |
+
audio_duration = float(max_duration)
|
| 469 |
+
params.duration = float(max_duration)
|
| 470 |
+
|
| 471 |
+
max_batch = gpu_config.max_batch_size_with_lm if use_lm else gpu_config.max_batch_size_without_lm
|
| 472 |
+
if config.batch_size is not None and config.batch_size > max_batch:
|
| 473 |
+
logger.warning(f"[generate_music] Batch size {config.batch_size} exceeds GPU limit {max_batch}. Clamping.")
|
| 474 |
+
config.batch_size = max_batch
|
| 475 |
+
|
| 476 |
+
# Extra safety for MPS: large durations can OOM with batch > 1
|
| 477 |
+
if (
|
| 478 |
+
hasattr(dit_handler, "device")
|
| 479 |
+
and dit_handler.device == "mps"
|
| 480 |
+
and audio_duration is not None
|
| 481 |
+
and float(audio_duration) > 180
|
| 482 |
+
and config.batch_size is not None
|
| 483 |
+
and config.batch_size > 1
|
| 484 |
+
):
|
| 485 |
+
logger.warning("[generate_music] MPS with long duration detected; reducing batch size to 1 to avoid OOM.")
|
| 486 |
+
config.batch_size = 1
|
| 487 |
+
except Exception as e:
|
| 488 |
+
logger.warning(f"[generate_music] Failed to clamp duration/batch to GPU limits: {e}")
|
| 489 |
+
|
| 490 |
+
if use_lm:
|
| 491 |
+
# Convert sampling parameters - handle None values safely
|
| 492 |
+
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|
| 493 |
+
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
|
| 494 |
+
|
| 495 |
+
# Build user_metadata from user-provided values
|
| 496 |
+
user_metadata = {}
|
| 497 |
+
if bpm is not None:
|
| 498 |
+
try:
|
| 499 |
+
bpm_value = float(bpm)
|
| 500 |
+
if bpm_value > 0:
|
| 501 |
+
user_metadata['bpm'] = int(bpm_value)
|
| 502 |
+
except (ValueError, TypeError):
|
| 503 |
+
pass
|
| 504 |
+
|
| 505 |
+
if key_scale and key_scale.strip():
|
| 506 |
+
key_scale_clean = key_scale.strip()
|
| 507 |
+
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 508 |
+
user_metadata['keyscale'] = key_scale_clean
|
| 509 |
+
|
| 510 |
+
if time_signature and time_signature.strip():
|
| 511 |
+
time_sig_clean = time_signature.strip()
|
| 512 |
+
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 513 |
+
user_metadata['timesignature'] = time_sig_clean
|
| 514 |
+
|
| 515 |
+
if audio_duration is not None:
|
| 516 |
+
try:
|
| 517 |
+
duration_value = float(audio_duration)
|
| 518 |
+
if duration_value > 0:
|
| 519 |
+
user_metadata['duration'] = int(duration_value)
|
| 520 |
+
except (ValueError, TypeError):
|
| 521 |
+
pass
|
| 522 |
+
|
| 523 |
+
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 524 |
+
|
| 525 |
+
# Determine infer_type based on whether we need audio codes
|
| 526 |
+
# - "llm_dit": generates both metas and audio codes (two-phase internally)
|
| 527 |
+
# - "dit": generates only metas (single phase)
|
| 528 |
+
infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
|
| 529 |
+
|
| 530 |
+
# Use chunk size from config, or default to batch_size if not set
|
| 531 |
+
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
|
| 532 |
+
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
|
| 533 |
+
|
| 534 |
+
all_metadata_list = []
|
| 535 |
+
all_audio_codes_list = []
|
| 536 |
+
|
| 537 |
+
for chunk_idx in range(num_chunks):
|
| 538 |
+
chunk_start = chunk_idx * max_inference_batch_size
|
| 539 |
+
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
|
| 540 |
+
chunk_size = chunk_end - chunk_start
|
| 541 |
+
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
|
| 542 |
+
|
| 543 |
+
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
|
| 544 |
+
f"(size: {chunk_size}, seeds: {chunk_seeds})")
|
| 545 |
+
|
| 546 |
+
# Use the determined infer_type
|
| 547 |
+
# - "llm_dit" will internally run two phases (metas + codes)
|
| 548 |
+
# - "dit" will only run phase 1 (metas only)
|
| 549 |
+
result = llm_handler.generate_with_stop_condition(
|
| 550 |
+
caption=params.caption or "",
|
| 551 |
+
lyrics=params.lyrics or "",
|
| 552 |
+
infer_type=infer_type,
|
| 553 |
+
temperature=params.lm_temperature,
|
| 554 |
+
cfg_scale=params.lm_cfg_scale,
|
| 555 |
+
negative_prompt=params.lm_negative_prompt,
|
| 556 |
+
top_k=top_k_value,
|
| 557 |
+
top_p=top_p_value,
|
| 558 |
+
target_duration=audio_duration, # Pass duration to limit audio codes generation
|
| 559 |
+
user_metadata=user_metadata_to_pass,
|
| 560 |
+
use_cot_caption=params.use_cot_caption,
|
| 561 |
+
use_cot_language=params.use_cot_language,
|
| 562 |
+
use_cot_metas=params.use_cot_metas,
|
| 563 |
+
use_constrained_decoding=params.use_constrained_decoding,
|
| 564 |
+
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 565 |
+
batch_size=chunk_size,
|
| 566 |
+
seeds=chunk_seeds,
|
| 567 |
+
progress=progress,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# Check if LM generation failed
|
| 571 |
+
if not result.get("success", False):
|
| 572 |
+
error_msg = result.get("error", "Unknown LM error")
|
| 573 |
+
lm_status.append(f"❌ LM Error: {error_msg}")
|
| 574 |
+
# Return early with error
|
| 575 |
+
return GenerationResult(
|
| 576 |
+
audios=[],
|
| 577 |
+
status_message=f"❌ LM generation failed: {error_msg}",
|
| 578 |
+
extra_outputs={},
|
| 579 |
+
success=False,
|
| 580 |
+
error=error_msg,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# Extract metadata and audio_codes from result dict
|
| 584 |
+
if chunk_size > 1:
|
| 585 |
+
metadata_list = result.get("metadata", [])
|
| 586 |
+
audio_codes_list = result.get("audio_codes", [])
|
| 587 |
+
all_metadata_list.extend(metadata_list)
|
| 588 |
+
all_audio_codes_list.extend(audio_codes_list)
|
| 589 |
+
else:
|
| 590 |
+
metadata = result.get("metadata", {})
|
| 591 |
+
audio_codes = result.get("audio_codes", "")
|
| 592 |
+
all_metadata_list.append(metadata)
|
| 593 |
+
all_audio_codes_list.append(audio_codes)
|
| 594 |
+
|
| 595 |
+
# Collect time costs from LM extra_outputs
|
| 596 |
+
lm_extra = result.get("extra_outputs", {})
|
| 597 |
+
lm_chunk_time_costs = lm_extra.get("time_costs", {})
|
| 598 |
+
if lm_chunk_time_costs:
|
| 599 |
+
# Accumulate time costs from all chunks
|
| 600 |
+
for key in ["phase1_time", "phase2_time", "total_time"]:
|
| 601 |
+
if key in lm_chunk_time_costs:
|
| 602 |
+
lm_total_time_costs[key] += lm_chunk_time_costs[key]
|
| 603 |
+
|
| 604 |
+
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
|
| 605 |
+
lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
|
| 606 |
+
|
| 607 |
+
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 608 |
+
lm_generated_audio_codes_list = all_audio_codes_list
|
| 609 |
+
|
| 610 |
+
# Set audio_code_string_to_use based on infer_type
|
| 611 |
+
if infer_type == "llm_dit":
|
| 612 |
+
# If batch mode, use list; otherwise use single string
|
| 613 |
+
if actual_batch_size > 1:
|
| 614 |
+
audio_code_string_to_use = all_audio_codes_list
|
| 615 |
+
else:
|
| 616 |
+
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
|
| 617 |
+
else:
|
| 618 |
+
# For "dit" mode, keep user-provided codes or empty
|
| 619 |
+
audio_code_string_to_use = params.audio_codes
|
| 620 |
+
|
| 621 |
+
# Update metadata from LM if not provided by user
|
| 622 |
+
if lm_generated_metadata:
|
| 623 |
+
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
|
| 624 |
+
metadata=lm_generated_metadata,
|
| 625 |
+
bpm=bpm,
|
| 626 |
+
key_scale=key_scale,
|
| 627 |
+
time_signature=time_signature,
|
| 628 |
+
audio_duration=audio_duration,
|
| 629 |
+
vocal_language=dit_input_vocal_language,
|
| 630 |
+
caption=dit_input_caption,
|
| 631 |
+
lyrics=dit_input_lyrics)
|
| 632 |
+
if not params.bpm:
|
| 633 |
+
params.cot_bpm = bpm
|
| 634 |
+
if not params.keyscale:
|
| 635 |
+
params.cot_keyscale = key_scale
|
| 636 |
+
if not params.timesignature:
|
| 637 |
+
params.cot_timesignature = time_signature
|
| 638 |
+
if not params.duration:
|
| 639 |
+
params.cot_duration = audio_duration
|
| 640 |
+
if not params.vocal_language:
|
| 641 |
+
params.cot_vocal_language = vocal_language
|
| 642 |
+
if not params.caption:
|
| 643 |
+
params.cot_caption = caption
|
| 644 |
+
if not params.lyrics:
|
| 645 |
+
params.cot_lyrics = lyrics
|
| 646 |
+
dit_input_lyrics = lyrics
|
| 647 |
+
|
| 648 |
+
# set cot caption and language if needed
|
| 649 |
+
if params.use_cot_caption:
|
| 650 |
+
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
|
| 651 |
+
if params.use_cot_language:
|
| 652 |
+
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
|
| 653 |
+
|
| 654 |
+
# Phase 2: DiT music generation
|
| 655 |
+
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
|
| 656 |
+
result = dit_handler.generate_music(
|
| 657 |
+
captions=dit_input_caption,
|
| 658 |
+
lyrics=dit_input_lyrics,
|
| 659 |
+
bpm=bpm,
|
| 660 |
+
key_scale=key_scale,
|
| 661 |
+
time_signature=time_signature,
|
| 662 |
+
vocal_language=dit_input_vocal_language,
|
| 663 |
+
inference_steps=params.inference_steps,
|
| 664 |
+
guidance_scale=params.guidance_scale,
|
| 665 |
+
use_random_seed=config.use_random_seed,
|
| 666 |
+
seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
|
| 667 |
+
reference_audio=params.reference_audio,
|
| 668 |
+
audio_duration=audio_duration,
|
| 669 |
+
batch_size=config.batch_size if config.batch_size is not None else 1,
|
| 670 |
+
src_audio=params.src_audio,
|
| 671 |
+
audio_code_string=audio_code_string_to_use,
|
| 672 |
+
repainting_start=params.repainting_start,
|
| 673 |
+
repainting_end=params.repainting_end,
|
| 674 |
+
instruction=params.instruction,
|
| 675 |
+
audio_cover_strength=params.audio_cover_strength,
|
| 676 |
+
task_type=params.task_type,
|
| 677 |
+
use_adg=params.use_adg,
|
| 678 |
+
cfg_interval_start=params.cfg_interval_start,
|
| 679 |
+
cfg_interval_end=params.cfg_interval_end,
|
| 680 |
+
shift=params.shift,
|
| 681 |
+
infer_method=params.infer_method,
|
| 682 |
+
timesteps=params.timesteps,
|
| 683 |
+
progress=progress,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# Check if generation failed
|
| 687 |
+
if not result.get("success", False):
|
| 688 |
+
return GenerationResult(
|
| 689 |
+
audios=[],
|
| 690 |
+
status_message=result.get("status_message", ""),
|
| 691 |
+
extra_outputs={},
|
| 692 |
+
success=False,
|
| 693 |
+
error=result.get("error"),
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# Extract results from dit_handler.generate_music dict
|
| 697 |
+
dit_audios = result.get("audios", [])
|
| 698 |
+
status_message = result.get("status_message", "")
|
| 699 |
+
dit_extra_outputs = result.get("extra_outputs", {})
|
| 700 |
+
|
| 701 |
+
# Use the seed list already prepared above (from config.seed or params.seed fallback)
|
| 702 |
+
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
|
| 703 |
+
seed_list = actual_seed_list
|
| 704 |
+
|
| 705 |
+
# Get base params dictionary
|
| 706 |
+
base_params_dict = params.to_dict()
|
| 707 |
+
|
| 708 |
+
# Save audio files using AudioSaver (format from config)
|
| 709 |
+
audio_format = config.audio_format if config.audio_format else "flac"
|
| 710 |
+
audio_saver = AudioSaver(default_format=audio_format)
|
| 711 |
+
|
| 712 |
+
# Use handler's temp_dir for saving files
|
| 713 |
+
if save_dir is not None:
|
| 714 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 715 |
+
|
| 716 |
+
# Build audios list for GenerationResult with params and save files
|
| 717 |
+
# Audio saving and UUID generation handled here, outside of handler
|
| 718 |
+
audios = []
|
| 719 |
+
silent_warnings = []
|
| 720 |
+
for idx, dit_audio in enumerate(dit_audios):
|
| 721 |
+
# Create a copy of params dict for this audio
|
| 722 |
+
audio_params = base_params_dict.copy()
|
| 723 |
+
|
| 724 |
+
# Update audio-specific values
|
| 725 |
+
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
|
| 726 |
+
|
| 727 |
+
# Add audio codes if batch mode
|
| 728 |
+
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
|
| 729 |
+
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
|
| 730 |
+
|
| 731 |
+
# Get audio tensor and metadata
|
| 732 |
+
audio_tensor = dit_audio.get("tensor")
|
| 733 |
+
sample_rate = dit_audio.get("sample_rate", 48000)
|
| 734 |
+
|
| 735 |
+
# Generate UUID for this audio (moved from handler)
|
| 736 |
+
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
|
| 737 |
+
audio_code_str = lm_generated_audio_codes_list[idx] if (
|
| 738 |
+
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
|
| 739 |
+
if isinstance(audio_code_str, list):
|
| 740 |
+
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
|
| 741 |
+
|
| 742 |
+
audio_key = generate_uuid_from_params(audio_params)
|
| 743 |
+
|
| 744 |
+
silent_check = False
|
| 745 |
+
if audio_tensor is not None:
|
| 746 |
+
silent_check, rms_val, peak_val = is_audio_silent(audio_tensor, channels_first=True)
|
| 747 |
+
if silent_check:
|
| 748 |
+
logger.warning(
|
| 749 |
+
f"[generate_music] Silent output detected (idx={idx}, RMS={rms_val:.2e}, peak={peak_val:.2e}). "
|
| 750 |
+
"Likely cause: LLM backend returned empty conditioning, or incompatible torch/triton/flash-attn. "
|
| 751 |
+
"Suggest running with --backend pt."
|
| 752 |
+
)
|
| 753 |
+
silent_warnings.append(
|
| 754 |
+
f"Output {idx + 1}: silent or near-silent (RMS≈{rms_val:.2e}). "
|
| 755 |
+
"Likely causes: LLM backend failure, incompatible torch/triton/flash-attn, or CPU/fallback path. "
|
| 756 |
+
"Try running with --backend pt."
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
audio_path = None
|
| 760 |
+
if audio_tensor is not None and save_dir is not None and not silent_check:
|
| 761 |
+
try:
|
| 762 |
+
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
|
| 763 |
+
audio_path = audio_saver.save_audio(audio_tensor,
|
| 764 |
+
audio_file,
|
| 765 |
+
sample_rate=sample_rate,
|
| 766 |
+
format=audio_format,
|
| 767 |
+
channels_first=True)
|
| 768 |
+
except Exception as e:
|
| 769 |
+
logger.error(f"[generate_music] Failed to save audio file: {e}")
|
| 770 |
+
audio_path = ""
|
| 771 |
+
|
| 772 |
+
audio_dict = {
|
| 773 |
+
"path": audio_path or "",
|
| 774 |
+
"tensor": audio_tensor,
|
| 775 |
+
"key": audio_key,
|
| 776 |
+
"sample_rate": sample_rate,
|
| 777 |
+
"params": audio_params,
|
| 778 |
+
"silent": silent_check,
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
audios.append(audio_dict)
|
| 782 |
+
|
| 783 |
+
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
|
| 784 |
+
extra_outputs = dit_extra_outputs.copy()
|
| 785 |
+
extra_outputs["lm_metadata"] = lm_generated_metadata
|
| 786 |
+
|
| 787 |
+
# Merge time_costs from both LM and DiT into a unified dictionary
|
| 788 |
+
unified_time_costs = {}
|
| 789 |
+
|
| 790 |
+
# Add LM time costs (if LM was used)
|
| 791 |
+
if use_lm and lm_total_time_costs:
|
| 792 |
+
for key, value in lm_total_time_costs.items():
|
| 793 |
+
unified_time_costs[f"lm_{key}"] = value
|
| 794 |
+
|
| 795 |
+
# Add DiT time costs (if available)
|
| 796 |
+
dit_time_costs = dit_extra_outputs.get("time_costs", {})
|
| 797 |
+
if dit_time_costs:
|
| 798 |
+
for key, value in dit_time_costs.items():
|
| 799 |
+
unified_time_costs[f"dit_{key}"] = value
|
| 800 |
+
|
| 801 |
+
# Calculate total pipeline time
|
| 802 |
+
if unified_time_costs:
|
| 803 |
+
lm_total = unified_time_costs.get("lm_total_time", 0.0)
|
| 804 |
+
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
|
| 805 |
+
unified_time_costs["pipeline_total_time"] = lm_total + dit_total
|
| 806 |
+
|
| 807 |
+
# Update extra_outputs with unified time_costs
|
| 808 |
+
extra_outputs["time_costs"] = unified_time_costs
|
| 809 |
+
|
| 810 |
+
if lm_status:
|
| 811 |
+
status_message = "\n".join(lm_status) + "\n" + status_message
|
| 812 |
+
else:
|
| 813 |
+
status_message = status_message
|
| 814 |
+
if silent_warnings:
|
| 815 |
+
status_message = "⚠️ Silent output detected:\n" + "\n".join(silent_warnings) + "\n\nSuggested fix: try running with --backend pt\n\n" + (status_message or "")
|
| 816 |
+
# Create and return GenerationResult
|
| 817 |
+
return GenerationResult(
|
| 818 |
+
audios=audios,
|
| 819 |
+
status_message=status_message,
|
| 820 |
+
extra_outputs=extra_outputs,
|
| 821 |
+
success=True,
|
| 822 |
+
error=None,
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
except Exception as e:
|
| 826 |
+
logger.exception("Music generation failed")
|
| 827 |
+
return GenerationResult(
|
| 828 |
+
audios=[],
|
| 829 |
+
status_message=f"Error: {str(e)}",
|
| 830 |
+
extra_outputs={},
|
| 831 |
+
success=False,
|
| 832 |
+
error=str(e),
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
def understand_music(
|
| 837 |
+
llm_handler,
|
| 838 |
+
audio_codes: str,
|
| 839 |
+
temperature: float = 0.85,
|
| 840 |
+
top_k: Optional[int] = None,
|
| 841 |
+
top_p: Optional[float] = None,
|
| 842 |
+
repetition_penalty: float = 1.0,
|
| 843 |
+
use_constrained_decoding: bool = True,
|
| 844 |
+
constrained_decoding_debug: bool = False,
|
| 845 |
+
) -> UnderstandResult:
|
| 846 |
+
"""Understand music from audio codes using the 5Hz Language Model.
|
| 847 |
+
|
| 848 |
+
This function analyzes audio semantic codes and generates metadata about the music,
|
| 849 |
+
including caption, lyrics, BPM, duration, key scale, language, and time signature.
|
| 850 |
+
|
| 851 |
+
If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
|
| 852 |
+
instead of analyzing existing codes.
|
| 853 |
+
|
| 854 |
+
Note: cfg_scale and negative_prompt are not supported in understand mode.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 858 |
+
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
|
| 859 |
+
Use empty string or "NO USER INPUT" to generate a sample example.
|
| 860 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 861 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 862 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 863 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 864 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 865 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 866 |
+
|
| 867 |
+
Returns:
|
| 868 |
+
UnderstandResult with parsed metadata fields and status
|
| 869 |
+
|
| 870 |
+
Example:
|
| 871 |
+
>>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
|
| 872 |
+
>>> if result.success:
|
| 873 |
+
... print(f"Caption: {result.caption}")
|
| 874 |
+
... print(f"BPM: {result.bpm}")
|
| 875 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 876 |
+
"""
|
| 877 |
+
# Check if LLM is initialized
|
| 878 |
+
if not llm_handler.llm_initialized:
|
| 879 |
+
return UnderstandResult(
|
| 880 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 881 |
+
success=False,
|
| 882 |
+
error="LLM not initialized",
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
# If codes are empty, use "NO USER INPUT" to generate a sample example
|
| 886 |
+
if not audio_codes or not audio_codes.strip():
|
| 887 |
+
audio_codes = "NO USER INPUT"
|
| 888 |
+
|
| 889 |
+
try:
|
| 890 |
+
# Call LLM understanding
|
| 891 |
+
metadata, status = llm_handler.understand_audio_from_codes(
|
| 892 |
+
audio_codes=audio_codes,
|
| 893 |
+
temperature=temperature,
|
| 894 |
+
top_k=top_k,
|
| 895 |
+
top_p=top_p,
|
| 896 |
+
repetition_penalty=repetition_penalty,
|
| 897 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 898 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
# Check if LLM returned empty metadata (error case)
|
| 902 |
+
if not metadata:
|
| 903 |
+
return UnderstandResult(
|
| 904 |
+
status_message=status or "Failed to understand audio codes",
|
| 905 |
+
success=False,
|
| 906 |
+
error=status or "Empty metadata returned",
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# Extract and convert fields
|
| 910 |
+
caption = metadata.get('caption', '')
|
| 911 |
+
lyrics = metadata.get('lyrics', '')
|
| 912 |
+
keyscale = metadata.get('keyscale', '')
|
| 913 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 914 |
+
timesignature = metadata.get('timesignature', '')
|
| 915 |
+
|
| 916 |
+
# Convert BPM to int
|
| 917 |
+
bpm = None
|
| 918 |
+
bpm_value = metadata.get('bpm')
|
| 919 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 920 |
+
try:
|
| 921 |
+
bpm = int(bpm_value)
|
| 922 |
+
except (ValueError, TypeError):
|
| 923 |
+
pass
|
| 924 |
+
|
| 925 |
+
# Convert duration to float
|
| 926 |
+
duration = None
|
| 927 |
+
duration_value = metadata.get('duration')
|
| 928 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 929 |
+
try:
|
| 930 |
+
duration = float(duration_value)
|
| 931 |
+
except (ValueError, TypeError):
|
| 932 |
+
pass
|
| 933 |
+
|
| 934 |
+
# Clean up N/A values
|
| 935 |
+
if keyscale == 'N/A':
|
| 936 |
+
keyscale = ''
|
| 937 |
+
if language == 'N/A':
|
| 938 |
+
language = ''
|
| 939 |
+
if timesignature == 'N/A':
|
| 940 |
+
timesignature = ''
|
| 941 |
+
|
| 942 |
+
return UnderstandResult(
|
| 943 |
+
caption=caption,
|
| 944 |
+
lyrics=lyrics,
|
| 945 |
+
bpm=bpm,
|
| 946 |
+
duration=duration,
|
| 947 |
+
keyscale=keyscale,
|
| 948 |
+
language=language,
|
| 949 |
+
timesignature=timesignature,
|
| 950 |
+
status_message=status,
|
| 951 |
+
success=True,
|
| 952 |
+
error=None,
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
except Exception as e:
|
| 956 |
+
logger.exception("Music understanding failed")
|
| 957 |
+
return UnderstandResult(
|
| 958 |
+
status_message=f"Error: {str(e)}",
|
| 959 |
+
success=False,
|
| 960 |
+
error=str(e),
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
@dataclass
|
| 965 |
+
class CreateSampleResult:
|
| 966 |
+
"""Result of creating a music sample from a natural language query.
|
| 967 |
+
|
| 968 |
+
This is used by the "Simple Mode" / "Inspiration Mode" feature where users
|
| 969 |
+
provide a natural language description and the LLM generates a complete
|
| 970 |
+
sample with caption, lyrics, and metadata.
|
| 971 |
+
|
| 972 |
+
Attributes:
|
| 973 |
+
# Metadata Fields
|
| 974 |
+
caption: Generated detailed music description/caption
|
| 975 |
+
lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
|
| 976 |
+
bpm: Beats per minute (None if not generated)
|
| 977 |
+
duration: Duration in seconds (None if not generated)
|
| 978 |
+
keyscale: Musical key (e.g., "C Major")
|
| 979 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 980 |
+
timesignature: Time signature (e.g., "4")
|
| 981 |
+
instrumental: Whether this is an instrumental piece
|
| 982 |
+
|
| 983 |
+
# Status
|
| 984 |
+
status_message: Status message from sample creation
|
| 985 |
+
success: Whether sample creation completed successfully
|
| 986 |
+
error: Error message if sample creation failed
|
| 987 |
+
"""
|
| 988 |
+
# Metadata Fields
|
| 989 |
+
caption: str = ""
|
| 990 |
+
lyrics: str = ""
|
| 991 |
+
bpm: Optional[int] = None
|
| 992 |
+
duration: Optional[float] = None
|
| 993 |
+
keyscale: str = ""
|
| 994 |
+
language: str = ""
|
| 995 |
+
timesignature: str = ""
|
| 996 |
+
instrumental: bool = False
|
| 997 |
+
|
| 998 |
+
# Status
|
| 999 |
+
status_message: str = ""
|
| 1000 |
+
success: bool = True
|
| 1001 |
+
error: Optional[str] = None
|
| 1002 |
+
|
| 1003 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1004 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 1005 |
+
return asdict(self)
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
def create_sample(
|
| 1009 |
+
llm_handler,
|
| 1010 |
+
query: str,
|
| 1011 |
+
instrumental: bool = False,
|
| 1012 |
+
vocal_language: Optional[str] = None,
|
| 1013 |
+
temperature: float = 0.85,
|
| 1014 |
+
top_k: Optional[int] = None,
|
| 1015 |
+
top_p: Optional[float] = None,
|
| 1016 |
+
repetition_penalty: float = 1.0,
|
| 1017 |
+
use_constrained_decoding: bool = True,
|
| 1018 |
+
constrained_decoding_debug: bool = False,
|
| 1019 |
+
) -> CreateSampleResult:
|
| 1020 |
+
"""Create a music sample from a natural language query using the 5Hz Language Model.
|
| 1021 |
+
|
| 1022 |
+
This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
|
| 1023 |
+
language description of music and generates a complete sample including:
|
| 1024 |
+
- Detailed caption/description
|
| 1025 |
+
- Lyrics (unless instrumental)
|
| 1026 |
+
- Metadata (BPM, duration, key, language, time signature)
|
| 1027 |
+
|
| 1028 |
+
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 1029 |
+
|
| 1030 |
+
Args:
|
| 1031 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 1032 |
+
query: User's natural language music description (e.g., "a soft Bengali love song")
|
| 1033 |
+
instrumental: Whether to generate instrumental music (no vocals)
|
| 1034 |
+
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
|
| 1035 |
+
If provided, the model will be constrained to generate lyrics in this language.
|
| 1036 |
+
If None or "unknown", no language constraint is applied.
|
| 1037 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 1038 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 1039 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 1040 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 1041 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 1042 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 1043 |
+
|
| 1044 |
+
Returns:
|
| 1045 |
+
CreateSampleResult with generated sample fields and status
|
| 1046 |
+
|
| 1047 |
+
Example:
|
| 1048 |
+
>>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
|
| 1049 |
+
>>> if result.success:
|
| 1050 |
+
... print(f"Caption: {result.caption}")
|
| 1051 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 1052 |
+
... print(f"BPM: {result.bpm}")
|
| 1053 |
+
"""
|
| 1054 |
+
# Check if LLM is initialized
|
| 1055 |
+
if not llm_handler.llm_initialized:
|
| 1056 |
+
return CreateSampleResult(
|
| 1057 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 1058 |
+
success=False,
|
| 1059 |
+
error="LLM not initialized",
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
try:
|
| 1063 |
+
# Call LLM to create sample
|
| 1064 |
+
metadata, status = llm_handler.create_sample_from_query(
|
| 1065 |
+
query=query,
|
| 1066 |
+
instrumental=instrumental,
|
| 1067 |
+
vocal_language=vocal_language,
|
| 1068 |
+
temperature=temperature,
|
| 1069 |
+
top_k=top_k,
|
| 1070 |
+
top_p=top_p,
|
| 1071 |
+
repetition_penalty=repetition_penalty,
|
| 1072 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1073 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
# Check if LLM returned empty metadata (error case)
|
| 1077 |
+
if not metadata:
|
| 1078 |
+
return CreateSampleResult(
|
| 1079 |
+
status_message=status or "Failed to create sample",
|
| 1080 |
+
success=False,
|
| 1081 |
+
error=status or "Empty metadata returned",
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
# Extract and convert fields
|
| 1085 |
+
caption = metadata.get('caption', '')
|
| 1086 |
+
lyrics = metadata.get('lyrics', '')
|
| 1087 |
+
keyscale = metadata.get('keyscale', '')
|
| 1088 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 1089 |
+
timesignature = metadata.get('timesignature', '')
|
| 1090 |
+
is_instrumental = metadata.get('instrumental', instrumental)
|
| 1091 |
+
|
| 1092 |
+
# Convert BPM to int
|
| 1093 |
+
bpm = None
|
| 1094 |
+
bpm_value = metadata.get('bpm')
|
| 1095 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 1096 |
+
try:
|
| 1097 |
+
bpm = int(bpm_value)
|
| 1098 |
+
except (ValueError, TypeError):
|
| 1099 |
+
pass
|
| 1100 |
+
|
| 1101 |
+
# Convert duration to float
|
| 1102 |
+
duration = None
|
| 1103 |
+
duration_value = metadata.get('duration')
|
| 1104 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 1105 |
+
try:
|
| 1106 |
+
duration = float(duration_value)
|
| 1107 |
+
except (ValueError, TypeError):
|
| 1108 |
+
pass
|
| 1109 |
+
|
| 1110 |
+
# Clean up N/A values
|
| 1111 |
+
if keyscale == 'N/A':
|
| 1112 |
+
keyscale = ''
|
| 1113 |
+
if language == 'N/A':
|
| 1114 |
+
language = ''
|
| 1115 |
+
if timesignature == 'N/A':
|
| 1116 |
+
timesignature = ''
|
| 1117 |
+
|
| 1118 |
+
return CreateSampleResult(
|
| 1119 |
+
caption=caption,
|
| 1120 |
+
lyrics=lyrics,
|
| 1121 |
+
bpm=bpm,
|
| 1122 |
+
duration=duration,
|
| 1123 |
+
keyscale=keyscale,
|
| 1124 |
+
language=language,
|
| 1125 |
+
timesignature=timesignature,
|
| 1126 |
+
instrumental=is_instrumental,
|
| 1127 |
+
status_message=status,
|
| 1128 |
+
success=True,
|
| 1129 |
+
error=None,
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
except Exception as e:
|
| 1133 |
+
logger.exception("Sample creation failed")
|
| 1134 |
+
return CreateSampleResult(
|
| 1135 |
+
status_message=f"Error: {str(e)}",
|
| 1136 |
+
success=False,
|
| 1137 |
+
error=str(e),
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
@dataclass
|
| 1142 |
+
class FormatSampleResult:
|
| 1143 |
+
"""Result of formatting user-provided caption and lyrics.
|
| 1144 |
+
|
| 1145 |
+
This is used by the "Format" feature where users provide caption and lyrics,
|
| 1146 |
+
and the LLM formats them into structured music metadata and an enhanced description.
|
| 1147 |
+
|
| 1148 |
+
Attributes:
|
| 1149 |
+
# Metadata Fields
|
| 1150 |
+
caption: Enhanced/formatted music description/caption
|
| 1151 |
+
lyrics: Formatted lyrics (may be same as input or reformatted)
|
| 1152 |
+
bpm: Beats per minute (None if not detected)
|
| 1153 |
+
duration: Duration in seconds (None if not detected)
|
| 1154 |
+
keyscale: Musical key (e.g., "C Major")
|
| 1155 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 1156 |
+
timesignature: Time signature (e.g., "4")
|
| 1157 |
+
|
| 1158 |
+
# Status
|
| 1159 |
+
status_message: Status message from formatting
|
| 1160 |
+
success: Whether formatting completed successfully
|
| 1161 |
+
error: Error message if formatting failed
|
| 1162 |
+
"""
|
| 1163 |
+
# Metadata Fields
|
| 1164 |
+
caption: str = ""
|
| 1165 |
+
lyrics: str = ""
|
| 1166 |
+
bpm: Optional[int] = None
|
| 1167 |
+
duration: Optional[float] = None
|
| 1168 |
+
keyscale: str = ""
|
| 1169 |
+
language: str = ""
|
| 1170 |
+
timesignature: str = ""
|
| 1171 |
+
|
| 1172 |
+
# Status
|
| 1173 |
+
status_message: str = ""
|
| 1174 |
+
success: bool = True
|
| 1175 |
+
error: Optional[str] = None
|
| 1176 |
+
|
| 1177 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1178 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 1179 |
+
return asdict(self)
|
| 1180 |
+
|
| 1181 |
+
|
| 1182 |
+
def format_sample(
|
| 1183 |
+
llm_handler,
|
| 1184 |
+
caption: str,
|
| 1185 |
+
lyrics: str,
|
| 1186 |
+
user_metadata: Optional[Dict[str, Any]] = None,
|
| 1187 |
+
temperature: float = 0.85,
|
| 1188 |
+
top_k: Optional[int] = None,
|
| 1189 |
+
top_p: Optional[float] = None,
|
| 1190 |
+
repetition_penalty: float = 1.0,
|
| 1191 |
+
use_constrained_decoding: bool = True,
|
| 1192 |
+
constrained_decoding_debug: bool = False,
|
| 1193 |
+
) -> FormatSampleResult:
|
| 1194 |
+
"""Format user-provided caption and lyrics using the 5Hz Language Model.
|
| 1195 |
+
|
| 1196 |
+
This function takes user input (caption and lyrics) and generates structured
|
| 1197 |
+
music metadata including an enhanced caption, BPM, duration, key, language,
|
| 1198 |
+
and time signature.
|
| 1199 |
+
|
| 1200 |
+
If user_metadata is provided, those values will be used to constrain the
|
| 1201 |
+
decoding, ensuring the output matches user-specified values.
|
| 1202 |
+
|
| 1203 |
+
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 1204 |
+
|
| 1205 |
+
Args:
|
| 1206 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 1207 |
+
caption: User's caption/description (e.g., "Latin pop, reggaeton")
|
| 1208 |
+
lyrics: User's lyrics with structure tags
|
| 1209 |
+
user_metadata: Optional dict with user-provided metadata to constrain decoding.
|
| 1210 |
+
Supported keys: bpm, duration, keyscale, timesignature, language
|
| 1211 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 1212 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 1213 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 1214 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 1215 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 1216 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 1217 |
+
|
| 1218 |
+
Returns:
|
| 1219 |
+
FormatSampleResult with formatted metadata fields and status
|
| 1220 |
+
|
| 1221 |
+
Example:
|
| 1222 |
+
>>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
|
| 1223 |
+
>>> if result.success:
|
| 1224 |
+
... print(f"Caption: {result.caption}")
|
| 1225 |
+
... print(f"BPM: {result.bpm}")
|
| 1226 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 1227 |
+
"""
|
| 1228 |
+
# Check if LLM is initialized
|
| 1229 |
+
if not llm_handler.llm_initialized:
|
| 1230 |
+
return FormatSampleResult(
|
| 1231 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 1232 |
+
success=False,
|
| 1233 |
+
error="LLM not initialized",
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
try:
|
| 1237 |
+
# Call LLM formatting
|
| 1238 |
+
metadata, status = llm_handler.format_sample_from_input(
|
| 1239 |
+
caption=caption,
|
| 1240 |
+
lyrics=lyrics,
|
| 1241 |
+
user_metadata=user_metadata,
|
| 1242 |
+
temperature=temperature,
|
| 1243 |
+
top_k=top_k,
|
| 1244 |
+
top_p=top_p,
|
| 1245 |
+
repetition_penalty=repetition_penalty,
|
| 1246 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1247 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1248 |
+
)
|
| 1249 |
+
|
| 1250 |
+
# Check if LLM returned empty metadata (error case)
|
| 1251 |
+
if not metadata:
|
| 1252 |
+
return FormatSampleResult(
|
| 1253 |
+
status_message=status or "Failed to format input",
|
| 1254 |
+
success=False,
|
| 1255 |
+
error=status or "Empty metadata returned",
|
| 1256 |
+
)
|
| 1257 |
+
|
| 1258 |
+
# Extract and convert fields
|
| 1259 |
+
result_caption = metadata.get('caption', '')
|
| 1260 |
+
result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
|
| 1261 |
+
keyscale = metadata.get('keyscale', '')
|
| 1262 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 1263 |
+
timesignature = metadata.get('timesignature', '')
|
| 1264 |
+
|
| 1265 |
+
# Convert BPM to int
|
| 1266 |
+
bpm = None
|
| 1267 |
+
bpm_value = metadata.get('bpm')
|
| 1268 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 1269 |
+
try:
|
| 1270 |
+
bpm = int(bpm_value)
|
| 1271 |
+
except (ValueError, TypeError):
|
| 1272 |
+
pass
|
| 1273 |
+
|
| 1274 |
+
# Convert duration to float
|
| 1275 |
+
duration = None
|
| 1276 |
+
duration_value = metadata.get('duration')
|
| 1277 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 1278 |
+
try:
|
| 1279 |
+
duration = float(duration_value)
|
| 1280 |
+
except (ValueError, TypeError):
|
| 1281 |
+
pass
|
| 1282 |
+
|
| 1283 |
+
# Clean up N/A values
|
| 1284 |
+
if keyscale == 'N/A':
|
| 1285 |
+
keyscale = ''
|
| 1286 |
+
if language == 'N/A':
|
| 1287 |
+
language = ''
|
| 1288 |
+
if timesignature == 'N/A':
|
| 1289 |
+
timesignature = ''
|
| 1290 |
+
|
| 1291 |
+
return FormatSampleResult(
|
| 1292 |
+
caption=result_caption,
|
| 1293 |
+
lyrics=result_lyrics,
|
| 1294 |
+
bpm=bpm,
|
| 1295 |
+
duration=duration,
|
| 1296 |
+
keyscale=keyscale,
|
| 1297 |
+
language=language,
|
| 1298 |
+
timesignature=timesignature,
|
| 1299 |
+
status_message=status,
|
| 1300 |
+
success=True,
|
| 1301 |
+
error=None,
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
except Exception as e:
|
| 1305 |
+
logger.exception("Format sample failed")
|
| 1306 |
+
return FormatSampleResult(
|
| 1307 |
+
status_message=f"Error: {str(e)}",
|
| 1308 |
+
success=False,
|
| 1309 |
+
error=str(e),
|
| 1310 |
+
)
|
acestep/llm_inference.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/model_downloader.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step Model Downloader
|
| 3 |
+
|
| 4 |
+
This module provides functionality to download models from HuggingFace Hub or ModelScope.
|
| 5 |
+
It supports automatic downloading when models are not found locally,
|
| 6 |
+
with intelligent fallback between download sources.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import argparse
|
| 12 |
+
from typing import Optional, List, Dict, Tuple
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from loguru import logger
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# =============================================================================
|
| 19 |
+
# Network Detection & Smart Download
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
def _can_access_google(timeout: float = 3.0) -> bool:
|
| 23 |
+
"""
|
| 24 |
+
Check if Google is accessible (to determine HuggingFace vs ModelScope).
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
timeout: Connection timeout in seconds
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
True if Google is accessible, False otherwise
|
| 31 |
+
"""
|
| 32 |
+
import socket
|
| 33 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 34 |
+
try:
|
| 35 |
+
sock.settimeout(timeout)
|
| 36 |
+
sock.connect(("www.google.com", 443))
|
| 37 |
+
return True
|
| 38 |
+
except (socket.timeout, socket.error, OSError):
|
| 39 |
+
return False
|
| 40 |
+
finally:
|
| 41 |
+
sock.close()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _download_from_huggingface_internal(
|
| 45 |
+
repo_id: str,
|
| 46 |
+
local_dir: Path,
|
| 47 |
+
token: Optional[str] = None,
|
| 48 |
+
) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Internal function to download from HuggingFace Hub.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
repo_id: HuggingFace repository ID (e.g., "ACE-Step/Ace-Step1.5")
|
| 54 |
+
local_dir: Local directory to save the model
|
| 55 |
+
token: HuggingFace token for private repos (optional)
|
| 56 |
+
|
| 57 |
+
Raises:
|
| 58 |
+
Exception: If download fails
|
| 59 |
+
"""
|
| 60 |
+
from huggingface_hub import snapshot_download
|
| 61 |
+
|
| 62 |
+
logger.info(f"[Model Download] Downloading from HuggingFace: {repo_id} -> {local_dir}")
|
| 63 |
+
|
| 64 |
+
snapshot_download(
|
| 65 |
+
repo_id=repo_id,
|
| 66 |
+
local_dir=str(local_dir),
|
| 67 |
+
local_dir_use_symlinks=False,
|
| 68 |
+
token=token,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _download_from_modelscope_internal(
|
| 73 |
+
repo_id: str,
|
| 74 |
+
local_dir: Path,
|
| 75 |
+
) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Internal function to download from ModelScope.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
repo_id: ModelScope repository ID (e.g., "ACE-Step/Ace-Step1.5")
|
| 81 |
+
local_dir: Local directory to save the model
|
| 82 |
+
|
| 83 |
+
Raises:
|
| 84 |
+
Exception: If download fails
|
| 85 |
+
"""
|
| 86 |
+
from modelscope import snapshot_download
|
| 87 |
+
|
| 88 |
+
logger.info(f"[Model Download] Downloading from ModelScope: {repo_id} -> {local_dir}")
|
| 89 |
+
|
| 90 |
+
snapshot_download(
|
| 91 |
+
model_id=repo_id,
|
| 92 |
+
local_dir=str(local_dir),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _smart_download(
|
| 97 |
+
repo_id: str,
|
| 98 |
+
local_dir: Path,
|
| 99 |
+
token: Optional[str] = None,
|
| 100 |
+
prefer_source: Optional[str] = None,
|
| 101 |
+
) -> Tuple[bool, str]:
|
| 102 |
+
"""
|
| 103 |
+
Smart download with automatic fallback between HuggingFace and ModelScope.
|
| 104 |
+
|
| 105 |
+
Automatically detects network environment and chooses the best download source.
|
| 106 |
+
If the primary source fails, automatically falls back to the alternative.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
repo_id: Repository ID (same format for both HF and ModelScope)
|
| 110 |
+
local_dir: Local directory to save the model
|
| 111 |
+
token: HuggingFace token for private repos (optional)
|
| 112 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Tuple of (success, message)
|
| 116 |
+
"""
|
| 117 |
+
# Ensure directory exists
|
| 118 |
+
local_dir.mkdir(parents=True, exist_ok=True)
|
| 119 |
+
|
| 120 |
+
# Determine primary source
|
| 121 |
+
if prefer_source == "huggingface":
|
| 122 |
+
use_huggingface_first = True
|
| 123 |
+
logger.info("[Model Download] User preference: HuggingFace Hub")
|
| 124 |
+
elif prefer_source == "modelscope":
|
| 125 |
+
use_huggingface_first = False
|
| 126 |
+
logger.info("[Model Download] User preference: ModelScope")
|
| 127 |
+
else:
|
| 128 |
+
# Auto-detect network environment
|
| 129 |
+
can_access_google = _can_access_google()
|
| 130 |
+
use_huggingface_first = can_access_google
|
| 131 |
+
logger.info(f"[Model Download] Auto-detected: {'HuggingFace Hub' if can_access_google else 'ModelScope'}")
|
| 132 |
+
|
| 133 |
+
if use_huggingface_first:
|
| 134 |
+
logger.info("[Model Download] Using HuggingFace Hub...")
|
| 135 |
+
try:
|
| 136 |
+
_download_from_huggingface_internal(repo_id, local_dir, token)
|
| 137 |
+
return True, f"Successfully downloaded from HuggingFace: {repo_id}"
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.warning(f"[Model Download] HuggingFace download failed: {e}")
|
| 140 |
+
logger.info("[Model Download] Falling back to ModelScope...")
|
| 141 |
+
try:
|
| 142 |
+
_download_from_modelscope_internal(repo_id, local_dir)
|
| 143 |
+
return True, f"Successfully downloaded from ModelScope: {repo_id}"
|
| 144 |
+
except Exception as e2:
|
| 145 |
+
error_msg = f"Both HuggingFace and ModelScope downloads failed. HF: {e}, MS: {e2}"
|
| 146 |
+
logger.error(error_msg)
|
| 147 |
+
return False, error_msg
|
| 148 |
+
else:
|
| 149 |
+
logger.info("[Model Download] Using ModelScope...")
|
| 150 |
+
try:
|
| 151 |
+
_download_from_modelscope_internal(repo_id, local_dir)
|
| 152 |
+
return True, f"Successfully downloaded from ModelScope: {repo_id}"
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.warning(f"[Model Download] ModelScope download failed: {e}")
|
| 155 |
+
logger.info("[Model Download] Falling back to HuggingFace Hub...")
|
| 156 |
+
try:
|
| 157 |
+
_download_from_huggingface_internal(repo_id, local_dir, token)
|
| 158 |
+
return True, f"Successfully downloaded from HuggingFace: {repo_id}"
|
| 159 |
+
except Exception as e2:
|
| 160 |
+
error_msg = f"Both ModelScope and HuggingFace downloads failed. MS: {e}, HF: {e2}"
|
| 161 |
+
logger.error(error_msg)
|
| 162 |
+
return False, error_msg
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# =============================================================================
|
| 166 |
+
# Model Registry
|
| 167 |
+
# =============================================================================
|
| 168 |
+
# Main model contains core components (vae, text_encoder, default DiT)
|
| 169 |
+
MAIN_MODEL_REPO = "ACE-Step/Ace-Step1.5"
|
| 170 |
+
|
| 171 |
+
# Sub-models that can be downloaded separately into the checkpoints directory
|
| 172 |
+
SUBMODEL_REGISTRY: Dict[str, str] = {
|
| 173 |
+
# LM models
|
| 174 |
+
"acestep-5Hz-lm-0.6B": "ACE-Step/acestep-5Hz-lm-0.6B",
|
| 175 |
+
"acestep-5Hz-lm-4B": "ACE-Step/acestep-5Hz-lm-4B",
|
| 176 |
+
# DiT models
|
| 177 |
+
"acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3",
|
| 178 |
+
"acestep-v15-sft": "ACE-Step/acestep-v15-sft",
|
| 179 |
+
"acestep-v15-base": "ACE-Step/acestep-v15-base",
|
| 180 |
+
"acestep-v15-turbo-shift1": "ACE-Step/acestep-v15-turbo-shift1",
|
| 181 |
+
"acestep-v15-turbo-continuous": "ACE-Step/acestep-v15-turbo-continuous",
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# Components that come from the main model repo (ACE-Step/Ace-Step1.5)
|
| 185 |
+
MAIN_MODEL_COMPONENTS = [
|
| 186 |
+
"acestep-v15-turbo", # Default DiT model
|
| 187 |
+
"vae", # VAE for audio encoding/decoding
|
| 188 |
+
"Qwen3-Embedding-0.6B", # Text encoder
|
| 189 |
+
"acestep-5Hz-lm-1.7B", # Default LM model (1.7B)
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
# Default LM model (included in main model)
|
| 193 |
+
DEFAULT_LM_MODEL = "acestep-5Hz-lm-1.7B"
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_project_root() -> Path:
|
| 197 |
+
"""Get the project root directory."""
|
| 198 |
+
current_file = Path(__file__).resolve()
|
| 199 |
+
return current_file.parent.parent
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def get_checkpoints_dir(custom_dir: Optional[str] = None) -> Path:
|
| 203 |
+
"""Get the checkpoints directory path."""
|
| 204 |
+
if custom_dir:
|
| 205 |
+
return Path(custom_dir)
|
| 206 |
+
return get_project_root() / "checkpoints"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def check_main_model_exists(checkpoints_dir: Optional[Path] = None) -> bool:
|
| 210 |
+
"""
|
| 211 |
+
Check if the main model components exist in the checkpoints directory.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
True if all main model components exist, False otherwise.
|
| 215 |
+
"""
|
| 216 |
+
if checkpoints_dir is None:
|
| 217 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 218 |
+
|
| 219 |
+
for component in MAIN_MODEL_COMPONENTS:
|
| 220 |
+
component_path = checkpoints_dir / component
|
| 221 |
+
if not component_path.exists():
|
| 222 |
+
return False
|
| 223 |
+
return True
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def check_model_exists(model_name: str, checkpoints_dir: Optional[Path] = None) -> bool:
|
| 227 |
+
"""
|
| 228 |
+
Check if a specific model exists in the checkpoints directory.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
model_name: Name of the model to check
|
| 232 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
True if the model exists, False otherwise.
|
| 236 |
+
"""
|
| 237 |
+
if checkpoints_dir is None:
|
| 238 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 239 |
+
|
| 240 |
+
model_path = checkpoints_dir / model_name
|
| 241 |
+
return model_path.exists()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def list_available_models() -> Dict[str, str]:
|
| 245 |
+
"""
|
| 246 |
+
List all available models for download.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Dictionary mapping local names to HuggingFace repo IDs.
|
| 250 |
+
"""
|
| 251 |
+
models = {
|
| 252 |
+
"main": MAIN_MODEL_REPO,
|
| 253 |
+
**SUBMODEL_REGISTRY
|
| 254 |
+
}
|
| 255 |
+
return models
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def download_main_model(
|
| 259 |
+
checkpoints_dir: Optional[Path] = None,
|
| 260 |
+
force: bool = False,
|
| 261 |
+
token: Optional[str] = None,
|
| 262 |
+
prefer_source: Optional[str] = None,
|
| 263 |
+
) -> Tuple[bool, str]:
|
| 264 |
+
"""
|
| 265 |
+
Download the main ACE-Step model from HuggingFace or ModelScope.
|
| 266 |
+
|
| 267 |
+
The main model includes:
|
| 268 |
+
- acestep-v15-turbo (default DiT model)
|
| 269 |
+
- vae (audio encoder/decoder)
|
| 270 |
+
- Qwen3-Embedding-0.6B (text encoder)
|
| 271 |
+
- acestep-5Hz-lm-1.7B (default LM model)
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 275 |
+
force: Force re-download even if model exists
|
| 276 |
+
token: HuggingFace token for private repos (optional)
|
| 277 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
Tuple of (success, message)
|
| 281 |
+
"""
|
| 282 |
+
if checkpoints_dir is None:
|
| 283 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 284 |
+
|
| 285 |
+
# Ensure checkpoints directory exists
|
| 286 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
if not force and check_main_model_exists(checkpoints_dir):
|
| 289 |
+
return True, f"Main model already exists at {checkpoints_dir}"
|
| 290 |
+
|
| 291 |
+
print(f"Downloading main model from {MAIN_MODEL_REPO}...")
|
| 292 |
+
print(f"Destination: {checkpoints_dir}")
|
| 293 |
+
print("This may take a while depending on your internet connection...")
|
| 294 |
+
|
| 295 |
+
# Use smart download with automatic fallback
|
| 296 |
+
return _smart_download(MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def download_submodel(
|
| 300 |
+
model_name: str,
|
| 301 |
+
checkpoints_dir: Optional[Path] = None,
|
| 302 |
+
force: bool = False,
|
| 303 |
+
token: Optional[str] = None,
|
| 304 |
+
prefer_source: Optional[str] = None,
|
| 305 |
+
) -> Tuple[bool, str]:
|
| 306 |
+
"""
|
| 307 |
+
Download a specific sub-model from HuggingFace or ModelScope.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
model_name: Name of the model to download (must be in SUBMODEL_REGISTRY)
|
| 311 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 312 |
+
force: Force re-download even if model exists
|
| 313 |
+
token: HuggingFace token for private repos (optional)
|
| 314 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Tuple of (success, message)
|
| 318 |
+
"""
|
| 319 |
+
if model_name not in SUBMODEL_REGISTRY:
|
| 320 |
+
available = ", ".join(SUBMODEL_REGISTRY.keys())
|
| 321 |
+
return False, f"Unknown model '{model_name}'. Available models: {available}"
|
| 322 |
+
|
| 323 |
+
if checkpoints_dir is None:
|
| 324 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 325 |
+
|
| 326 |
+
# Ensure checkpoints directory exists
|
| 327 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 328 |
+
|
| 329 |
+
model_path = checkpoints_dir / model_name
|
| 330 |
+
|
| 331 |
+
if not force and model_path.exists():
|
| 332 |
+
return True, f"Model '{model_name}' already exists at {model_path}"
|
| 333 |
+
|
| 334 |
+
repo_id = SUBMODEL_REGISTRY[model_name]
|
| 335 |
+
|
| 336 |
+
print(f"Downloading {model_name} from {repo_id}...")
|
| 337 |
+
print(f"Destination: {model_path}")
|
| 338 |
+
|
| 339 |
+
# Use smart download with automatic fallback
|
| 340 |
+
return _smart_download(repo_id, model_path, token, prefer_source)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def download_all_models(
|
| 344 |
+
checkpoints_dir: Optional[Path] = None,
|
| 345 |
+
force: bool = False,
|
| 346 |
+
token: Optional[str] = None,
|
| 347 |
+
) -> Tuple[bool, List[str]]:
|
| 348 |
+
"""
|
| 349 |
+
Download all available models.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 353 |
+
force: Force re-download even if models exist
|
| 354 |
+
token: HuggingFace token for private repos (optional)
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Tuple of (all_success, list of messages)
|
| 358 |
+
"""
|
| 359 |
+
if checkpoints_dir is None:
|
| 360 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 361 |
+
|
| 362 |
+
messages = []
|
| 363 |
+
all_success = True
|
| 364 |
+
|
| 365 |
+
# Download main model first
|
| 366 |
+
success, msg = download_main_model(checkpoints_dir, force, token)
|
| 367 |
+
messages.append(msg)
|
| 368 |
+
if not success:
|
| 369 |
+
all_success = False
|
| 370 |
+
|
| 371 |
+
# Download all sub-models
|
| 372 |
+
for model_name in SUBMODEL_REGISTRY:
|
| 373 |
+
success, msg = download_submodel(model_name, checkpoints_dir, force, token)
|
| 374 |
+
messages.append(msg)
|
| 375 |
+
if not success:
|
| 376 |
+
all_success = False
|
| 377 |
+
|
| 378 |
+
return all_success, messages
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def ensure_main_model(
|
| 382 |
+
checkpoints_dir: Optional[Path] = None,
|
| 383 |
+
token: Optional[str] = None,
|
| 384 |
+
prefer_source: Optional[str] = None,
|
| 385 |
+
) -> Tuple[bool, str]:
|
| 386 |
+
"""
|
| 387 |
+
Ensure the main model is available, downloading if necessary.
|
| 388 |
+
|
| 389 |
+
This function is designed to be called during initialization.
|
| 390 |
+
It will only download if the model doesn't exist.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 394 |
+
token: HuggingFace token for private repos (optional)
|
| 395 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
Tuple of (success, message)
|
| 399 |
+
"""
|
| 400 |
+
if checkpoints_dir is None:
|
| 401 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 402 |
+
|
| 403 |
+
if check_main_model_exists(checkpoints_dir):
|
| 404 |
+
return True, "Main model is available"
|
| 405 |
+
|
| 406 |
+
print("\n" + "=" * 60)
|
| 407 |
+
print("Main model not found. Starting automatic download...")
|
| 408 |
+
print("=" * 60 + "\n")
|
| 409 |
+
|
| 410 |
+
return download_main_model(checkpoints_dir, token=token, prefer_source=prefer_source)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def ensure_lm_model(
|
| 414 |
+
model_name: Optional[str] = None,
|
| 415 |
+
checkpoints_dir: Optional[Path] = None,
|
| 416 |
+
token: Optional[str] = None,
|
| 417 |
+
prefer_source: Optional[str] = None,
|
| 418 |
+
) -> Tuple[bool, str]:
|
| 419 |
+
"""
|
| 420 |
+
Ensure an LM model is available, downloading if necessary.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
model_name: Name of the LM model (defaults to DEFAULT_LM_MODEL)
|
| 424 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 425 |
+
token: HuggingFace token for private repos (optional)
|
| 426 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Tuple of (success, message)
|
| 430 |
+
"""
|
| 431 |
+
if model_name is None:
|
| 432 |
+
model_name = DEFAULT_LM_MODEL
|
| 433 |
+
|
| 434 |
+
if checkpoints_dir is None:
|
| 435 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 436 |
+
|
| 437 |
+
if check_model_exists(model_name, checkpoints_dir):
|
| 438 |
+
return True, f"LM model '{model_name}' is available"
|
| 439 |
+
|
| 440 |
+
# Check if this is a known LM model
|
| 441 |
+
if model_name not in SUBMODEL_REGISTRY:
|
| 442 |
+
# Check if it might be a variant name
|
| 443 |
+
for known_model in SUBMODEL_REGISTRY:
|
| 444 |
+
if "lm" in known_model.lower() and model_name.lower() in known_model.lower():
|
| 445 |
+
model_name = known_model
|
| 446 |
+
break
|
| 447 |
+
else:
|
| 448 |
+
return False, f"Unknown LM model: {model_name}"
|
| 449 |
+
|
| 450 |
+
print("\n" + "=" * 60)
|
| 451 |
+
print(f"LM model '{model_name}' not found. Starting automatic download...")
|
| 452 |
+
print("=" * 60 + "\n")
|
| 453 |
+
|
| 454 |
+
return download_submodel(model_name, checkpoints_dir, token=token, prefer_source=prefer_source)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def ensure_dit_model(
|
| 458 |
+
model_name: str,
|
| 459 |
+
checkpoints_dir: Optional[Path] = None,
|
| 460 |
+
token: Optional[str] = None,
|
| 461 |
+
prefer_source: Optional[str] = None,
|
| 462 |
+
) -> Tuple[bool, str]:
|
| 463 |
+
"""
|
| 464 |
+
Ensure a DiT model is available, downloading if necessary.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
model_name: Name of the DiT model
|
| 468 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 469 |
+
token: HuggingFace token for private repos (optional)
|
| 470 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
Tuple of (success, message)
|
| 474 |
+
"""
|
| 475 |
+
if checkpoints_dir is None:
|
| 476 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 477 |
+
|
| 478 |
+
if check_model_exists(model_name, checkpoints_dir):
|
| 479 |
+
return True, f"DiT model '{model_name}' is available"
|
| 480 |
+
|
| 481 |
+
# Check if this is the default turbo model (part of main)
|
| 482 |
+
if model_name == "acestep-v15-turbo":
|
| 483 |
+
return ensure_main_model(checkpoints_dir, token, prefer_source)
|
| 484 |
+
|
| 485 |
+
# Check if it's a known sub-model
|
| 486 |
+
if model_name in SUBMODEL_REGISTRY:
|
| 487 |
+
print("\n" + "=" * 60)
|
| 488 |
+
print(f"DiT model '{model_name}' not found. Starting automatic download...")
|
| 489 |
+
print("=" * 60 + "\n")
|
| 490 |
+
return download_submodel(model_name, checkpoints_dir, token=token, prefer_source=prefer_source)
|
| 491 |
+
|
| 492 |
+
return False, f"Unknown DiT model: {model_name}"
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def print_model_list():
|
| 496 |
+
"""Print formatted list of available models."""
|
| 497 |
+
print("\nAvailable Models for Download:")
|
| 498 |
+
print("=" * 60)
|
| 499 |
+
print("\nSupported Sources: HuggingFace Hub <-> ModelScope (auto-fallback)")
|
| 500 |
+
|
| 501 |
+
print("\n[Main Model]")
|
| 502 |
+
print(f" main -> {MAIN_MODEL_REPO}")
|
| 503 |
+
print(" Contains: vae, Qwen3-Embedding-0.6B, acestep-v15-turbo, acestep-5Hz-lm-1.7B")
|
| 504 |
+
|
| 505 |
+
print("\n[Optional LM Models]")
|
| 506 |
+
for name, repo in SUBMODEL_REGISTRY.items():
|
| 507 |
+
if "lm" in name.lower():
|
| 508 |
+
print(f" {name} -> {repo}")
|
| 509 |
+
|
| 510 |
+
print("\n[Optional DiT Models]")
|
| 511 |
+
for name, repo in SUBMODEL_REGISTRY.items():
|
| 512 |
+
if "lm" not in name.lower():
|
| 513 |
+
print(f" {name} -> {repo}")
|
| 514 |
+
|
| 515 |
+
print("\n" + "=" * 60)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def main():
|
| 519 |
+
"""CLI entry point for model downloading."""
|
| 520 |
+
parser = argparse.ArgumentParser(
|
| 521 |
+
description="Download ACE-Step models with automatic fallback (HuggingFace <-> ModelScope)",
|
| 522 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 523 |
+
epilog="""
|
| 524 |
+
Examples:
|
| 525 |
+
acestep-download # Download main model (includes LM 1.7B)
|
| 526 |
+
acestep-download --all # Download all available models
|
| 527 |
+
acestep-download --model acestep-v15-sft # Download a specific model
|
| 528 |
+
acestep-download --list # List all available models
|
| 529 |
+
|
| 530 |
+
Network Detection:
|
| 531 |
+
Automatically detects network environment and chooses the best download source:
|
| 532 |
+
- Google accessible -> HuggingFace (fallback to ModelScope)
|
| 533 |
+
- Google blocked -> ModelScope (fallback to HuggingFace)
|
| 534 |
+
|
| 535 |
+
Alternative using huggingface-cli:
|
| 536 |
+
huggingface-cli download ACE-Step/Ace-Step1.5 --local-dir ./checkpoints
|
| 537 |
+
huggingface-cli download ACE-Step/acestep-5Hz-lm-0.6B --local-dir ./checkpoints/acestep-5Hz-lm-0.6B
|
| 538 |
+
"""
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
parser.add_argument(
|
| 542 |
+
"--model", "-m",
|
| 543 |
+
type=str,
|
| 544 |
+
help="Specific model to download (use --list to see available models)"
|
| 545 |
+
)
|
| 546 |
+
parser.add_argument(
|
| 547 |
+
"--all", "-a",
|
| 548 |
+
action="store_true",
|
| 549 |
+
help="Download all available models"
|
| 550 |
+
)
|
| 551 |
+
parser.add_argument(
|
| 552 |
+
"--list", "-l",
|
| 553 |
+
action="store_true",
|
| 554 |
+
help="List all available models"
|
| 555 |
+
)
|
| 556 |
+
parser.add_argument(
|
| 557 |
+
"--dir", "-d",
|
| 558 |
+
type=str,
|
| 559 |
+
default=None,
|
| 560 |
+
help="Custom checkpoints directory (default: ./checkpoints)"
|
| 561 |
+
)
|
| 562 |
+
parser.add_argument(
|
| 563 |
+
"--force", "-f",
|
| 564 |
+
action="store_true",
|
| 565 |
+
help="Force re-download even if model exists"
|
| 566 |
+
)
|
| 567 |
+
parser.add_argument(
|
| 568 |
+
"--token", "-t",
|
| 569 |
+
type=str,
|
| 570 |
+
default=None,
|
| 571 |
+
help="HuggingFace token for private repos"
|
| 572 |
+
)
|
| 573 |
+
parser.add_argument(
|
| 574 |
+
"--skip-main",
|
| 575 |
+
action="store_true",
|
| 576 |
+
help="Skip downloading the main model (only download specified sub-model)"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
args = parser.parse_args()
|
| 580 |
+
|
| 581 |
+
# Handle --list
|
| 582 |
+
if args.list:
|
| 583 |
+
print_model_list()
|
| 584 |
+
return 0
|
| 585 |
+
|
| 586 |
+
# Get checkpoints directory
|
| 587 |
+
checkpoints_dir = get_checkpoints_dir(args.dir) if args.dir else get_checkpoints_dir()
|
| 588 |
+
print(f"Checkpoints directory: {checkpoints_dir}")
|
| 589 |
+
|
| 590 |
+
# Handle --all
|
| 591 |
+
if args.all:
|
| 592 |
+
success, messages = download_all_models(checkpoints_dir, args.force, args.token)
|
| 593 |
+
for msg in messages:
|
| 594 |
+
print(msg)
|
| 595 |
+
return 0 if success else 1
|
| 596 |
+
|
| 597 |
+
# Handle --model
|
| 598 |
+
if args.model:
|
| 599 |
+
if args.model == "main":
|
| 600 |
+
success, msg = download_main_model(checkpoints_dir, args.force, args.token)
|
| 601 |
+
elif args.model in SUBMODEL_REGISTRY:
|
| 602 |
+
# Download main model first if needed (unless --skip-main)
|
| 603 |
+
if not args.skip_main and not check_main_model_exists(checkpoints_dir):
|
| 604 |
+
print("Main model not found. Downloading main model first...")
|
| 605 |
+
main_success, main_msg = download_main_model(checkpoints_dir, args.force, args.token)
|
| 606 |
+
print(main_msg)
|
| 607 |
+
if not main_success:
|
| 608 |
+
return 1
|
| 609 |
+
|
| 610 |
+
success, msg = download_submodel(args.model, checkpoints_dir, args.force, args.token)
|
| 611 |
+
else:
|
| 612 |
+
print(f"Unknown model: {args.model}")
|
| 613 |
+
print("Use --list to see available models")
|
| 614 |
+
return 1
|
| 615 |
+
|
| 616 |
+
print(msg)
|
| 617 |
+
return 0 if success else 1
|
| 618 |
+
|
| 619 |
+
# Default: download main model (includes default LM 1.7B)
|
| 620 |
+
print("Downloading main model (includes vae, text encoder, DiT, and LM 1.7B)...")
|
| 621 |
+
|
| 622 |
+
# Download main model
|
| 623 |
+
success, msg = download_main_model(checkpoints_dir, args.force, args.token)
|
| 624 |
+
print(msg)
|
| 625 |
+
|
| 626 |
+
if success:
|
| 627 |
+
print("\nDownload complete!")
|
| 628 |
+
print(f"Models are available at: {checkpoints_dir}")
|
| 629 |
+
|
| 630 |
+
return 0 if success else 1
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
if __name__ == "__main__":
|
| 634 |
+
sys.exit(main())
|
handler.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
| 1 |
# handler.py
|
| 2 |
import base64
|
| 3 |
-
import inspect
|
| 4 |
import io
|
| 5 |
import os
|
| 6 |
import traceback
|
| 7 |
-
from typing import Any, Dict, Tuple
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
import soundfile as sf
|
| 11 |
|
| 12 |
-
# Optional torch import for dtype/device handling
|
| 13 |
try:
|
| 14 |
import torch
|
| 15 |
except Exception:
|
|
@@ -20,7 +18,7 @@ class EndpointHandler:
|
|
| 20 |
"""
|
| 21 |
Hugging Face Inference Endpoints custom handler for ACE-Step 1.5.
|
| 22 |
|
| 23 |
-
|
| 24 |
{
|
| 25 |
"inputs": {
|
| 26 |
"prompt": "upbeat pop rap, emotional guitar",
|
|
@@ -29,130 +27,144 @@ class EndpointHandler:
|
|
| 29 |
"sample_rate": 44100,
|
| 30 |
"seed": 42,
|
| 31 |
"guidance_scale": 7.0,
|
| 32 |
-
"steps":
|
| 33 |
"use_lm": true,
|
| 34 |
"simple_prompt": false,
|
| 35 |
-
"
|
|
|
|
| 36 |
}
|
| 37 |
}
|
| 38 |
|
| 39 |
-
|
| 40 |
{
|
| 41 |
"inputs": "upbeat pop rap with emotional guitar"
|
| 42 |
}
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
"sample_rate": 44100,
|
| 48 |
-
"duration_sec": 12,
|
| 49 |
-
"used_fallback": false,
|
| 50 |
-
"model_loaded": true,
|
| 51 |
-
"model_error": null,
|
| 52 |
-
"meta": {...}
|
| 53 |
-
}
|
| 54 |
"""
|
| 55 |
|
| 56 |
def __init__(self, path: str = ""):
|
| 57 |
self.path = path
|
| 58 |
-
self.
|
| 59 |
-
|
| 60 |
self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
self.default_sr = int(os.getenv("DEFAULT_SAMPLE_RATE", "44100"))
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
# Runtime knobs
|
| 64 |
self.device = "cuda" if (torch is not None and torch.cuda.is_available()) else "cpu"
|
| 65 |
self.dtype = "float16" if self.device == "cuda" else "float32"
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# --------------------------
|
| 72 |
-
# Initialization
|
| 73 |
# --------------------------
|
| 74 |
def _init_model(self) -> None:
|
| 75 |
err_msgs = []
|
| 76 |
|
| 77 |
-
# Strategy A: class/factory in acestep.acestep_v15_pipeline
|
| 78 |
try:
|
| 79 |
-
from acestep import
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
cls = getattr(m, "AceStepV15Pipeline")
|
| 86 |
-
if hasattr(cls, "from_pretrained"):
|
| 87 |
-
self.model = cls.from_pretrained(self.model_repo)
|
| 88 |
-
else:
|
| 89 |
-
self.model = cls(model_path=self.model_repo)
|
| 90 |
-
elif hasattr(m, "Pipeline"):
|
| 91 |
-
cls = getattr(m, "Pipeline")
|
| 92 |
-
if hasattr(cls, "from_pretrained"):
|
| 93 |
-
self.model = cls.from_pretrained(self.model_repo)
|
| 94 |
-
else:
|
| 95 |
-
self.model = cls(self.model_repo)
|
| 96 |
-
else:
|
| 97 |
-
raise RuntimeError("No known pipeline class/factory found in acestep_v15_pipeline")
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
except Exception:
|
| 104 |
-
pass
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
except Exception as e:
|
| 108 |
-
err_msgs.append(f"
|
| 109 |
|
| 110 |
-
# Strategy B: import root `acestep` and find a likely pipeline symbol
|
| 111 |
try:
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
"AceStepV15Pipeline",
|
| 116 |
-
"AceStepPipeline",
|
| 117 |
-
"Pipeline",
|
| 118 |
-
"create_pipeline",
|
| 119 |
-
"build_pipeline",
|
| 120 |
-
"load_pipeline",
|
| 121 |
-
]
|
| 122 |
-
|
| 123 |
-
obj = None
|
| 124 |
-
for name in candidates:
|
| 125 |
-
if hasattr(acestep, name):
|
| 126 |
-
obj = getattr(acestep, name)
|
| 127 |
-
break
|
| 128 |
-
|
| 129 |
-
if obj is None:
|
| 130 |
-
raise RuntimeError("No known pipeline symbol found in `acestep` package")
|
| 131 |
-
|
| 132 |
-
if callable(obj):
|
| 133 |
-
# class or factory
|
| 134 |
-
if hasattr(obj, "from_pretrained"):
|
| 135 |
-
self.model = obj.from_pretrained(self.model_repo)
|
| 136 |
-
else:
|
| 137 |
-
# try keyword variants
|
| 138 |
-
try:
|
| 139 |
-
self.model = obj(model_path=self.model_repo)
|
| 140 |
-
except TypeError:
|
| 141 |
-
self.model = obj(self.model_repo)
|
| 142 |
-
else:
|
| 143 |
-
self.model = obj
|
| 144 |
-
|
| 145 |
-
if self.model is not None and hasattr(self.model, "to"):
|
| 146 |
-
try:
|
| 147 |
-
self.model.to(self.device)
|
| 148 |
-
except Exception:
|
| 149 |
-
pass
|
| 150 |
-
return
|
| 151 |
except Exception as e:
|
| 152 |
-
err_msgs.append(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# --------------------------
|
| 158 |
# Audio helpers
|
|
@@ -166,7 +178,6 @@ class EndpointHandler:
|
|
| 166 |
else:
|
| 167 |
arr = np.asarray(audio)
|
| 168 |
|
| 169 |
-
# Convert common tensor shape [channels, samples] to [samples, channels].
|
| 170 |
if arr.ndim == 2 and arr.shape[0] in (1, 2) and arr.shape[1] > arr.shape[0]:
|
| 171 |
arr = arr.T
|
| 172 |
|
|
@@ -188,6 +199,9 @@ class EndpointHandler:
|
|
| 188 |
y = (0.07 * np.sin(2 * np.pi * 440 * t) + 0.01 * rng.standard_normal(len(t))).astype(np.float32)
|
| 189 |
return np.clip(y, -1.0, 1.0)
|
| 190 |
|
|
|
|
|
|
|
|
|
|
| 191 |
@staticmethod
|
| 192 |
def _to_bool(value: Any, default: bool = False) -> bool:
|
| 193 |
if value is None:
|
|
@@ -242,39 +256,18 @@ class EndpointHandler:
|
|
| 242 |
if not lyrics and (instrumental or simple_prompt):
|
| 243 |
lyrics = "[Instrumental]"
|
| 244 |
|
| 245 |
-
duration_sec = self._to_int(raw_inputs.get("duration_sec", raw_inputs.get("duration",
|
| 246 |
-
duration_sec = max(
|
| 247 |
|
| 248 |
sample_rate = self._to_int(raw_inputs.get("sample_rate", self.default_sr), self.default_sr)
|
| 249 |
sample_rate = max(8000, min(sample_rate, 48000))
|
| 250 |
|
| 251 |
seed = self._to_int(raw_inputs.get("seed", 42), 42)
|
| 252 |
guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
|
| 253 |
-
steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps",
|
| 254 |
-
steps = max(1, min(steps,
|
| 255 |
use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
model_repo = raw_inputs.get("model_repo")
|
| 259 |
-
|
| 260 |
-
model_kwargs = {
|
| 261 |
-
"task_type": task_type,
|
| 262 |
-
"prompt": prompt,
|
| 263 |
-
"caption": prompt,
|
| 264 |
-
"query": prompt,
|
| 265 |
-
"lyrics": lyrics,
|
| 266 |
-
"duration_sec": duration_sec,
|
| 267 |
-
"duration": duration_sec,
|
| 268 |
-
"sample_rate": sample_rate,
|
| 269 |
-
"seed": seed,
|
| 270 |
-
"guidance_scale": guidance_scale,
|
| 271 |
-
"steps": steps,
|
| 272 |
-
"inference_steps": steps,
|
| 273 |
-
"num_inference_steps": steps,
|
| 274 |
-
"use_lm": use_lm,
|
| 275 |
-
"thinking": use_lm,
|
| 276 |
-
"instrumental": instrumental,
|
| 277 |
-
}
|
| 278 |
|
| 279 |
return {
|
| 280 |
"prompt": prompt,
|
|
@@ -287,165 +280,144 @@ class EndpointHandler:
|
|
| 287 |
"use_lm": use_lm,
|
| 288 |
"instrumental": instrumental,
|
| 289 |
"simple_prompt": simple_prompt,
|
| 290 |
-
"
|
| 291 |
-
"model_kwargs": model_kwargs,
|
| 292 |
}
|
| 293 |
|
| 294 |
-
@staticmethod
|
| 295 |
-
def _invoke_with_supported_kwargs(fn: Any, kwargs: Dict[str, Any]) -> Any:
|
| 296 |
-
try:
|
| 297 |
-
sig = inspect.signature(fn)
|
| 298 |
-
has_var_kw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
|
| 299 |
-
if has_var_kw:
|
| 300 |
-
return fn(**kwargs)
|
| 301 |
-
accepted = {
|
| 302 |
-
name
|
| 303 |
-
for name, p in sig.parameters.items()
|
| 304 |
-
if p.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
|
| 305 |
-
}
|
| 306 |
-
filtered = {k: v for k, v in kwargs.items() if k in accepted}
|
| 307 |
-
return fn(**filtered)
|
| 308 |
-
except Exception:
|
| 309 |
-
# Fallback for C-extension callables or dynamic signatures.
|
| 310 |
-
return fn(**kwargs)
|
| 311 |
-
|
| 312 |
-
def _normalize_model_output(self, out: Any, default_sr: int) -> Tuple[np.ndarray, int]:
|
| 313 |
-
if out is None:
|
| 314 |
-
raise RuntimeError("Model returned None")
|
| 315 |
-
|
| 316 |
-
if hasattr(out, "success") and not getattr(out, "success"):
|
| 317 |
-
err = getattr(out, "error", "unknown model error")
|
| 318 |
-
raise RuntimeError(str(err))
|
| 319 |
-
|
| 320 |
-
if hasattr(out, "audios"):
|
| 321 |
-
audios = getattr(out, "audios") or []
|
| 322 |
-
if not audios:
|
| 323 |
-
raise RuntimeError("Model result has no audios")
|
| 324 |
-
first = audios[0]
|
| 325 |
-
if isinstance(first, dict):
|
| 326 |
-
audio = first.get("tensor", first.get("audio", first.get("waveform", first.get("wav"))))
|
| 327 |
-
sr = first.get("sample_rate", default_sr)
|
| 328 |
-
else:
|
| 329 |
-
audio = getattr(first, "tensor", getattr(first, "audio", None))
|
| 330 |
-
sr = getattr(first, "sample_rate", default_sr)
|
| 331 |
-
if audio is None:
|
| 332 |
-
raise RuntimeError("Model result audio entry is missing tensor/audio")
|
| 333 |
-
return self._as_float32(audio), int(sr)
|
| 334 |
-
|
| 335 |
-
if isinstance(out, tuple) and len(out) >= 1:
|
| 336 |
-
audio = out[0]
|
| 337 |
-
sr = int(out[1]) if len(out) > 1 and out[1] is not None else default_sr
|
| 338 |
-
return self._as_float32(audio), sr
|
| 339 |
-
|
| 340 |
-
if isinstance(out, dict):
|
| 341 |
-
if "audios" in out:
|
| 342 |
-
audios = out.get("audios") or []
|
| 343 |
-
if not audios:
|
| 344 |
-
raise RuntimeError("Model output `audios` is empty")
|
| 345 |
-
first = audios[0]
|
| 346 |
-
if not isinstance(first, dict):
|
| 347 |
-
raise RuntimeError("Model output `audios[0]` must be a dict")
|
| 348 |
-
audio = first.get("tensor", first.get("audio", first.get("waveform", first.get("wav"))))
|
| 349 |
-
sr = first.get("sample_rate", default_sr)
|
| 350 |
-
if audio is None:
|
| 351 |
-
raise RuntimeError("Model output `audios[0]` missing tensor/audio")
|
| 352 |
-
return self._as_float32(audio), int(sr)
|
| 353 |
-
|
| 354 |
-
audio = out.get("audio", out.get("waveform", out.get("wav", out.get("tensor"))))
|
| 355 |
-
sr = out.get("sample_rate", out.get("sr", default_sr))
|
| 356 |
-
if audio is None:
|
| 357 |
-
raise RuntimeError("Model dict output missing audio/waveform field")
|
| 358 |
-
return self._as_float32(audio), int(sr)
|
| 359 |
-
|
| 360 |
-
for name in ("audio", "waveform", "wav", "tensor"):
|
| 361 |
-
if hasattr(out, name):
|
| 362 |
-
audio = getattr(out, name)
|
| 363 |
-
if audio is not None:
|
| 364 |
-
sr = getattr(out, "sample_rate", getattr(out, "sr", default_sr))
|
| 365 |
-
return self._as_float32(audio), int(sr)
|
| 366 |
-
|
| 367 |
-
return self._as_float32(out), default_sr
|
| 368 |
-
|
| 369 |
# --------------------------
|
| 370 |
-
#
|
| 371 |
# --------------------------
|
| 372 |
-
def
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
sample_rate: int,
|
| 376 |
-
) -> Tuple[np.ndarray, int]:
|
| 377 |
-
"""
|
| 378 |
-
Tries multiple invocation styles to tolerate minor ACE-Step API differences.
|
| 379 |
-
Returns (audio_np, sample_rate).
|
| 380 |
-
"""
|
| 381 |
-
if self.model is None:
|
| 382 |
-
raise RuntimeError("Model is not loaded")
|
| 383 |
-
|
| 384 |
-
# Common callable entrypoints
|
| 385 |
-
methods = [
|
| 386 |
-
"__call__",
|
| 387 |
-
"generate",
|
| 388 |
-
"infer",
|
| 389 |
-
"inference",
|
| 390 |
-
"text_to_music",
|
| 391 |
-
"run",
|
| 392 |
-
]
|
| 393 |
-
|
| 394 |
-
last_err = None
|
| 395 |
-
for m in methods:
|
| 396 |
-
try:
|
| 397 |
-
fn = self.model if m == "__call__" else getattr(self.model, m, None)
|
| 398 |
-
if fn is None:
|
| 399 |
-
continue
|
| 400 |
-
|
| 401 |
-
# Try full kwargs
|
| 402 |
-
try:
|
| 403 |
-
out = self._invoke_with_supported_kwargs(fn, model_kwargs)
|
| 404 |
-
except TypeError:
|
| 405 |
-
# Narrow payload if signature is strict
|
| 406 |
-
skinny = {
|
| 407 |
-
"prompt": model_kwargs.get("prompt"),
|
| 408 |
-
"caption": model_kwargs.get("caption"),
|
| 409 |
-
"lyrics": model_kwargs.get("lyrics"),
|
| 410 |
-
"duration": model_kwargs.get("duration"),
|
| 411 |
-
"seed": model_kwargs.get("seed"),
|
| 412 |
-
}
|
| 413 |
-
skinny = {k: v for k, v in skinny.items() if v is not None and (k != "prompt" or str(v).strip())}
|
| 414 |
-
out = self._invoke_with_supported_kwargs(fn, skinny)
|
| 415 |
-
|
| 416 |
-
return self._normalize_model_output(out, sample_rate)
|
| 417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
except Exception as e:
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 425 |
try:
|
| 426 |
req = self._normalize_request(data)
|
| 427 |
|
| 428 |
-
# Optional override
|
| 429 |
-
model_repo = req.get("model_repo")
|
| 430 |
-
if model_repo and model_repo != self.model_repo:
|
| 431 |
-
# hot-switch model only if user asks
|
| 432 |
-
self.model_repo = str(model_repo)
|
| 433 |
-
self._init_model()
|
| 434 |
-
|
| 435 |
used_fallback = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
-
if self.model is not None:
|
| 438 |
-
try:
|
| 439 |
-
audio, out_sr = self._call_model(
|
| 440 |
-
model_kwargs=req["model_kwargs"],
|
| 441 |
-
sample_rate=req["sample_rate"],
|
| 442 |
-
)
|
| 443 |
-
except Exception as e:
|
| 444 |
-
used_fallback = True
|
| 445 |
-
self.model_error = f"Inference failed: {type(e).__name__}: {e}"
|
| 446 |
-
audio = self._fallback_sine(req["duration_sec"], req["sample_rate"], req["seed"])
|
| 447 |
-
out_sr = req["sample_rate"]
|
| 448 |
-
else:
|
| 449 |
used_fallback = True
|
| 450 |
audio = self._fallback_sine(req["duration_sec"], req["sample_rate"], req["seed"])
|
| 451 |
out_sr = req["sample_rate"]
|
|
@@ -455,7 +427,7 @@ class EndpointHandler:
|
|
| 455 |
"sample_rate": int(out_sr),
|
| 456 |
"duration_sec": int(req["duration_sec"]),
|
| 457 |
"used_fallback": used_fallback,
|
| 458 |
-
"model_loaded": self.
|
| 459 |
"model_repo": self.model_repo,
|
| 460 |
"model_error": self.model_error,
|
| 461 |
"meta": {
|
|
@@ -469,16 +441,34 @@ class EndpointHandler:
|
|
| 469 |
"use_lm": req["use_lm"],
|
| 470 |
"simple_prompt": req["simple_prompt"],
|
| 471 |
"instrumental": req["instrumental"],
|
| 472 |
-
"
|
| 473 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
},
|
| 475 |
}
|
| 476 |
|
| 477 |
except Exception as e:
|
| 478 |
return {
|
| 479 |
"error": f"{type(e).__name__}: {e}",
|
| 480 |
-
"traceback": traceback.format_exc(limit=
|
| 481 |
"audio_base64_wav": None,
|
| 482 |
"sample_rate": None,
|
| 483 |
"duration_sec": None,
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# handler.py
|
| 2 |
import base64
|
|
|
|
| 3 |
import io
|
| 4 |
import os
|
| 5 |
import traceback
|
| 6 |
+
from typing import Any, Dict, Optional, Tuple
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import soundfile as sf
|
| 10 |
|
|
|
|
| 11 |
try:
|
| 12 |
import torch
|
| 13 |
except Exception:
|
|
|
|
| 18 |
"""
|
| 19 |
Hugging Face Inference Endpoints custom handler for ACE-Step 1.5.
|
| 20 |
|
| 21 |
+
Supported request shapes:
|
| 22 |
{
|
| 23 |
"inputs": {
|
| 24 |
"prompt": "upbeat pop rap, emotional guitar",
|
|
|
|
| 27 |
"sample_rate": 44100,
|
| 28 |
"seed": 42,
|
| 29 |
"guidance_scale": 7.0,
|
| 30 |
+
"steps": 8,
|
| 31 |
"use_lm": true,
|
| 32 |
"simple_prompt": false,
|
| 33 |
+
"instrumental": false,
|
| 34 |
+
"allow_fallback": false
|
| 35 |
}
|
| 36 |
}
|
| 37 |
|
| 38 |
+
Or simple mode:
|
| 39 |
{
|
| 40 |
"inputs": "upbeat pop rap with emotional guitar"
|
| 41 |
}
|
| 42 |
|
| 43 |
+
Notes:
|
| 44 |
+
- This handler uses ACE-Step's official Python API internally.
|
| 45 |
+
- Fallback sine generation is disabled by default so model failures are explicit.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
"""
|
| 47 |
|
| 48 |
def __init__(self, path: str = ""):
|
| 49 |
self.path = path
|
| 50 |
+
self.project_root = os.path.dirname(os.path.abspath(__file__))
|
| 51 |
+
|
| 52 |
self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
|
| 53 |
+
self.config_path = os.getenv("ACE_CONFIG_PATH", "acestep-v15-turbo")
|
| 54 |
+
self.lm_model_path = os.getenv("ACE_LM_MODEL_PATH", "acestep-5Hz-lm-1.7B")
|
| 55 |
+
self.lm_backend = os.getenv("ACE_LM_BACKEND", "pt")
|
| 56 |
+
self.download_source = os.getenv("ACE_DOWNLOAD_SOURCE", "huggingface")
|
| 57 |
+
|
| 58 |
self.default_sr = int(os.getenv("DEFAULT_SAMPLE_RATE", "44100"))
|
| 59 |
+
self.enable_fallback = self._to_bool(os.getenv("ACE_ENABLE_FALLBACK"), False)
|
| 60 |
+
self.init_lm_on_start = self._to_bool(os.getenv("ACE_INIT_LLM"), False)
|
| 61 |
+
self.skip_init = self._to_bool(os.getenv("ACE_SKIP_INIT"), False)
|
| 62 |
|
|
|
|
| 63 |
self.device = "cuda" if (torch is not None and torch.cuda.is_available()) else "cpu"
|
| 64 |
self.dtype = "float16" if self.device == "cuda" else "float32"
|
| 65 |
|
| 66 |
+
self.model_loaded = False
|
| 67 |
+
self.model_error: Optional[str] = None
|
| 68 |
+
self.init_details: Dict[str, Any] = {}
|
| 69 |
+
|
| 70 |
+
self.dit_handler = None
|
| 71 |
+
self.llm_handler = None
|
| 72 |
+
self.llm_initialized = False
|
| 73 |
+
self.llm_error: Optional[str] = None
|
| 74 |
+
|
| 75 |
+
self._GenerationParams = None
|
| 76 |
+
self._GenerationConfig = None
|
| 77 |
+
self._generate_music = None
|
| 78 |
+
self._create_sample = None
|
| 79 |
+
|
| 80 |
+
if self.skip_init:
|
| 81 |
+
self.model_error = "Initialization skipped because ACE_SKIP_INIT=true"
|
| 82 |
+
else:
|
| 83 |
+
self._init_model()
|
| 84 |
|
| 85 |
# --------------------------
|
| 86 |
+
# Initialization
|
| 87 |
# --------------------------
|
| 88 |
def _init_model(self) -> None:
|
| 89 |
err_msgs = []
|
| 90 |
|
|
|
|
| 91 |
try:
|
| 92 |
+
from acestep.handler import AceStepHandler
|
| 93 |
+
from acestep.inference import GenerationConfig, GenerationParams, create_sample, generate_music
|
| 94 |
+
from acestep.llm_inference import LLMHandler
|
| 95 |
+
except Exception as e:
|
| 96 |
+
self.model_error = f"ACE-Step import failed: {type(e).__name__}: {e}"
|
| 97 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
self._GenerationParams = GenerationParams
|
| 100 |
+
self._GenerationConfig = GenerationConfig
|
| 101 |
+
self._generate_music = generate_music
|
| 102 |
+
self._create_sample = create_sample
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
try:
|
| 105 |
+
self.dit_handler = AceStepHandler()
|
| 106 |
+
prefer_source = self.download_source if self.download_source in {"huggingface", "modelscope"} else None
|
| 107 |
+
init_status, ok = self.dit_handler.initialize_service(
|
| 108 |
+
project_root=self.project_root,
|
| 109 |
+
config_path=self.config_path,
|
| 110 |
+
device=self.device,
|
| 111 |
+
use_flash_attention=False,
|
| 112 |
+
compile_model=False,
|
| 113 |
+
offload_to_cpu=False,
|
| 114 |
+
offload_dit_to_cpu=False,
|
| 115 |
+
prefer_source=prefer_source,
|
| 116 |
+
)
|
| 117 |
+
self.init_details["dit_status"] = init_status
|
| 118 |
+
if not ok:
|
| 119 |
+
raise RuntimeError(init_status)
|
| 120 |
except Exception as e:
|
| 121 |
+
err_msgs.append(f"DiT init failed: {type(e).__name__}: {e}")
|
| 122 |
|
|
|
|
| 123 |
try:
|
| 124 |
+
self.llm_handler = LLMHandler()
|
| 125 |
+
if self.init_lm_on_start:
|
| 126 |
+
self._ensure_llm_initialized()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
except Exception as e:
|
| 128 |
+
err_msgs.append(f"LLM bootstrap failed: {type(e).__name__}: {e}")
|
| 129 |
+
|
| 130 |
+
if err_msgs:
|
| 131 |
+
self.model_loaded = False
|
| 132 |
+
self.model_error = " | ".join(err_msgs)
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
self.model_loaded = True
|
| 136 |
+
self.model_error = None
|
| 137 |
+
|
| 138 |
+
def _ensure_llm_initialized(self) -> bool:
|
| 139 |
+
if self.llm_handler is None:
|
| 140 |
+
self.llm_error = "LLM handler is not available"
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
if self.llm_initialized:
|
| 144 |
+
return True
|
| 145 |
|
| 146 |
+
try:
|
| 147 |
+
checkpoint_dir = os.path.join(self.project_root, "checkpoints")
|
| 148 |
+
status, ok = self.llm_handler.initialize(
|
| 149 |
+
checkpoint_dir=checkpoint_dir,
|
| 150 |
+
lm_model_path=self.lm_model_path,
|
| 151 |
+
backend=self.lm_backend,
|
| 152 |
+
device=self.device,
|
| 153 |
+
offload_to_cpu=False,
|
| 154 |
+
)
|
| 155 |
+
self.init_details["llm_status"] = status
|
| 156 |
+
if not ok:
|
| 157 |
+
self.llm_error = status
|
| 158 |
+
self.llm_initialized = False
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
self.llm_error = None
|
| 162 |
+
self.llm_initialized = True
|
| 163 |
+
return True
|
| 164 |
+
except Exception as e:
|
| 165 |
+
self.llm_error = f"LLM init exception: {type(e).__name__}: {e}"
|
| 166 |
+
self.llm_initialized = False
|
| 167 |
+
return False
|
| 168 |
|
| 169 |
# --------------------------
|
| 170 |
# Audio helpers
|
|
|
|
| 178 |
else:
|
| 179 |
arr = np.asarray(audio)
|
| 180 |
|
|
|
|
| 181 |
if arr.ndim == 2 and arr.shape[0] in (1, 2) and arr.shape[1] > arr.shape[0]:
|
| 182 |
arr = arr.T
|
| 183 |
|
|
|
|
| 199 |
y = (0.07 * np.sin(2 * np.pi * 440 * t) + 0.01 * rng.standard_normal(len(t))).astype(np.float32)
|
| 200 |
return np.clip(y, -1.0, 1.0)
|
| 201 |
|
| 202 |
+
# --------------------------
|
| 203 |
+
# Request normalization
|
| 204 |
+
# --------------------------
|
| 205 |
@staticmethod
|
| 206 |
def _to_bool(value: Any, default: bool = False) -> bool:
|
| 207 |
if value is None:
|
|
|
|
| 256 |
if not lyrics and (instrumental or simple_prompt):
|
| 257 |
lyrics = "[Instrumental]"
|
| 258 |
|
| 259 |
+
duration_sec = self._to_int(raw_inputs.get("duration_sec", raw_inputs.get("duration", 12)), 12)
|
| 260 |
+
duration_sec = max(10, min(duration_sec, 600))
|
| 261 |
|
| 262 |
sample_rate = self._to_int(raw_inputs.get("sample_rate", self.default_sr), self.default_sr)
|
| 263 |
sample_rate = max(8000, min(sample_rate, 48000))
|
| 264 |
|
| 265 |
seed = self._to_int(raw_inputs.get("seed", 42), 42)
|
| 266 |
guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
|
| 267 |
+
steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps", 8)), 8)
|
| 268 |
+
steps = max(1, min(steps, 200))
|
| 269 |
use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
|
| 270 |
+
allow_fallback = self._to_bool(raw_inputs.get("allow_fallback"), self.enable_fallback)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
return {
|
| 273 |
"prompt": prompt,
|
|
|
|
| 280 |
"use_lm": use_lm,
|
| 281 |
"instrumental": instrumental,
|
| 282 |
"simple_prompt": simple_prompt,
|
| 283 |
+
"allow_fallback": allow_fallback,
|
|
|
|
| 284 |
}
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
# --------------------------
|
| 287 |
+
# ACE-Step invocation
|
| 288 |
# --------------------------
|
| 289 |
+
def _build_generation_inputs(self, req: Dict[str, Any], llm_ready: bool) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 290 |
+
caption = req["prompt"]
|
| 291 |
+
lyrics = req["lyrics"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
extras: Dict[str, Any] = {
|
| 294 |
+
"simple_expansion_used": False,
|
| 295 |
+
"simple_expansion_error": None,
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
bpm = None
|
| 299 |
+
keyscale = ""
|
| 300 |
+
timesignature = ""
|
| 301 |
+
vocal_language = "unknown"
|
| 302 |
+
duration = float(req["duration_sec"])
|
| 303 |
+
|
| 304 |
+
if req["simple_prompt"] and req["use_lm"] and llm_ready and caption:
|
| 305 |
+
try:
|
| 306 |
+
sample = self._create_sample(
|
| 307 |
+
llm_handler=self.llm_handler,
|
| 308 |
+
query=caption,
|
| 309 |
+
instrumental=req["instrumental"],
|
| 310 |
+
)
|
| 311 |
+
if getattr(sample, "success", False):
|
| 312 |
+
caption = getattr(sample, "caption", "") or caption
|
| 313 |
+
lyrics = getattr(sample, "lyrics", "") or lyrics
|
| 314 |
+
bpm = getattr(sample, "bpm", None)
|
| 315 |
+
keyscale = getattr(sample, "keyscale", "") or ""
|
| 316 |
+
timesignature = getattr(sample, "timesignature", "") or ""
|
| 317 |
+
vocal_language = getattr(sample, "language", "") or "unknown"
|
| 318 |
+
sample_duration = getattr(sample, "duration", None)
|
| 319 |
+
if sample_duration:
|
| 320 |
+
duration = float(sample_duration)
|
| 321 |
+
extras["simple_expansion_used"] = True
|
| 322 |
+
else:
|
| 323 |
+
extras["simple_expansion_error"] = getattr(sample, "error", "create_sample failed")
|
| 324 |
except Exception as e:
|
| 325 |
+
extras["simple_expansion_error"] = f"{type(e).__name__}: {e}"
|
| 326 |
+
|
| 327 |
+
params = self._GenerationParams(
|
| 328 |
+
task_type="text2music",
|
| 329 |
+
caption=caption,
|
| 330 |
+
lyrics=lyrics,
|
| 331 |
+
instrumental=req["instrumental"],
|
| 332 |
+
duration=duration,
|
| 333 |
+
inference_steps=req["steps"],
|
| 334 |
+
guidance_scale=req["guidance_scale"],
|
| 335 |
+
seed=req["seed"],
|
| 336 |
+
bpm=bpm,
|
| 337 |
+
keyscale=keyscale,
|
| 338 |
+
timesignature=timesignature,
|
| 339 |
+
vocal_language=vocal_language,
|
| 340 |
+
thinking=bool(req["use_lm"] and llm_ready),
|
| 341 |
+
use_cot_metas=bool(req["use_lm"] and llm_ready),
|
| 342 |
+
use_cot_caption=bool(req["use_lm"] and llm_ready and not req["simple_prompt"]),
|
| 343 |
+
use_cot_language=bool(req["use_lm"] and llm_ready),
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
config = self._GenerationConfig(
|
| 347 |
+
batch_size=1,
|
| 348 |
+
allow_lm_batch=False,
|
| 349 |
+
use_random_seed=False,
|
| 350 |
+
seeds=[req["seed"]],
|
| 351 |
+
audio_format="wav",
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
extras["resolved_prompt"] = caption
|
| 355 |
+
extras["resolved_lyrics"] = lyrics
|
| 356 |
+
extras["resolved_duration"] = duration
|
| 357 |
|
| 358 |
+
return {"params": params, "config": config}, extras
|
| 359 |
+
|
| 360 |
+
def _call_model(self, req: Dict[str, Any]) -> Tuple[np.ndarray, int, Dict[str, Any]]:
|
| 361 |
+
if not self.model_loaded or self.dit_handler is None:
|
| 362 |
+
raise RuntimeError(self.model_error or "Model is not loaded")
|
| 363 |
+
|
| 364 |
+
llm_ready = False
|
| 365 |
+
if req["use_lm"]:
|
| 366 |
+
llm_ready = self._ensure_llm_initialized()
|
| 367 |
+
|
| 368 |
+
generation_inputs, extras = self._build_generation_inputs(req, llm_ready)
|
| 369 |
+
|
| 370 |
+
result = self._generate_music(
|
| 371 |
+
self.dit_handler,
|
| 372 |
+
self.llm_handler if llm_ready else None,
|
| 373 |
+
generation_inputs["params"],
|
| 374 |
+
generation_inputs["config"],
|
| 375 |
+
save_dir=None,
|
| 376 |
+
progress=None,
|
| 377 |
+
)
|
| 378 |
|
| 379 |
+
if not getattr(result, "success", False):
|
| 380 |
+
raise RuntimeError(getattr(result, "error", "generation failed"))
|
| 381 |
+
|
| 382 |
+
audios = getattr(result, "audios", None) or []
|
| 383 |
+
if not audios:
|
| 384 |
+
raise RuntimeError("generation succeeded but no audio was returned")
|
| 385 |
+
|
| 386 |
+
first = audios[0]
|
| 387 |
+
audio_tensor = first.get("tensor") if isinstance(first, dict) else None
|
| 388 |
+
if audio_tensor is None:
|
| 389 |
+
raise RuntimeError("generated audio tensor is missing")
|
| 390 |
+
|
| 391 |
+
sample_rate = int(first.get("sample_rate", req["sample_rate"]))
|
| 392 |
+
status_message = getattr(result, "status_message", "")
|
| 393 |
+
|
| 394 |
+
meta = {
|
| 395 |
+
"llm_requested": req["use_lm"],
|
| 396 |
+
"llm_initialized": llm_ready,
|
| 397 |
+
"llm_error": self.llm_error,
|
| 398 |
+
"status_message": status_message,
|
| 399 |
+
}
|
| 400 |
+
meta.update(extras)
|
| 401 |
+
|
| 402 |
+
return self._as_float32(audio_tensor), sample_rate, meta
|
| 403 |
+
|
| 404 |
+
# --------------------------
|
| 405 |
+
# Endpoint entry
|
| 406 |
+
# --------------------------
|
| 407 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 408 |
try:
|
| 409 |
req = self._normalize_request(data)
|
| 410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
used_fallback = False
|
| 412 |
+
runtime_meta: Dict[str, Any] = {}
|
| 413 |
+
|
| 414 |
+
try:
|
| 415 |
+
audio, out_sr, runtime_meta = self._call_model(req)
|
| 416 |
+
except Exception as model_exc:
|
| 417 |
+
self.model_error = f"Inference failed: {type(model_exc).__name__}: {model_exc}"
|
| 418 |
+
if not req["allow_fallback"]:
|
| 419 |
+
raise RuntimeError(self.model_error)
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
used_fallback = True
|
| 422 |
audio = self._fallback_sine(req["duration_sec"], req["sample_rate"], req["seed"])
|
| 423 |
out_sr = req["sample_rate"]
|
|
|
|
| 427 |
"sample_rate": int(out_sr),
|
| 428 |
"duration_sec": int(req["duration_sec"]),
|
| 429 |
"used_fallback": used_fallback,
|
| 430 |
+
"model_loaded": self.model_loaded,
|
| 431 |
"model_repo": self.model_repo,
|
| 432 |
"model_error": self.model_error,
|
| 433 |
"meta": {
|
|
|
|
| 441 |
"use_lm": req["use_lm"],
|
| 442 |
"simple_prompt": req["simple_prompt"],
|
| 443 |
"instrumental": req["instrumental"],
|
| 444 |
+
"allow_fallback": req["allow_fallback"],
|
| 445 |
+
"resolved_prompt": runtime_meta.get("resolved_prompt", req["prompt"]),
|
| 446 |
+
"resolved_lyrics": runtime_meta.get("resolved_lyrics", req["lyrics"]),
|
| 447 |
+
"simple_expansion_used": runtime_meta.get("simple_expansion_used", False),
|
| 448 |
+
"simple_expansion_error": runtime_meta.get("simple_expansion_error"),
|
| 449 |
+
"llm_requested": runtime_meta.get("llm_requested", False),
|
| 450 |
+
"llm_initialized": runtime_meta.get("llm_initialized", False),
|
| 451 |
+
"llm_error": runtime_meta.get("llm_error"),
|
| 452 |
+
"status_message": runtime_meta.get("status_message", ""),
|
| 453 |
+
"init_details": self.init_details,
|
| 454 |
},
|
| 455 |
}
|
| 456 |
|
| 457 |
except Exception as e:
|
| 458 |
return {
|
| 459 |
"error": f"{type(e).__name__}: {e}",
|
| 460 |
+
"traceback": traceback.format_exc(limit=4),
|
| 461 |
"audio_base64_wav": None,
|
| 462 |
"sample_rate": None,
|
| 463 |
"duration_sec": None,
|
| 464 |
+
"used_fallback": False,
|
| 465 |
+
"model_loaded": self.model_loaded,
|
| 466 |
+
"model_repo": self.model_repo,
|
| 467 |
+
"model_error": self.model_error,
|
| 468 |
+
"meta": {
|
| 469 |
+
"device": self.device,
|
| 470 |
+
"dtype": self.dtype,
|
| 471 |
+
"init_details": self.init_details,
|
| 472 |
+
"llm_error": self.llm_error,
|
| 473 |
+
},
|
| 474 |
+
}
|
requirements.txt
CHANGED
|
@@ -2,7 +2,13 @@ numpy
|
|
| 2 |
soundfile
|
| 3 |
torch
|
| 4 |
torchaudio
|
| 5 |
-
transformers
|
| 6 |
accelerate
|
| 7 |
huggingface_hub
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
soundfile
|
| 3 |
torch
|
| 4 |
torchaudio
|
| 5 |
+
transformers>=4.51.0,<4.58.0
|
| 6 |
accelerate
|
| 7 |
huggingface_hub
|
| 8 |
+
diffusers
|
| 9 |
+
loguru
|
| 10 |
+
tqdm
|
| 11 |
+
numba>=0.63.1
|
| 12 |
+
PyYAML
|
| 13 |
+
modelscope
|
| 14 |
+
filelock>=3.13.0
|