Spaces:
Running
Running
add inference code and doc
Browse files- acestep/api_server.py +12 -2
- acestep/audio_utils.py +396 -0
- acestep/constrained_logits_processor.py +76 -97
- acestep/gradio_ui/event.py +0 -0
- acestep/gradio_ui/events/results_handlers.py +22 -5
- acestep/handler.py +329 -297
- acestep/inference.py +383 -571
- acestep/llm_inference.py +544 -598
- acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +60 -44
- acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py +101 -47
- acestep/third_parts/nano-vllm/pyproject.toml +0 -2
- profile_inference.py +223 -0
acestep/api_server.py
CHANGED
|
@@ -868,7 +868,7 @@ def create_app() -> FastAPI:
|
|
| 868 |
if s in {"", "N/A"}:
|
| 869 |
return None
|
| 870 |
return s
|
| 871 |
-
|
| 872 |
captions=req.caption,
|
| 873 |
lyrics=req.lyrics,
|
| 874 |
bpm=bpm_val,
|
|
@@ -896,10 +896,20 @@ def create_app() -> FastAPI:
|
|
| 896 |
use_tiled_decode=req.use_tiled_decode,
|
| 897 |
progress=None,
|
| 898 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
return {
|
| 900 |
"first_audio_path": _path_to_audio_url(first) if first else None,
|
| 901 |
"second_audio_path": _path_to_audio_url(second) if second else None,
|
| 902 |
-
"audio_paths": [_path_to_audio_url(p) for p in (
|
| 903 |
"generation_info": gen_info,
|
| 904 |
"status_message": status_msg,
|
| 905 |
"seed_value": seed_value,
|
|
|
|
| 868 |
if s in {"", "N/A"}:
|
| 869 |
return None
|
| 870 |
return s
|
| 871 |
+
result = h.generate_music(
|
| 872 |
captions=req.caption,
|
| 873 |
lyrics=req.lyrics,
|
| 874 |
bpm=bpm_val,
|
|
|
|
| 896 |
use_tiled_decode=req.use_tiled_decode,
|
| 897 |
progress=None,
|
| 898 |
)
|
| 899 |
+
|
| 900 |
+
# Extract values from new dict structure
|
| 901 |
+
audios = result.get("audios", [])
|
| 902 |
+
audio_paths = [audio.get("path") for audio in audios]
|
| 903 |
+
first = audio_paths[0] if len(audio_paths) > 0 else None
|
| 904 |
+
second = audio_paths[1] if len(audio_paths) > 1 else None
|
| 905 |
+
gen_info = result.get("generation_info", "")
|
| 906 |
+
status_msg = result.get("status_message", "")
|
| 907 |
+
seed_value = result.get("extra_outputs", {}).get("seed_value", "")
|
| 908 |
+
|
| 909 |
return {
|
| 910 |
"first_audio_path": _path_to_audio_url(first) if first else None,
|
| 911 |
"second_audio_path": _path_to_audio_url(second) if second else None,
|
| 912 |
+
"audio_paths": [_path_to_audio_url(p) for p in (audio_paths or [])],
|
| 913 |
"generation_info": gen_info,
|
| 914 |
"status_message": status_msg,
|
| 915 |
"seed_value": seed_value,
|
acestep/audio_utils.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio saving and transcoding utility module
|
| 3 |
+
|
| 4 |
+
Independent audio file operations outside of handler, supporting:
|
| 5 |
+
- Save audio tensor/numpy to files (default FLAC format, fast)
|
| 6 |
+
- Format conversion (FLAC/WAV/MP3)
|
| 7 |
+
- Batch processing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hashlib
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Union, Optional, List, Tuple
|
| 15 |
+
import torch
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torchaudio
|
| 18 |
+
from loguru import logger
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AudioSaver:
|
| 22 |
+
"""Audio saving and transcoding utility class"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, default_format: str = "flac"):
|
| 25 |
+
"""
|
| 26 |
+
Initialize audio saver
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
default_format: Default save format ('flac', 'wav', 'mp3')
|
| 30 |
+
"""
|
| 31 |
+
self.default_format = default_format.lower()
|
| 32 |
+
if self.default_format not in ["flac", "wav", "mp3"]:
|
| 33 |
+
logger.warning(f"Unsupported format {default_format}, using 'flac'")
|
| 34 |
+
self.default_format = "flac"
|
| 35 |
+
|
| 36 |
+
def save_audio(
|
| 37 |
+
self,
|
| 38 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 39 |
+
output_path: Union[str, Path],
|
| 40 |
+
sample_rate: int = 48000,
|
| 41 |
+
format: Optional[str] = None,
|
| 42 |
+
channels_first: bool = True,
|
| 43 |
+
) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Save audio data to file
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
|
| 49 |
+
output_path: Output file path (extension can be omitted)
|
| 50 |
+
sample_rate: Sample rate
|
| 51 |
+
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
|
| 52 |
+
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Actual saved file path
|
| 56 |
+
"""
|
| 57 |
+
format = (format or self.default_format).lower()
|
| 58 |
+
if format not in ["flac", "wav", "mp3"]:
|
| 59 |
+
logger.warning(f"Unsupported format {format}, using {self.default_format}")
|
| 60 |
+
format = self.default_format
|
| 61 |
+
|
| 62 |
+
# Ensure output path has correct extension
|
| 63 |
+
output_path = Path(output_path)
|
| 64 |
+
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
|
| 65 |
+
output_path = output_path.with_suffix(f'.{format}')
|
| 66 |
+
|
| 67 |
+
# Convert to torch tensor
|
| 68 |
+
if isinstance(audio_data, np.ndarray):
|
| 69 |
+
if channels_first:
|
| 70 |
+
# numpy [samples, channels] -> tensor [channels, samples]
|
| 71 |
+
audio_tensor = torch.from_numpy(audio_data.T).float()
|
| 72 |
+
else:
|
| 73 |
+
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
|
| 74 |
+
audio_tensor = torch.from_numpy(audio_data).float()
|
| 75 |
+
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
|
| 76 |
+
audio_tensor = audio_tensor.T
|
| 77 |
+
else:
|
| 78 |
+
# torch tensor
|
| 79 |
+
audio_tensor = audio_data.cpu().float()
|
| 80 |
+
if not channels_first and audio_tensor.dim() == 2:
|
| 81 |
+
# [samples, channels] -> [channels, samples]
|
| 82 |
+
if audio_tensor.shape[0] > audio_tensor.shape[1]:
|
| 83 |
+
audio_tensor = audio_tensor.T
|
| 84 |
+
|
| 85 |
+
# Ensure memory is contiguous
|
| 86 |
+
audio_tensor = audio_tensor.contiguous()
|
| 87 |
+
|
| 88 |
+
# Select backend and save
|
| 89 |
+
try:
|
| 90 |
+
if format == "mp3":
|
| 91 |
+
# MP3 uses ffmpeg backend
|
| 92 |
+
from torchaudio.io import CodecConfig
|
| 93 |
+
config = CodecConfig(bit_rate=192000, compression_level=1)
|
| 94 |
+
torchaudio.save(
|
| 95 |
+
str(output_path),
|
| 96 |
+
audio_tensor,
|
| 97 |
+
sample_rate,
|
| 98 |
+
channels_first=True,
|
| 99 |
+
backend='ffmpeg',
|
| 100 |
+
compression=config,
|
| 101 |
+
buffer_size=65536
|
| 102 |
+
)
|
| 103 |
+
elif format in ["flac", "wav"]:
|
| 104 |
+
# FLAC and WAV use soundfile backend (fastest)
|
| 105 |
+
torchaudio.save(
|
| 106 |
+
str(output_path),
|
| 107 |
+
audio_tensor,
|
| 108 |
+
sample_rate,
|
| 109 |
+
channels_first=True,
|
| 110 |
+
backend='soundfile',
|
| 111 |
+
buffer_size=65536
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
# Other formats use default backend
|
| 115 |
+
torchaudio.save(
|
| 116 |
+
str(output_path),
|
| 117 |
+
audio_tensor,
|
| 118 |
+
sample_rate,
|
| 119 |
+
channels_first=True,
|
| 120 |
+
buffer_size=65536
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 124 |
+
return str(output_path)
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"[AudioSaver] Failed to save audio: {e}")
|
| 128 |
+
raise
|
| 129 |
+
|
| 130 |
+
def convert_audio(
|
| 131 |
+
self,
|
| 132 |
+
input_path: Union[str, Path],
|
| 133 |
+
output_path: Union[str, Path],
|
| 134 |
+
output_format: str,
|
| 135 |
+
remove_input: bool = False,
|
| 136 |
+
) -> str:
|
| 137 |
+
"""
|
| 138 |
+
Convert audio format
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
input_path: Input audio file path
|
| 142 |
+
output_path: Output audio file path
|
| 143 |
+
output_format: Target format ('flac', 'wav', 'mp3')
|
| 144 |
+
remove_input: Whether to delete input file
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Output file path
|
| 148 |
+
"""
|
| 149 |
+
input_path = Path(input_path)
|
| 150 |
+
output_path = Path(output_path)
|
| 151 |
+
|
| 152 |
+
if not input_path.exists():
|
| 153 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
| 154 |
+
|
| 155 |
+
# Load audio
|
| 156 |
+
audio_tensor, sample_rate = torchaudio.load(str(input_path))
|
| 157 |
+
|
| 158 |
+
# Save as new format
|
| 159 |
+
output_path = self.save_audio(
|
| 160 |
+
audio_tensor,
|
| 161 |
+
output_path,
|
| 162 |
+
sample_rate=sample_rate,
|
| 163 |
+
format=output_format,
|
| 164 |
+
channels_first=True
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Delete input file if needed
|
| 168 |
+
if remove_input:
|
| 169 |
+
input_path.unlink()
|
| 170 |
+
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
|
| 171 |
+
|
| 172 |
+
return output_path
|
| 173 |
+
|
| 174 |
+
def save_batch(
|
| 175 |
+
self,
|
| 176 |
+
audio_batch: Union[List[torch.Tensor], torch.Tensor],
|
| 177 |
+
output_dir: Union[str, Path],
|
| 178 |
+
file_prefix: str = "audio",
|
| 179 |
+
sample_rate: int = 48000,
|
| 180 |
+
format: Optional[str] = None,
|
| 181 |
+
channels_first: bool = True,
|
| 182 |
+
) -> List[str]:
|
| 183 |
+
"""
|
| 184 |
+
Save audio batch
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
|
| 188 |
+
output_dir: Output directory
|
| 189 |
+
file_prefix: File prefix
|
| 190 |
+
sample_rate: Sample rate
|
| 191 |
+
format: Audio format
|
| 192 |
+
channels_first: Tensor format flag
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
List of saved file paths
|
| 196 |
+
"""
|
| 197 |
+
output_dir = Path(output_dir)
|
| 198 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 199 |
+
|
| 200 |
+
# Process batch
|
| 201 |
+
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
|
| 202 |
+
# [batch, channels, samples]
|
| 203 |
+
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
|
| 204 |
+
elif isinstance(audio_batch, list):
|
| 205 |
+
audio_list = audio_batch
|
| 206 |
+
else:
|
| 207 |
+
audio_list = [audio_batch]
|
| 208 |
+
|
| 209 |
+
saved_paths = []
|
| 210 |
+
for i, audio in enumerate(audio_list):
|
| 211 |
+
output_path = output_dir / f"{file_prefix}_{i:04d}"
|
| 212 |
+
saved_path = self.save_audio(
|
| 213 |
+
audio,
|
| 214 |
+
output_path,
|
| 215 |
+
sample_rate=sample_rate,
|
| 216 |
+
format=format,
|
| 217 |
+
channels_first=channels_first
|
| 218 |
+
)
|
| 219 |
+
saved_paths.append(saved_path)
|
| 220 |
+
|
| 221 |
+
return saved_paths
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_audio_file_hash(audio_file) -> str:
|
| 225 |
+
"""
|
| 226 |
+
Get hash identifier for an audio file.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
audio_file: Path to audio file (str) or file-like object
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Hash string or empty string
|
| 233 |
+
"""
|
| 234 |
+
if audio_file is None:
|
| 235 |
+
return ""
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
if isinstance(audio_file, str):
|
| 239 |
+
if os.path.exists(audio_file):
|
| 240 |
+
with open(audio_file, 'rb') as f:
|
| 241 |
+
return hashlib.md5(f.read()).hexdigest()
|
| 242 |
+
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
|
| 243 |
+
elif hasattr(audio_file, 'name'):
|
| 244 |
+
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
|
| 245 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 246 |
+
except Exception:
|
| 247 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def generate_uuid_from_params(
|
| 251 |
+
captions: str,
|
| 252 |
+
lyrics: str,
|
| 253 |
+
bpm: Optional[int],
|
| 254 |
+
key_scale: str,
|
| 255 |
+
time_signature: str,
|
| 256 |
+
vocal_language: str,
|
| 257 |
+
inference_steps: int,
|
| 258 |
+
guidance_scale: float,
|
| 259 |
+
seed: Union[str, float, int],
|
| 260 |
+
audio_duration: Optional[float],
|
| 261 |
+
audio_code_string: Union[str, List[str]],
|
| 262 |
+
repainting_start: float,
|
| 263 |
+
repainting_end: Optional[float],
|
| 264 |
+
instruction: str,
|
| 265 |
+
audio_cover_strength: float,
|
| 266 |
+
task_type: str,
|
| 267 |
+
use_adg: bool,
|
| 268 |
+
cfg_interval_start: float,
|
| 269 |
+
cfg_interval_end: float,
|
| 270 |
+
audio_format: str,
|
| 271 |
+
reference_audio=None,
|
| 272 |
+
src_audio=None,
|
| 273 |
+
batch_index: int = 0,
|
| 274 |
+
) -> str:
|
| 275 |
+
"""
|
| 276 |
+
Generate deterministic UUID from generation parameters.
|
| 277 |
+
Same parameters will always generate the same UUID.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
captions: Music caption
|
| 281 |
+
lyrics: Lyrics text
|
| 282 |
+
bpm: BPM value
|
| 283 |
+
key_scale: Musical key and scale
|
| 284 |
+
time_signature: Time signature
|
| 285 |
+
vocal_language: Vocal language code
|
| 286 |
+
inference_steps: Number of inference steps
|
| 287 |
+
guidance_scale: Guidance scale
|
| 288 |
+
seed: Random seed
|
| 289 |
+
audio_duration: Audio duration in seconds
|
| 290 |
+
audio_code_string: Audio code string or list
|
| 291 |
+
repainting_start: Repainting start time
|
| 292 |
+
repainting_end: Repainting end time
|
| 293 |
+
instruction: Task instruction
|
| 294 |
+
audio_cover_strength: Audio cover strength
|
| 295 |
+
task_type: Task type
|
| 296 |
+
use_adg: Whether to use ADG
|
| 297 |
+
cfg_interval_start: CFG interval start
|
| 298 |
+
cfg_interval_end: CFG interval end
|
| 299 |
+
audio_format: Audio format
|
| 300 |
+
reference_audio: Reference audio file path
|
| 301 |
+
src_audio: Source audio file path
|
| 302 |
+
batch_index: Index in batch (for audio_code_string list access)
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
UUID string
|
| 306 |
+
"""
|
| 307 |
+
params_dict = {
|
| 308 |
+
"captions": captions or "",
|
| 309 |
+
"lyrics": lyrics or "",
|
| 310 |
+
"bpm": bpm,
|
| 311 |
+
"key_scale": key_scale or "",
|
| 312 |
+
"time_signature": time_signature or "",
|
| 313 |
+
"vocal_language": vocal_language or "",
|
| 314 |
+
"inference_steps": inference_steps,
|
| 315 |
+
"guidance_scale": guidance_scale,
|
| 316 |
+
"seed": seed,
|
| 317 |
+
"audio_duration": audio_duration,
|
| 318 |
+
"audio_code_string": audio_code_string if isinstance(audio_code_string, str) else (audio_code_string[batch_index] if isinstance(audio_code_string, list) and batch_index < len(audio_code_string) else ""),
|
| 319 |
+
"repainting_start": repainting_start,
|
| 320 |
+
"repainting_end": repainting_end,
|
| 321 |
+
"instruction": instruction or "",
|
| 322 |
+
"audio_cover_strength": audio_cover_strength,
|
| 323 |
+
"task_type": task_type or "",
|
| 324 |
+
"use_adg": use_adg,
|
| 325 |
+
"cfg_interval_start": cfg_interval_start,
|
| 326 |
+
"cfg_interval_end": cfg_interval_end,
|
| 327 |
+
"audio_format": audio_format or "",
|
| 328 |
+
"reference_audio_hash": get_audio_file_hash(reference_audio),
|
| 329 |
+
"src_audio_hash": get_audio_file_hash(src_audio),
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
| 333 |
+
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
| 334 |
+
hash_hex = hash_obj.hexdigest()
|
| 335 |
+
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
|
| 336 |
+
return uuid_str
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def generate_uuid_from_audio_data(
|
| 340 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 341 |
+
seed: Optional[int] = None
|
| 342 |
+
) -> str:
|
| 343 |
+
"""
|
| 344 |
+
Generate UUID from audio data (for caching/deduplication)
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
audio_data: Audio data
|
| 348 |
+
seed: Optional seed value
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
UUID string
|
| 352 |
+
"""
|
| 353 |
+
if isinstance(audio_data, torch.Tensor):
|
| 354 |
+
# Convert to numpy and calculate hash
|
| 355 |
+
audio_np = audio_data.cpu().numpy()
|
| 356 |
+
else:
|
| 357 |
+
audio_np = audio_data
|
| 358 |
+
|
| 359 |
+
# Calculate data hash
|
| 360 |
+
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
|
| 361 |
+
|
| 362 |
+
if seed is not None:
|
| 363 |
+
combined = f"{data_hash}_{seed}"
|
| 364 |
+
return hashlib.md5(combined.encode()).hexdigest()
|
| 365 |
+
|
| 366 |
+
return data_hash
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
# Global default instance
|
| 370 |
+
_default_saver = AudioSaver(default_format="flac")
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def save_audio(
|
| 374 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 375 |
+
output_path: Union[str, Path],
|
| 376 |
+
sample_rate: int = 48000,
|
| 377 |
+
format: Optional[str] = None,
|
| 378 |
+
channels_first: bool = True,
|
| 379 |
+
) -> str:
|
| 380 |
+
"""
|
| 381 |
+
Convenience function: save audio (using default configuration)
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
audio_data: Audio data
|
| 385 |
+
output_path: Output path
|
| 386 |
+
sample_rate: Sample rate
|
| 387 |
+
format: Format (default flac)
|
| 388 |
+
channels_first: Tensor format flag
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
Saved file path
|
| 392 |
+
"""
|
| 393 |
+
return _default_saver.save_audio(
|
| 394 |
+
audio_data, output_path, sample_rate, format, channels_first
|
| 395 |
+
)
|
| 396 |
+
|
acestep/constrained_logits_processor.py
CHANGED
|
@@ -571,6 +571,33 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 571 |
if self.debug:
|
| 572 |
logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
|
| 573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 575 |
"""
|
| 576 |
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
@@ -1484,10 +1511,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1484 |
if self.debug:
|
| 1485 |
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
|
| 1486 |
else:
|
| 1487 |
-
# Force EOS token when target codes count is reached
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
scores =
|
| 1491 |
if self.debug:
|
| 1492 |
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
|
| 1493 |
return self._apply_temperature_scaling(scores)
|
|
@@ -1609,20 +1636,15 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1609 |
input_ids: torch.LongTensor,
|
| 1610 |
scores: torch.FloatTensor,
|
| 1611 |
) -> torch.FloatTensor:
|
| 1612 |
-
"""Process a single sequence and return modified scores."""
|
| 1613 |
|
| 1614 |
# Check if we have tokens in queue for user-provided field
|
| 1615 |
# If so, inject the next token directly
|
| 1616 |
if self.user_field_token_queue:
|
| 1617 |
-
mask = torch.full_like(scores, float('-inf'))
|
| 1618 |
next_token = self.user_field_token_queue[0]
|
| 1619 |
-
|
| 1620 |
-
scores = scores + mask
|
| 1621 |
return scores
|
| 1622 |
|
| 1623 |
-
# Create mask (all -inf initially)
|
| 1624 |
-
mask = torch.full_like(scores, float('-inf'))
|
| 1625 |
-
|
| 1626 |
if self.state in self.fixed_strings:
|
| 1627 |
# Fixed string state: force specific tokens
|
| 1628 |
fixed_str = self.fixed_strings[self.state]
|
|
@@ -1633,28 +1655,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1633 |
# This happens when we're about to complete the </think> tag
|
| 1634 |
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
| 1635 |
# Check if the next token would complete the fixed string
|
| 1636 |
-
# We check if position_in_state + length of next token would complete it
|
| 1637 |
-
# Since we don't know which token will be selected, we check if we're close to completion
|
| 1638 |
-
# Actually, a better approach: check if this is the last character(s) of the fixed string
|
| 1639 |
remaining_chars = len(fixed_str) - self.position_in_state
|
| 1640 |
# If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
|
| 1641 |
if remaining_chars <= 10:
|
| 1642 |
# Force EOS token to stop generation
|
| 1643 |
if self.eos_token_id is not None:
|
| 1644 |
-
|
| 1645 |
-
scores = scores + mask
|
| 1646 |
if self.debug:
|
| 1647 |
logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
|
| 1648 |
return scores
|
| 1649 |
|
| 1650 |
-
|
| 1651 |
-
|
| 1652 |
-
# Apply mask
|
| 1653 |
-
scores = scores + mask
|
| 1654 |
-
|
| 1655 |
-
# Update position tracking
|
| 1656 |
-
# We need to check if the selected token completes the fixed string
|
| 1657 |
-
# This will be done in update_state() after token selection
|
| 1658 |
else:
|
| 1659 |
# Position exceeds string, move to next state
|
| 1660 |
# If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
|
|
@@ -1662,8 +1674,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1662 |
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
| 1663 |
# Force EOS token to stop generation
|
| 1664 |
if self.eos_token_id is not None:
|
| 1665 |
-
|
| 1666 |
-
scores = scores + mask
|
| 1667 |
if self.debug:
|
| 1668 |
logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
|
| 1669 |
return scores
|
|
@@ -1676,7 +1687,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1676 |
if self.debug:
|
| 1677 |
logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
|
| 1678 |
return scores
|
| 1679 |
-
|
|
|
|
|
|
|
| 1680 |
|
| 1681 |
elif self.state == FSMState.BPM_VALUE:
|
| 1682 |
# Check if field is user-provided and we haven't started injecting yet
|
|
@@ -1690,22 +1703,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1690 |
self.user_field_token_queue = value_tokens
|
| 1691 |
self.current_user_field = "bpm"
|
| 1692 |
# Inject first token
|
| 1693 |
-
|
| 1694 |
-
scores = scores + mask
|
| 1695 |
return scores
|
| 1696 |
|
| 1697 |
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
|
| 1698 |
allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
|
| 1699 |
-
for t in allowed:
|
| 1700 |
-
mask[0, t] = 0
|
| 1701 |
|
| 1702 |
# Also allow newline if current token sequence prefix allows it
|
| 1703 |
-
# Check if current token sequence is in prefix tree and allows newline
|
| 1704 |
token_prefix = tuple(self.accumulated_token_ids)
|
| 1705 |
if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
|
| 1706 |
-
|
| 1707 |
|
| 1708 |
-
scores
|
| 1709 |
|
| 1710 |
elif self.state == FSMState.CAPTION_VALUE:
|
| 1711 |
# Caption field generation with YAML format support:
|
|
@@ -1724,8 +1733,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1724 |
self.user_field_token_queue = value_tokens
|
| 1725 |
self.current_user_field = "caption"
|
| 1726 |
# Inject first token
|
| 1727 |
-
|
| 1728 |
-
scores = scores + mask
|
| 1729 |
return scores
|
| 1730 |
|
| 1731 |
# Check if we should transition after a newline (non-indented line = new field)
|
|
@@ -1757,7 +1765,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1757 |
# The field name detection will happen in update_state()
|
| 1758 |
return scores
|
| 1759 |
|
| 1760 |
-
# Block backticks (code blocks)
|
| 1761 |
if self.backtick_token is not None:
|
| 1762 |
scores[0, self.backtick_token] = float('-inf')
|
| 1763 |
|
|
@@ -1773,8 +1781,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1773 |
if self.caption_token_count >= 512:
|
| 1774 |
# Force end by only allowing newline
|
| 1775 |
if self.newline_token is not None:
|
| 1776 |
-
|
| 1777 |
-
scores = scores + mask
|
| 1778 |
return scores
|
| 1779 |
|
| 1780 |
# Allow natural generation (with blocked audio codes and backticks)
|
|
@@ -1791,8 +1798,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1791 |
self.user_field_token_queue = value_tokens
|
| 1792 |
self.current_user_field = "duration"
|
| 1793 |
# Inject first token
|
| 1794 |
-
|
| 1795 |
-
scores = scores + mask
|
| 1796 |
return scores
|
| 1797 |
|
| 1798 |
# If target_duration is set, force generate that exact value
|
|
@@ -1804,26 +1810,22 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1804 |
# Force the next digit
|
| 1805 |
next_digit = int(target_str[current_pos])
|
| 1806 |
if next_digit in self.digit_tokens:
|
| 1807 |
-
|
| 1808 |
else:
|
| 1809 |
# All digits generated, force newline
|
| 1810 |
if self.newline_token:
|
| 1811 |
-
|
| 1812 |
-
|
| 1813 |
-
scores = scores + mask
|
| 1814 |
else:
|
| 1815 |
# Normal duration generation with range constraint
|
| 1816 |
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
|
| 1817 |
allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
|
| 1818 |
-
for t in allowed:
|
| 1819 |
-
mask[0, t] = 0
|
| 1820 |
|
| 1821 |
# Also allow newline if current token sequence prefix allows it
|
| 1822 |
token_prefix = tuple(self.accumulated_token_ids)
|
| 1823 |
if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
|
| 1824 |
-
|
| 1825 |
|
| 1826 |
-
scores
|
| 1827 |
|
| 1828 |
elif self.state == FSMState.GENRES_VALUE:
|
| 1829 |
# Check if field is user-provided and we haven't started injecting yet
|
|
@@ -1836,8 +1838,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1836 |
self.user_field_token_queue = value_tokens
|
| 1837 |
self.current_user_field = "genres"
|
| 1838 |
# Inject first token
|
| 1839 |
-
|
| 1840 |
-
scores = scores + mask
|
| 1841 |
return scores
|
| 1842 |
|
| 1843 |
# Try to hot-reload genres vocab if file has changed
|
|
@@ -1848,24 +1849,20 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1848 |
|
| 1849 |
if allowed:
|
| 1850 |
# Use vocabulary-constrained decoding
|
| 1851 |
-
|
| 1852 |
-
mask[0, t] = 0
|
| 1853 |
-
scores = scores + mask
|
| 1854 |
elif self.genres_vocab:
|
| 1855 |
# Vocab is loaded but no valid continuation found
|
| 1856 |
# Force newline to end the field
|
| 1857 |
if self.newline_token:
|
| 1858 |
-
mask[0, self.newline_token] = 0
|
| 1859 |
if self.debug:
|
| 1860 |
logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
|
| 1861 |
-
|
| 1862 |
else:
|
| 1863 |
# Fallback: no vocab loaded, use probability-based ending
|
| 1864 |
if self._should_end_text_field(scores):
|
| 1865 |
if self.newline_token:
|
| 1866 |
-
|
| 1867 |
self._transition_to_next_state()
|
| 1868 |
-
scores = scores + mask
|
| 1869 |
else:
|
| 1870 |
# Allow any token except newline if we don't have content yet
|
| 1871 |
if not self.accumulated_value.strip():
|
|
@@ -1884,8 +1881,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1884 |
self.user_field_token_queue = value_tokens
|
| 1885 |
self.current_user_field = "keyscale"
|
| 1886 |
# Inject first token
|
| 1887 |
-
|
| 1888 |
-
scores = scores + mask
|
| 1889 |
return scores
|
| 1890 |
|
| 1891 |
# Check if current token sequence is complete (allows newline)
|
|
@@ -1893,21 +1889,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1893 |
if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
|
| 1894 |
# Complete keyscale, allow newline
|
| 1895 |
if self.newline_token:
|
| 1896 |
-
|
| 1897 |
-
scores = scores + mask
|
| 1898 |
else:
|
| 1899 |
# Not complete, allow valid continuation tokens
|
| 1900 |
allowed = self._get_allowed_keyscale_tokens()
|
| 1901 |
if allowed:
|
| 1902 |
-
|
| 1903 |
-
mask[0, t] = 0
|
| 1904 |
-
scores = scores + mask
|
| 1905 |
else:
|
| 1906 |
# No valid tokens found - force newline to end field
|
| 1907 |
# This handles edge cases where keyscale format is unexpected
|
| 1908 |
if self.newline_token:
|
| 1909 |
-
|
| 1910 |
-
scores = scores + mask
|
| 1911 |
|
| 1912 |
elif self.state == FSMState.LANGUAGE_VALUE:
|
| 1913 |
# Language field: Use top-1 probability language (greedy selection)
|
|
@@ -1925,8 +1917,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1925 |
self.user_field_token_queue = value_tokens
|
| 1926 |
self.current_user_field = "language"
|
| 1927 |
# Inject first token
|
| 1928 |
-
|
| 1929 |
-
scores = scores + mask
|
| 1930 |
return scores
|
| 1931 |
|
| 1932 |
# If we haven't started generating language yet (empty accumulated_token_ids),
|
|
@@ -1938,19 +1929,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1938 |
candidate_tokens = list(self.language_prefix_tree[empty_prefix])
|
| 1939 |
|
| 1940 |
if candidate_tokens:
|
| 1941 |
-
# Find the token with highest probability (top-1)
|
| 1942 |
-
#
|
| 1943 |
-
|
| 1944 |
-
|
| 1945 |
-
temp_mask[0, t] = 0
|
| 1946 |
-
temp_scores = scores + temp_mask
|
| 1947 |
|
| 1948 |
# Get the highest probability token among candidates
|
| 1949 |
-
|
|
|
|
| 1950 |
|
| 1951 |
-
# Only allow this top-1 token, block all others
|
| 1952 |
-
|
| 1953 |
-
scores = scores + mask
|
| 1954 |
|
| 1955 |
if self.debug:
|
| 1956 |
top_token_text = self.tokenizer.decode([top_token_id])
|
|
@@ -1958,13 +1947,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1958 |
else:
|
| 1959 |
# No valid first tokens found - force newline
|
| 1960 |
if self.newline_token:
|
| 1961 |
-
|
| 1962 |
-
scores = scores + mask
|
| 1963 |
else:
|
| 1964 |
# Empty prefix not in tree - force newline
|
| 1965 |
if self.newline_token:
|
| 1966 |
-
|
| 1967 |
-
scores = scores + mask
|
| 1968 |
else:
|
| 1969 |
# We've started generating a language, continue with prefix tree constraints
|
| 1970 |
# Check if current token sequence is complete (allows newline)
|
|
@@ -1972,20 +1959,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1972 |
if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
|
| 1973 |
# Complete language, allow newline
|
| 1974 |
if self.newline_token:
|
| 1975 |
-
|
| 1976 |
-
scores = scores + mask
|
| 1977 |
else:
|
| 1978 |
# Not complete, allow valid continuation tokens
|
| 1979 |
allowed = self._get_allowed_language_tokens()
|
| 1980 |
if allowed:
|
| 1981 |
-
|
| 1982 |
-
mask[0, t] = 0
|
| 1983 |
-
scores = scores + mask
|
| 1984 |
else:
|
| 1985 |
# No valid tokens found - force newline to end field
|
| 1986 |
if self.newline_token:
|
| 1987 |
-
|
| 1988 |
-
scores = scores + mask
|
| 1989 |
|
| 1990 |
elif self.state == FSMState.TIMESIG_VALUE:
|
| 1991 |
# Check if field is user-provided and we haven't started injecting yet
|
|
@@ -1998,8 +1981,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1998 |
self.user_field_token_queue = value_tokens
|
| 1999 |
self.current_user_field = "timesignature"
|
| 2000 |
# Inject first token
|
| 2001 |
-
|
| 2002 |
-
scores = scores + mask
|
| 2003 |
return scores
|
| 2004 |
|
| 2005 |
# Check if current token sequence is complete (allows newline)
|
|
@@ -2007,14 +1989,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 2007 |
if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
|
| 2008 |
# Complete value, allow newline
|
| 2009 |
if self.newline_token:
|
| 2010 |
-
|
| 2011 |
-
scores = scores + mask
|
| 2012 |
else:
|
| 2013 |
# Not complete, allow valid continuation tokens
|
| 2014 |
allowed = self._get_allowed_timesig_tokens()
|
| 2015 |
-
|
| 2016 |
-
mask[0, t] = 0
|
| 2017 |
-
scores = scores + mask
|
| 2018 |
|
| 2019 |
return scores
|
| 2020 |
|
|
|
|
| 571 |
if self.debug:
|
| 572 |
logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
|
| 573 |
|
| 574 |
+
def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None:
|
| 575 |
+
"""
|
| 576 |
+
Apply whitelist constraint inplace: only allow specified tokens, block all others.
|
| 577 |
+
|
| 578 |
+
This is more efficient than creating a mask tensor because:
|
| 579 |
+
1. No memory allocation for mask
|
| 580 |
+
2. No tensor addition operation
|
| 581 |
+
|
| 582 |
+
Args:
|
| 583 |
+
scores: [1, vocab_size] scores tensor to modify inplace
|
| 584 |
+
allowed_tokens: List of token IDs to allow (all others will be set to -inf)
|
| 585 |
+
"""
|
| 586 |
+
if not allowed_tokens:
|
| 587 |
+
# No tokens allowed, set all to -inf
|
| 588 |
+
scores.fill_(float('-inf'))
|
| 589 |
+
return
|
| 590 |
+
|
| 591 |
+
# Save the original values of allowed tokens
|
| 592 |
+
allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long)
|
| 593 |
+
saved_values = scores[0, allowed_indices].clone()
|
| 594 |
+
|
| 595 |
+
# Set all scores to -inf
|
| 596 |
+
scores.fill_(float('-inf'))
|
| 597 |
+
|
| 598 |
+
# Restore allowed token values
|
| 599 |
+
scores[0, allowed_indices] = saved_values
|
| 600 |
+
|
| 601 |
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 602 |
"""
|
| 603 |
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
|
|
| 1511 |
if self.debug:
|
| 1512 |
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
|
| 1513 |
else:
|
| 1514 |
+
# Force EOS token when target codes count is reached - inplace
|
| 1515 |
+
eos_scores = scores[:, self.eos_token_id].clone()
|
| 1516 |
+
scores.fill_(float('-inf'))
|
| 1517 |
+
scores[:, self.eos_token_id] = eos_scores
|
| 1518 |
if self.debug:
|
| 1519 |
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
|
| 1520 |
return self._apply_temperature_scaling(scores)
|
|
|
|
| 1636 |
input_ids: torch.LongTensor,
|
| 1637 |
scores: torch.FloatTensor,
|
| 1638 |
) -> torch.FloatTensor:
|
| 1639 |
+
"""Process a single sequence and return modified scores (inplace when possible)."""
|
| 1640 |
|
| 1641 |
# Check if we have tokens in queue for user-provided field
|
| 1642 |
# If so, inject the next token directly
|
| 1643 |
if self.user_field_token_queue:
|
|
|
|
| 1644 |
next_token = self.user_field_token_queue[0]
|
| 1645 |
+
self._apply_whitelist_inplace(scores, [next_token])
|
|
|
|
| 1646 |
return scores
|
| 1647 |
|
|
|
|
|
|
|
|
|
|
| 1648 |
if self.state in self.fixed_strings:
|
| 1649 |
# Fixed string state: force specific tokens
|
| 1650 |
fixed_str = self.fixed_strings[self.state]
|
|
|
|
| 1655 |
# This happens when we're about to complete the </think> tag
|
| 1656 |
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
| 1657 |
# Check if the next token would complete the fixed string
|
|
|
|
|
|
|
|
|
|
| 1658 |
remaining_chars = len(fixed_str) - self.position_in_state
|
| 1659 |
# If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
|
| 1660 |
if remaining_chars <= 10:
|
| 1661 |
# Force EOS token to stop generation
|
| 1662 |
if self.eos_token_id is not None:
|
| 1663 |
+
self._apply_whitelist_inplace(scores, [self.eos_token_id])
|
|
|
|
| 1664 |
if self.debug:
|
| 1665 |
logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
|
| 1666 |
return scores
|
| 1667 |
|
| 1668 |
+
# Apply whitelist constraint inplace
|
| 1669 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1670 |
else:
|
| 1671 |
# Position exceeds string, move to next state
|
| 1672 |
# If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
|
|
|
|
| 1674 |
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
| 1675 |
# Force EOS token to stop generation
|
| 1676 |
if self.eos_token_id is not None:
|
| 1677 |
+
self._apply_whitelist_inplace(scores, [self.eos_token_id])
|
|
|
|
| 1678 |
if self.debug:
|
| 1679 |
logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
|
| 1680 |
return scores
|
|
|
|
| 1687 |
if self.debug:
|
| 1688 |
logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
|
| 1689 |
return scores
|
| 1690 |
+
# For recursion, reset scores to zero (no constraints from previous state)
|
| 1691 |
+
scores.zero_()
|
| 1692 |
+
return self._process_single_sequence(input_ids, scores)
|
| 1693 |
|
| 1694 |
elif self.state == FSMState.BPM_VALUE:
|
| 1695 |
# Check if field is user-provided and we haven't started injecting yet
|
|
|
|
| 1703 |
self.user_field_token_queue = value_tokens
|
| 1704 |
self.current_user_field = "bpm"
|
| 1705 |
# Inject first token
|
| 1706 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1707 |
return scores
|
| 1708 |
|
| 1709 |
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
|
| 1710 |
allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
|
|
|
|
|
|
|
| 1711 |
|
| 1712 |
# Also allow newline if current token sequence prefix allows it
|
|
|
|
| 1713 |
token_prefix = tuple(self.accumulated_token_ids)
|
| 1714 |
if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
|
| 1715 |
+
allowed = allowed + [self.newline_token]
|
| 1716 |
|
| 1717 |
+
self._apply_whitelist_inplace(scores, allowed)
|
| 1718 |
|
| 1719 |
elif self.state == FSMState.CAPTION_VALUE:
|
| 1720 |
# Caption field generation with YAML format support:
|
|
|
|
| 1733 |
self.user_field_token_queue = value_tokens
|
| 1734 |
self.current_user_field = "caption"
|
| 1735 |
# Inject first token
|
| 1736 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1737 |
return scores
|
| 1738 |
|
| 1739 |
# Check if we should transition after a newline (non-indented line = new field)
|
|
|
|
| 1765 |
# The field name detection will happen in update_state()
|
| 1766 |
return scores
|
| 1767 |
|
| 1768 |
+
# Block backticks (code blocks) - inplace
|
| 1769 |
if self.backtick_token is not None:
|
| 1770 |
scores[0, self.backtick_token] = float('-inf')
|
| 1771 |
|
|
|
|
| 1781 |
if self.caption_token_count >= 512:
|
| 1782 |
# Force end by only allowing newline
|
| 1783 |
if self.newline_token is not None:
|
| 1784 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1785 |
return scores
|
| 1786 |
|
| 1787 |
# Allow natural generation (with blocked audio codes and backticks)
|
|
|
|
| 1798 |
self.user_field_token_queue = value_tokens
|
| 1799 |
self.current_user_field = "duration"
|
| 1800 |
# Inject first token
|
| 1801 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1802 |
return scores
|
| 1803 |
|
| 1804 |
# If target_duration is set, force generate that exact value
|
|
|
|
| 1810 |
# Force the next digit
|
| 1811 |
next_digit = int(target_str[current_pos])
|
| 1812 |
if next_digit in self.digit_tokens:
|
| 1813 |
+
self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]])
|
| 1814 |
else:
|
| 1815 |
# All digits generated, force newline
|
| 1816 |
if self.newline_token:
|
| 1817 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
|
|
|
| 1818 |
else:
|
| 1819 |
# Normal duration generation with range constraint
|
| 1820 |
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
|
| 1821 |
allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
|
|
|
|
|
|
|
| 1822 |
|
| 1823 |
# Also allow newline if current token sequence prefix allows it
|
| 1824 |
token_prefix = tuple(self.accumulated_token_ids)
|
| 1825 |
if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
|
| 1826 |
+
allowed = allowed + [self.newline_token]
|
| 1827 |
|
| 1828 |
+
self._apply_whitelist_inplace(scores, allowed)
|
| 1829 |
|
| 1830 |
elif self.state == FSMState.GENRES_VALUE:
|
| 1831 |
# Check if field is user-provided and we haven't started injecting yet
|
|
|
|
| 1838 |
self.user_field_token_queue = value_tokens
|
| 1839 |
self.current_user_field = "genres"
|
| 1840 |
# Inject first token
|
| 1841 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1842 |
return scores
|
| 1843 |
|
| 1844 |
# Try to hot-reload genres vocab if file has changed
|
|
|
|
| 1849 |
|
| 1850 |
if allowed:
|
| 1851 |
# Use vocabulary-constrained decoding
|
| 1852 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
| 1853 |
elif self.genres_vocab:
|
| 1854 |
# Vocab is loaded but no valid continuation found
|
| 1855 |
# Force newline to end the field
|
| 1856 |
if self.newline_token:
|
|
|
|
| 1857 |
if self.debug:
|
| 1858 |
logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
|
| 1859 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
| 1860 |
else:
|
| 1861 |
# Fallback: no vocab loaded, use probability-based ending
|
| 1862 |
if self._should_end_text_field(scores):
|
| 1863 |
if self.newline_token:
|
| 1864 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
| 1865 |
self._transition_to_next_state()
|
|
|
|
| 1866 |
else:
|
| 1867 |
# Allow any token except newline if we don't have content yet
|
| 1868 |
if not self.accumulated_value.strip():
|
|
|
|
| 1881 |
self.user_field_token_queue = value_tokens
|
| 1882 |
self.current_user_field = "keyscale"
|
| 1883 |
# Inject first token
|
| 1884 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1885 |
return scores
|
| 1886 |
|
| 1887 |
# Check if current token sequence is complete (allows newline)
|
|
|
|
| 1889 |
if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
|
| 1890 |
# Complete keyscale, allow newline
|
| 1891 |
if self.newline_token:
|
| 1892 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1893 |
else:
|
| 1894 |
# Not complete, allow valid continuation tokens
|
| 1895 |
allowed = self._get_allowed_keyscale_tokens()
|
| 1896 |
if allowed:
|
| 1897 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
| 1898 |
else:
|
| 1899 |
# No valid tokens found - force newline to end field
|
| 1900 |
# This handles edge cases where keyscale format is unexpected
|
| 1901 |
if self.newline_token:
|
| 1902 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1903 |
|
| 1904 |
elif self.state == FSMState.LANGUAGE_VALUE:
|
| 1905 |
# Language field: Use top-1 probability language (greedy selection)
|
|
|
|
| 1917 |
self.user_field_token_queue = value_tokens
|
| 1918 |
self.current_user_field = "language"
|
| 1919 |
# Inject first token
|
| 1920 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1921 |
return scores
|
| 1922 |
|
| 1923 |
# If we haven't started generating language yet (empty accumulated_token_ids),
|
|
|
|
| 1929 |
candidate_tokens = list(self.language_prefix_tree[empty_prefix])
|
| 1930 |
|
| 1931 |
if candidate_tokens:
|
| 1932 |
+
# Find the token with highest probability (top-1) among candidates
|
| 1933 |
+
# Use tensor indexing to get scores of candidate tokens directly
|
| 1934 |
+
candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long)
|
| 1935 |
+
candidate_scores = scores[0, candidate_indices]
|
|
|
|
|
|
|
| 1936 |
|
| 1937 |
# Get the highest probability token among candidates
|
| 1938 |
+
best_idx = torch.argmax(candidate_scores).item()
|
| 1939 |
+
top_token_id = candidate_tokens[best_idx]
|
| 1940 |
|
| 1941 |
+
# Only allow this top-1 token, block all others
|
| 1942 |
+
self._apply_whitelist_inplace(scores, [top_token_id])
|
|
|
|
| 1943 |
|
| 1944 |
if self.debug:
|
| 1945 |
top_token_text = self.tokenizer.decode([top_token_id])
|
|
|
|
| 1947 |
else:
|
| 1948 |
# No valid first tokens found - force newline
|
| 1949 |
if self.newline_token:
|
| 1950 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1951 |
else:
|
| 1952 |
# Empty prefix not in tree - force newline
|
| 1953 |
if self.newline_token:
|
| 1954 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1955 |
else:
|
| 1956 |
# We've started generating a language, continue with prefix tree constraints
|
| 1957 |
# Check if current token sequence is complete (allows newline)
|
|
|
|
| 1959 |
if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
|
| 1960 |
# Complete language, allow newline
|
| 1961 |
if self.newline_token:
|
| 1962 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1963 |
else:
|
| 1964 |
# Not complete, allow valid continuation tokens
|
| 1965 |
allowed = self._get_allowed_language_tokens()
|
| 1966 |
if allowed:
|
| 1967 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
| 1968 |
else:
|
| 1969 |
# No valid tokens found - force newline to end field
|
| 1970 |
if self.newline_token:
|
| 1971 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1972 |
|
| 1973 |
elif self.state == FSMState.TIMESIG_VALUE:
|
| 1974 |
# Check if field is user-provided and we haven't started injecting yet
|
|
|
|
| 1981 |
self.user_field_token_queue = value_tokens
|
| 1982 |
self.current_user_field = "timesignature"
|
| 1983 |
# Inject first token
|
| 1984 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1985 |
return scores
|
| 1986 |
|
| 1987 |
# Check if current token sequence is complete (allows newline)
|
|
|
|
| 1989 |
if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
|
| 1990 |
# Complete value, allow newline
|
| 1991 |
if self.newline_token:
|
| 1992 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1993 |
else:
|
| 1994 |
# Not complete, allow valid continuation tokens
|
| 1995 |
allowed = self._get_allowed_timesig_tokens()
|
| 1996 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
| 1997 |
|
| 1998 |
return scores
|
| 1999 |
|
acestep/gradio_ui/event.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -332,10 +332,9 @@ def generate_with_progress(
|
|
| 332 |
logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
|
| 333 |
|
| 334 |
# Generate batch
|
| 335 |
-
metadata_list, audio_codes_list, status = llm_handler.
|
| 336 |
caption=captions or "",
|
| 337 |
lyrics=lyrics or "",
|
| 338 |
-
batch_size=chunk_size,
|
| 339 |
infer_type="llm_dit",
|
| 340 |
temperature=lm_temperature,
|
| 341 |
cfg_scale=lm_cfg_scale,
|
|
@@ -347,6 +346,7 @@ def generate_with_progress(
|
|
| 347 |
use_cot_language=use_cot_language,
|
| 348 |
is_format_caption=is_format_caption,
|
| 349 |
constrained_decoding_debug=constrained_decoding_debug,
|
|
|
|
| 350 |
seeds=chunk_seeds,
|
| 351 |
)
|
| 352 |
|
|
@@ -474,9 +474,26 @@ def generate_with_progress(
|
|
| 474 |
progress=progress
|
| 475 |
)
|
| 476 |
|
| 477 |
-
# Extract results
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
# Extract LM timing from status if available and prepend to generation_info
|
| 482 |
if status:
|
|
|
|
| 332 |
logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
|
| 333 |
|
| 334 |
# Generate batch
|
| 335 |
+
metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition(
|
| 336 |
caption=captions or "",
|
| 337 |
lyrics=lyrics or "",
|
|
|
|
| 338 |
infer_type="llm_dit",
|
| 339 |
temperature=lm_temperature,
|
| 340 |
cfg_scale=lm_cfg_scale,
|
|
|
|
| 346 |
use_cot_language=use_cot_language,
|
| 347 |
is_format_caption=is_format_caption,
|
| 348 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 349 |
+
batch_size=chunk_size,
|
| 350 |
seeds=chunk_seeds,
|
| 351 |
)
|
| 352 |
|
|
|
|
| 474 |
progress=progress
|
| 475 |
)
|
| 476 |
|
| 477 |
+
# Extract results from new dict structure
|
| 478 |
+
if not isinstance(result, dict):
|
| 479 |
+
# Fallback for old tuple format (should not happen)
|
| 480 |
+
first_audio, second_audio, all_audio_paths, generation_info, status_message, seed_value_for_ui, \
|
| 481 |
+
align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2 = result
|
| 482 |
+
else:
|
| 483 |
+
audios = result.get("audios", [])
|
| 484 |
+
all_audio_paths = [audio.get("path") for audio in audios]
|
| 485 |
+
first_audio = all_audio_paths[0] if len(all_audio_paths) > 0 else None
|
| 486 |
+
second_audio = all_audio_paths[1] if len(all_audio_paths) > 1 else None
|
| 487 |
+
generation_info = result.get("generation_info", "")
|
| 488 |
+
status_message = result.get("status_message", "")
|
| 489 |
+
seed_value_for_ui = result.get("extra_outputs", {}).get("seed_value", "")
|
| 490 |
+
# Legacy alignment fields (no longer used)
|
| 491 |
+
align_score_1 = ""
|
| 492 |
+
align_text_1 = ""
|
| 493 |
+
align_plot_1 = None
|
| 494 |
+
align_score_2 = ""
|
| 495 |
+
align_text_2 = ""
|
| 496 |
+
align_plot_2 = None
|
| 497 |
|
| 498 |
# Extract LM timing from status if available and prepend to generation_info
|
| 499 |
if status:
|
acestep/handler.py
CHANGED
|
@@ -10,6 +10,8 @@ import traceback
|
|
| 10 |
import re
|
| 11 |
import random
|
| 12 |
import uuid
|
|
|
|
|
|
|
| 13 |
from contextlib import contextmanager
|
| 14 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 15 |
|
|
@@ -37,16 +39,12 @@ warnings.filterwarnings("ignore")
|
|
| 37 |
class AceStepHandler:
|
| 38 |
"""ACE-Step Business Logic Handler"""
|
| 39 |
|
| 40 |
-
def __init__(self
|
| 41 |
self.model = None
|
| 42 |
self.config = None
|
| 43 |
self.device = "cpu"
|
| 44 |
self.dtype = torch.float32 # Will be set based on device in initialize_service
|
| 45 |
-
|
| 46 |
-
self.temp_dir = tempfile.mkdtemp()
|
| 47 |
-
else:
|
| 48 |
-
self.temp_dir = save_root
|
| 49 |
-
|
| 50 |
# VAE for audio encoding/decoding
|
| 51 |
self.vae = None
|
| 52 |
|
|
@@ -81,8 +79,7 @@ class AceStepHandler:
|
|
| 81 |
def get_available_checkpoints(self) -> str:
|
| 82 |
"""Return project root directory path"""
|
| 83 |
# Get project root (handler.py is in acestep/, so go up two levels to project root)
|
| 84 |
-
|
| 85 |
-
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 86 |
# default checkpoints
|
| 87 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 88 |
if os.path.exists(checkpoint_dir):
|
|
@@ -93,8 +90,7 @@ class AceStepHandler:
|
|
| 93 |
def get_available_acestep_v15_models(self) -> List[str]:
|
| 94 |
"""Scan and return all model directory names starting with 'acestep-v15-'"""
|
| 95 |
# Get project root
|
| 96 |
-
|
| 97 |
-
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 98 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 99 |
|
| 100 |
models = []
|
|
@@ -171,8 +167,7 @@ class AceStepHandler:
|
|
| 171 |
|
| 172 |
|
| 173 |
# Auto-detect project root (independent of passed project_root parameter)
|
| 174 |
-
|
| 175 |
-
actual_project_root = os.path.dirname(os.path.dirname(current_file))
|
| 176 |
checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
|
| 177 |
|
| 178 |
# 1. Load main model
|
|
@@ -187,7 +182,7 @@ class AceStepHandler:
|
|
| 187 |
attn_implementation = "sdpa"
|
| 188 |
|
| 189 |
try:
|
| 190 |
-
logger.info(f"Attempting to load model with attention implementation: {attn_implementation}")
|
| 191 |
self.model = AutoModel.from_pretrained(
|
| 192 |
acestep_v15_checkpoint_path,
|
| 193 |
trust_remote_code=True,
|
|
@@ -195,9 +190,9 @@ class AceStepHandler:
|
|
| 195 |
dtype="bfloat16"
|
| 196 |
)
|
| 197 |
except Exception as e:
|
| 198 |
-
logger.warning(f"Failed to load model with {attn_implementation}: {e}")
|
| 199 |
if attn_implementation == "sdpa":
|
| 200 |
-
logger.info("Falling back to eager attention")
|
| 201 |
attn_implementation = "eager"
|
| 202 |
self.model = AutoModel.from_pretrained(
|
| 203 |
acestep_v15_checkpoint_path,
|
|
@@ -215,7 +210,7 @@ class AceStepHandler:
|
|
| 215 |
else:
|
| 216 |
# If offload_to_cpu is True, check if we should keep DiT on GPU
|
| 217 |
if not self.offload_dit_to_cpu:
|
| 218 |
-
logger.info(f"Keeping main model on {device} (persistent)")
|
| 219 |
self.model = self.model.to(device).to(self.dtype)
|
| 220 |
else:
|
| 221 |
self.model = self.model.to("cpu").to(self.dtype)
|
|
@@ -239,7 +234,7 @@ class AceStepHandler:
|
|
| 239 |
raise ValueError(f"Unsupported quantization type: {self.quantization}")
|
| 240 |
|
| 241 |
quantize_(self.model, quant_config)
|
| 242 |
-
logger.info(f"DiT quantized with: {self.quantization}")
|
| 243 |
|
| 244 |
|
| 245 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
|
@@ -260,7 +255,7 @@ class AceStepHandler:
|
|
| 260 |
if os.path.exists(vae_checkpoint_path):
|
| 261 |
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 262 |
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
|
| 263 |
-
vae_dtype =
|
| 264 |
if not self.offload_to_cpu:
|
| 265 |
self.vae = self.vae.to(device).to(vae_dtype)
|
| 266 |
else:
|
|
@@ -302,6 +297,7 @@ class AceStepHandler:
|
|
| 302 |
|
| 303 |
except Exception as e:
|
| 304 |
error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
|
|
| 305 |
return error_msg, False
|
| 306 |
|
| 307 |
@contextmanager
|
|
@@ -326,7 +322,7 @@ class AceStepHandler:
|
|
| 326 |
try:
|
| 327 |
param = next(model.parameters())
|
| 328 |
if param.device.type == "cpu":
|
| 329 |
-
logger.info(f"Moving {model_name} to {self.device} (persistent)")
|
| 330 |
model.to(self.device).to(self.dtype)
|
| 331 |
if hasattr(self, "silence_latent"):
|
| 332 |
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
|
@@ -341,10 +337,10 @@ class AceStepHandler:
|
|
| 341 |
return
|
| 342 |
|
| 343 |
# Load to GPU
|
| 344 |
-
logger.info(f"Loading {model_name} to {self.device}")
|
| 345 |
start_time = time.time()
|
| 346 |
if model_name == "vae":
|
| 347 |
-
vae_dtype =
|
| 348 |
model.to(self.device).to(vae_dtype)
|
| 349 |
else:
|
| 350 |
model.to(self.device).to(self.dtype)
|
|
@@ -354,13 +350,13 @@ class AceStepHandler:
|
|
| 354 |
|
| 355 |
load_time = time.time() - start_time
|
| 356 |
self.current_offload_cost += load_time
|
| 357 |
-
logger.info(f"Loaded {model_name} to {self.device} in {load_time:.4f}s")
|
| 358 |
|
| 359 |
try:
|
| 360 |
yield
|
| 361 |
finally:
|
| 362 |
# Offload to CPU
|
| 363 |
-
logger.info(f"Offloading {model_name} to CPU")
|
| 364 |
start_time = time.time()
|
| 365 |
model.to("cpu")
|
| 366 |
|
|
@@ -370,7 +366,7 @@ class AceStepHandler:
|
|
| 370 |
torch.cuda.empty_cache()
|
| 371 |
offload_time = time.time() - start_time
|
| 372 |
self.current_offload_cost += offload_time
|
| 373 |
-
logger.info(f"Offloaded {model_name} to CPU in {offload_time:.4f}s")
|
| 374 |
|
| 375 |
def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 376 |
"""Process target audio"""
|
|
@@ -386,23 +382,12 @@ class AceStepHandler:
|
|
| 386 |
else:
|
| 387 |
audio = torch.from_numpy(audio_np.T)
|
| 388 |
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
audio = audio[:2]
|
| 393 |
-
|
| 394 |
-
# Resample if needed
|
| 395 |
-
if sr != 48000:
|
| 396 |
-
import torch.nn.functional as F
|
| 397 |
-
ratio = 48000 / sr
|
| 398 |
-
new_length = int(audio.shape[-1] * ratio)
|
| 399 |
-
audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
|
| 400 |
-
|
| 401 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 402 |
|
| 403 |
return audio
|
| 404 |
except Exception as e:
|
| 405 |
-
logger.
|
| 406 |
return None
|
| 407 |
|
| 408 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
|
@@ -411,7 +396,8 @@ class AceStepHandler:
|
|
| 411 |
return []
|
| 412 |
try:
|
| 413 |
return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
| 414 |
-
except Exception:
|
|
|
|
| 415 |
return []
|
| 416 |
|
| 417 |
def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
|
|
@@ -538,9 +524,7 @@ class AceStepHandler:
|
|
| 538 |
)
|
| 539 |
"""
|
| 540 |
# Align instruction formatting with _prepare_batch
|
| 541 |
-
final_instruction = instruction or DEFAULT_DIT_INSTRUCTION
|
| 542 |
-
if not final_instruction.endswith(":"):
|
| 543 |
-
final_instruction = final_instruction + ":"
|
| 544 |
|
| 545 |
# Extract caption and language from metas if available (from LM CoT output)
|
| 546 |
# Fallback to user-provided values if not in metas
|
|
@@ -571,7 +555,7 @@ class AceStepHandler:
|
|
| 571 |
|
| 572 |
parsed_meta = self._parse_metas([metas])[0]
|
| 573 |
caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
|
| 574 |
-
lyrics_input =
|
| 575 |
return caption_input, lyrics_input
|
| 576 |
|
| 577 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -614,7 +598,7 @@ class AceStepHandler:
|
|
| 614 |
return match.group(1).strip()
|
| 615 |
return caption
|
| 616 |
except Exception as e:
|
| 617 |
-
logger.
|
| 618 |
return caption
|
| 619 |
|
| 620 |
def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
|
|
@@ -638,7 +622,8 @@ class AceStepHandler:
|
|
| 638 |
else:
|
| 639 |
try:
|
| 640 |
seed_list.append(int(float(s)))
|
| 641 |
-
except (ValueError, TypeError):
|
|
|
|
| 642 |
seed_list.append(-1)
|
| 643 |
elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
|
| 644 |
# If seed is None or negative, use -1 for all items
|
|
@@ -679,7 +664,176 @@ class AceStepHandler:
|
|
| 679 |
return actual_seed_list, seed_value_for_ui
|
| 680 |
|
| 681 |
def prepare_metadata(self, bpm, key_scale, time_signature):
|
| 682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
metadata_dict = {}
|
| 684 |
if bpm:
|
| 685 |
metadata_dict["bpm"] = bpm
|
|
@@ -695,10 +849,12 @@ class AceStepHandler:
|
|
| 695 |
metadata_dict["timesignature"] = time_signature
|
| 696 |
else:
|
| 697 |
metadata_dict["timesignature"] = "N/A"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
return metadata_dict
|
| 699 |
-
|
| 700 |
-
def is_silence(self, audio):
|
| 701 |
-
return torch.all(audio.abs() < 1e-6)
|
| 702 |
|
| 703 |
def generate_instruction(
|
| 704 |
self,
|
|
@@ -745,23 +901,12 @@ class AceStepHandler:
|
|
| 745 |
# Load audio file
|
| 746 |
audio, sr = torchaudio.load(audio_file)
|
| 747 |
|
| 748 |
-
logger.
|
| 749 |
-
logger.
|
| 750 |
-
logger.
|
| 751 |
-
|
| 752 |
-
# Convert to stereo (duplicate channel if mono)
|
| 753 |
-
if audio.shape[0] == 1:
|
| 754 |
-
audio = torch.cat([audio, audio], dim=0)
|
| 755 |
|
| 756 |
-
#
|
| 757 |
-
audio = audio
|
| 758 |
-
|
| 759 |
-
# Resample to 48kHz if needed
|
| 760 |
-
if sr != 48000:
|
| 761 |
-
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
| 762 |
-
|
| 763 |
-
# Clamp values to [-1.0, 1.0]
|
| 764 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 765 |
|
| 766 |
is_silence = self.is_silence(audio)
|
| 767 |
if is_silence:
|
|
@@ -800,7 +945,7 @@ class AceStepHandler:
|
|
| 800 |
return audio
|
| 801 |
|
| 802 |
except Exception as e:
|
| 803 |
-
logger.
|
| 804 |
return None
|
| 805 |
|
| 806 |
def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
|
|
@@ -811,24 +956,13 @@ class AceStepHandler:
|
|
| 811 |
# Load audio file
|
| 812 |
audio, sr = torchaudio.load(audio_file)
|
| 813 |
|
| 814 |
-
#
|
| 815 |
-
|
| 816 |
-
audio = torch.cat([audio, audio], dim=0)
|
| 817 |
-
|
| 818 |
-
# Keep only first 2 channels
|
| 819 |
-
audio = audio[:2]
|
| 820 |
-
|
| 821 |
-
# Resample to 48kHz if needed
|
| 822 |
-
if sr != 48000:
|
| 823 |
-
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
| 824 |
-
|
| 825 |
-
# Clamp values to [-1.0, 1.0]
|
| 826 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 827 |
|
| 828 |
return audio
|
| 829 |
|
| 830 |
except Exception as e:
|
| 831 |
-
logger.
|
| 832 |
return None
|
| 833 |
|
| 834 |
def convert_src_audio_to_codes(self, audio_file) -> str:
|
|
@@ -856,19 +990,12 @@ class AceStepHandler:
|
|
| 856 |
# Encode audio to latents using VAE
|
| 857 |
with torch.no_grad():
|
| 858 |
with self._load_model_context("vae"):
|
| 859 |
-
# Prepare audio for VAE: [channels, samples] -> [1, channels, samples]
|
| 860 |
-
vae_input = processed_audio.unsqueeze(0).to(self.device).to(self.vae.dtype)
|
| 861 |
-
|
| 862 |
# Check if audio is silence
|
| 863 |
-
if self.is_silence(
|
| 864 |
return "❌ Audio file appears to be silent"
|
| 865 |
|
| 866 |
-
# Encode to latents
|
| 867 |
-
latents = self.
|
| 868 |
-
# Cast back to model dtype
|
| 869 |
-
latents = latents.to(self.dtype)
|
| 870 |
-
# Transpose: [1, d, T] -> [1, T, d] -> [T, d]
|
| 871 |
-
latents = latents.squeeze(0).transpose(0, 1) # [T, d]
|
| 872 |
|
| 873 |
# Create attention mask for latents
|
| 874 |
attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
|
|
@@ -893,7 +1020,7 @@ class AceStepHandler:
|
|
| 893 |
|
| 894 |
except Exception as e:
|
| 895 |
error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
|
| 896 |
-
logger.
|
| 897 |
return error_msg
|
| 898 |
|
| 899 |
def prepare_batch_data(
|
|
@@ -922,26 +1049,7 @@ class AceStepHandler:
|
|
| 922 |
calculated_duration = audio_duration
|
| 923 |
|
| 924 |
# Build metadata dict - use "N/A" as default for empty fields
|
| 925 |
-
metadata_dict =
|
| 926 |
-
if bpm:
|
| 927 |
-
metadata_dict["bpm"] = bpm
|
| 928 |
-
else:
|
| 929 |
-
metadata_dict["bpm"] = "N/A"
|
| 930 |
-
|
| 931 |
-
if key_scale.strip():
|
| 932 |
-
metadata_dict["keyscale"] = key_scale
|
| 933 |
-
else:
|
| 934 |
-
metadata_dict["keyscale"] = "N/A"
|
| 935 |
-
|
| 936 |
-
if time_signature.strip() and time_signature != "N/A" and time_signature:
|
| 937 |
-
metadata_dict["timesignature"] = time_signature
|
| 938 |
-
else:
|
| 939 |
-
metadata_dict["timesignature"] = "N/A"
|
| 940 |
-
|
| 941 |
-
# Add duration to metadata if available (inference service format: "30 seconds")
|
| 942 |
-
if calculated_duration is not None:
|
| 943 |
-
metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
|
| 944 |
-
# If duration not set, inference service will use default (30 seconds)
|
| 945 |
|
| 946 |
# Format metadata - inference service accepts dict and will convert to string
|
| 947 |
# Create a copy for each batch item (in case we modify it)
|
|
@@ -977,7 +1085,7 @@ class AceStepHandler:
|
|
| 977 |
target_wavs = torch.zeros(2, frames)
|
| 978 |
return target_wavs
|
| 979 |
except Exception as e:
|
| 980 |
-
logger.
|
| 981 |
# Fallback to 30 seconds if error
|
| 982 |
return torch.zeros(2, 30 * 48000)
|
| 983 |
|
|
@@ -1158,16 +1266,8 @@ class AceStepHandler:
|
|
| 1158 |
"""
|
| 1159 |
batch_size = len(captions)
|
| 1160 |
|
| 1161 |
-
#
|
| 1162 |
-
|
| 1163 |
-
audio_code_hints = [None] * batch_size
|
| 1164 |
-
elif len(audio_code_hints) != batch_size:
|
| 1165 |
-
if len(audio_code_hints) == 1:
|
| 1166 |
-
audio_code_hints = audio_code_hints * batch_size
|
| 1167 |
-
else:
|
| 1168 |
-
audio_code_hints = audio_code_hints[:batch_size]
|
| 1169 |
-
while len(audio_code_hints) < batch_size:
|
| 1170 |
-
audio_code_hints.append(None)
|
| 1171 |
|
| 1172 |
for ii, refer_audio_list in enumerate(refer_audios):
|
| 1173 |
if isinstance(refer_audio_list, list):
|
|
@@ -1179,17 +1279,6 @@ class AceStepHandler:
|
|
| 1179 |
if vocal_languages is None:
|
| 1180 |
vocal_languages = self._create_fallback_vocal_languages(batch_size)
|
| 1181 |
|
| 1182 |
-
# Normalize audio_code_hints to batch list
|
| 1183 |
-
if audio_code_hints is None:
|
| 1184 |
-
audio_code_hints = [None] * batch_size
|
| 1185 |
-
elif not isinstance(audio_code_hints, list):
|
| 1186 |
-
audio_code_hints = [audio_code_hints] * batch_size
|
| 1187 |
-
elif len(audio_code_hints) == 1 and batch_size > 1:
|
| 1188 |
-
audio_code_hints = audio_code_hints * batch_size
|
| 1189 |
-
else:
|
| 1190 |
-
audio_code_hints = (audio_code_hints + [None] * batch_size)[:batch_size]
|
| 1191 |
-
audio_code_hints = [hint if isinstance(hint, str) and hint.strip() else None for hint in audio_code_hints]
|
| 1192 |
-
|
| 1193 |
# Parse metas with fallbacks
|
| 1194 |
parsed_metas = self._parse_metas(metas)
|
| 1195 |
|
|
@@ -1223,13 +1312,9 @@ class AceStepHandler:
|
|
| 1223 |
expected_latent_length = current_wav.shape[-1] // 1920
|
| 1224 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1225 |
else:
|
| 1226 |
-
#
|
| 1227 |
logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
|
| 1228 |
-
|
| 1229 |
-
target_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1230 |
-
# Cast back to model dtype
|
| 1231 |
-
target_latent = target_latent.to(self.dtype)
|
| 1232 |
-
target_latent = target_latent.squeeze(0).transpose(0, 1)
|
| 1233 |
target_latents_list.append(target_latent)
|
| 1234 |
latent_lengths.append(target_latent.shape[0])
|
| 1235 |
|
|
@@ -1268,18 +1353,7 @@ class AceStepHandler:
|
|
| 1268 |
|
| 1269 |
# Process instructions early so we can use them for task type detection
|
| 1270 |
# Use custom instructions if provided, otherwise use default
|
| 1271 |
-
|
| 1272 |
-
instructions = [DEFAULT_DIT_INSTRUCTION] * batch_size
|
| 1273 |
-
|
| 1274 |
-
# Ensure instructions list has the same length as batch_size
|
| 1275 |
-
if len(instructions) != batch_size:
|
| 1276 |
-
if len(instructions) == 1:
|
| 1277 |
-
instructions = instructions * batch_size
|
| 1278 |
-
else:
|
| 1279 |
-
# Pad or truncate to match batch_size
|
| 1280 |
-
instructions = instructions[:batch_size]
|
| 1281 |
-
while len(instructions) < batch_size:
|
| 1282 |
-
instructions.append(DEFAULT_DIT_INSTRUCTION)
|
| 1283 |
|
| 1284 |
# Generate chunk_masks and spans based on repainting parameters
|
| 1285 |
# Also determine if this is a cover task (target audio provided without repainting)
|
|
@@ -1428,6 +1502,10 @@ class AceStepHandler:
|
|
| 1428 |
else:
|
| 1429 |
precomputed_lm_hints_25Hz = None
|
| 1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1431 |
# Format text_inputs
|
| 1432 |
text_inputs = []
|
| 1433 |
text_token_idss = []
|
|
@@ -1437,26 +1515,10 @@ class AceStepHandler:
|
|
| 1437 |
|
| 1438 |
for i in range(batch_size):
|
| 1439 |
# Use custom instruction for this batch item
|
| 1440 |
-
instruction = instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION
|
| 1441 |
-
|
| 1442 |
-
|
| 1443 |
-
|
| 1444 |
-
|
| 1445 |
-
# Extract caption and language from metas if available (from LM CoT output)
|
| 1446 |
-
# Fallback to user-provided values if not in metas
|
| 1447 |
-
actual_caption = captions[i]
|
| 1448 |
-
actual_language = vocal_languages[i]
|
| 1449 |
-
|
| 1450 |
-
# Check if metas contains caption/language from LM CoT
|
| 1451 |
-
if i < len(parsed_metas) and parsed_metas[i]:
|
| 1452 |
-
meta_dict = parsed_metas[i]
|
| 1453 |
-
if isinstance(meta_dict, dict):
|
| 1454 |
-
# Extract caption from metas if available
|
| 1455 |
-
if 'caption' in meta_dict and meta_dict['caption']:
|
| 1456 |
-
actual_caption = str(meta_dict['caption'])
|
| 1457 |
-
# Extract language from metas if available
|
| 1458 |
-
if 'language' in meta_dict and meta_dict['language']:
|
| 1459 |
-
actual_language = str(meta_dict['language'])
|
| 1460 |
|
| 1461 |
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1462 |
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
|
@@ -1473,7 +1535,7 @@ class AceStepHandler:
|
|
| 1473 |
text_attention_mask = text_inputs_dict.attention_mask[0].bool()
|
| 1474 |
|
| 1475 |
# Format and tokenize lyrics (using LM-generated language if available)
|
| 1476 |
-
lyrics_text =
|
| 1477 |
lyrics_inputs_dict = self.text_tokenizer(
|
| 1478 |
lyrics_text,
|
| 1479 |
padding="longest",
|
|
@@ -1495,36 +1557,12 @@ class AceStepHandler:
|
|
| 1495 |
|
| 1496 |
# Pad tokenized sequences
|
| 1497 |
max_text_length = max(len(seq) for seq in text_token_idss)
|
| 1498 |
-
padded_text_token_idss =
|
| 1499 |
-
|
| 1500 |
-
seq, (0, max_text_length - len(seq)), 'constant',
|
| 1501 |
-
self.text_tokenizer.pad_token_id
|
| 1502 |
-
)
|
| 1503 |
-
for seq in text_token_idss
|
| 1504 |
-
])
|
| 1505 |
-
|
| 1506 |
-
padded_text_attention_masks = torch.stack([
|
| 1507 |
-
torch.nn.functional.pad(
|
| 1508 |
-
seq, (0, max_text_length - len(seq)), 'constant', 0
|
| 1509 |
-
)
|
| 1510 |
-
for seq in text_attention_masks
|
| 1511 |
-
])
|
| 1512 |
|
| 1513 |
max_lyric_length = max(len(seq) for seq in lyric_token_idss)
|
| 1514 |
-
padded_lyric_token_idss =
|
| 1515 |
-
|
| 1516 |
-
seq, (0, max_lyric_length - len(seq)), 'constant',
|
| 1517 |
-
self.text_tokenizer.pad_token_id
|
| 1518 |
-
)
|
| 1519 |
-
for seq in lyric_token_idss
|
| 1520 |
-
])
|
| 1521 |
-
|
| 1522 |
-
padded_lyric_attention_masks = torch.stack([
|
| 1523 |
-
torch.nn.functional.pad(
|
| 1524 |
-
seq, (0, max_lyric_length - len(seq)), 'constant', 0
|
| 1525 |
-
)
|
| 1526 |
-
for seq in lyric_attention_masks
|
| 1527 |
-
])
|
| 1528 |
|
| 1529 |
padded_non_cover_text_input_ids = None
|
| 1530 |
padded_non_cover_text_attention_masks = None
|
|
@@ -1533,14 +1571,10 @@ class AceStepHandler:
|
|
| 1533 |
non_cover_text_attention_masks = []
|
| 1534 |
for i in range(batch_size):
|
| 1535 |
# Use custom instruction for this batch item
|
| 1536 |
-
instruction = DEFAULT_DIT_INSTRUCTION
|
| 1537 |
|
| 1538 |
# Extract caption from metas if available (from LM CoT output)
|
| 1539 |
-
actual_caption =
|
| 1540 |
-
if i < len(parsed_metas) and parsed_metas[i]:
|
| 1541 |
-
meta_dict = parsed_metas[i]
|
| 1542 |
-
if isinstance(meta_dict, dict) and 'caption' in meta_dict and meta_dict['caption']:
|
| 1543 |
-
actual_caption = str(meta_dict['caption'])
|
| 1544 |
|
| 1545 |
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1546 |
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
|
@@ -1558,19 +1592,8 @@ class AceStepHandler:
|
|
| 1558 |
non_cover_text_input_ids.append(text_token_ids)
|
| 1559 |
non_cover_text_attention_masks.append(non_cover_text_attention_mask)
|
| 1560 |
|
| 1561 |
-
padded_non_cover_text_input_ids =
|
| 1562 |
-
|
| 1563 |
-
seq, (0, max_text_length - len(seq)), 'constant',
|
| 1564 |
-
self.text_tokenizer.pad_token_id
|
| 1565 |
-
)
|
| 1566 |
-
for seq in non_cover_text_input_ids
|
| 1567 |
-
])
|
| 1568 |
-
padded_non_cover_text_attention_masks = torch.stack([
|
| 1569 |
-
torch.nn.functional.pad(
|
| 1570 |
-
seq, (0, max_text_length - len(seq)), 'constant', 0
|
| 1571 |
-
)
|
| 1572 |
-
for seq in non_cover_text_attention_masks
|
| 1573 |
-
])
|
| 1574 |
|
| 1575 |
if audio_cover_strength < 1.0:
|
| 1576 |
assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
|
|
@@ -1804,7 +1827,7 @@ class AceStepHandler:
|
|
| 1804 |
if self.config.is_turbo:
|
| 1805 |
# Limit inference steps to maximum 8
|
| 1806 |
if infer_steps > 8:
|
| 1807 |
-
logger.warning(f"dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
|
| 1808 |
infer_steps = 8
|
| 1809 |
# CFG parameters are not adjustable for dmd_gan (they will be ignored)
|
| 1810 |
# Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
|
|
@@ -1827,30 +1850,12 @@ class AceStepHandler:
|
|
| 1827 |
if isinstance(repainting_end, (int, float)):
|
| 1828 |
repainting_end = [repainting_end]
|
| 1829 |
|
| 1830 |
-
# Convert instructions to list
|
| 1831 |
-
if isinstance(instructions, str):
|
| 1832 |
-
instructions = [instructions]
|
| 1833 |
-
elif instructions is None:
|
| 1834 |
-
instructions = None
|
| 1835 |
-
|
| 1836 |
-
# Convert audio_code_hints to list
|
| 1837 |
-
if isinstance(audio_code_hints, str):
|
| 1838 |
-
audio_code_hints = [audio_code_hints]
|
| 1839 |
-
elif audio_code_hints is None:
|
| 1840 |
-
audio_code_hints = None
|
| 1841 |
-
|
| 1842 |
# Get batch size from captions
|
| 1843 |
batch_size = len(captions)
|
| 1844 |
|
| 1845 |
-
#
|
| 1846 |
-
if
|
| 1847 |
-
|
| 1848 |
-
if len(audio_code_hints) == 1:
|
| 1849 |
-
audio_code_hints = audio_code_hints * batch_size
|
| 1850 |
-
else:
|
| 1851 |
-
audio_code_hints = audio_code_hints[:batch_size]
|
| 1852 |
-
while len(audio_code_hints) < batch_size:
|
| 1853 |
-
audio_code_hints.append(None)
|
| 1854 |
|
| 1855 |
# Convert seed to list format
|
| 1856 |
if seed is None:
|
|
@@ -1947,6 +1952,14 @@ class AceStepHandler:
|
|
| 1947 |
logger.info("[service_generate] Generating audio...")
|
| 1948 |
with self._load_model_context("model"):
|
| 1949 |
outputs = self.model.generate_audio(**generate_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1950 |
return outputs
|
| 1951 |
|
| 1952 |
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
|
@@ -2042,25 +2055,34 @@ class AceStepHandler:
|
|
| 2042 |
use_adg: bool = False,
|
| 2043 |
cfg_interval_start: float = 0.0,
|
| 2044 |
cfg_interval_end: float = 1.0,
|
| 2045 |
-
audio_format: str = "mp3",
|
| 2046 |
-
lm_temperature: float = 0.6,
|
| 2047 |
use_tiled_decode: bool = True,
|
| 2048 |
progress=None
|
| 2049 |
-
) ->
|
| 2050 |
"""
|
| 2051 |
Main interface for music generation
|
| 2052 |
|
| 2053 |
Returns:
|
| 2054 |
-
|
| 2055 |
-
|
| 2056 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2057 |
"""
|
| 2058 |
if progress is None:
|
| 2059 |
def progress(*args, **kwargs):
|
| 2060 |
pass
|
| 2061 |
|
| 2062 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 2063 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2064 |
|
| 2065 |
def _has_audio_codes(v: Union[str, List[str]]) -> bool:
|
| 2066 |
if isinstance(v, list):
|
|
@@ -2191,8 +2213,8 @@ class AceStepHandler:
|
|
| 2191 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 2192 |
time_costs = outputs["time_costs"]
|
| 2193 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2194 |
-
logger.
|
| 2195 |
-
logger.
|
| 2196 |
if progress:
|
| 2197 |
progress(0.8, desc="Decoding audio...")
|
| 2198 |
logger.info("[generate_music] Decoding latents with VAE...")
|
|
@@ -2221,30 +2243,19 @@ class AceStepHandler:
|
|
| 2221 |
# Update offload cost one last time to include VAE offloading
|
| 2222 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2223 |
|
| 2224 |
-
logger.info("[generate_music] VAE decode completed.
|
| 2225 |
if progress:
|
| 2226 |
-
progress(0.9, desc="
|
| 2227 |
|
| 2228 |
-
#
|
| 2229 |
-
|
| 2230 |
-
|
| 2231 |
-
|
| 2232 |
|
| 2233 |
-
saved_files = []
|
| 2234 |
-
saved_uuids = [] # Store UUIDs for each file
|
| 2235 |
for i in range(actual_batch_size):
|
| 2236 |
-
#
|
| 2237 |
-
|
| 2238 |
-
|
| 2239 |
-
# Convert to numpy: [channels, samples] -> [samples, channels]
|
| 2240 |
-
audio_np = pred_wavs[i].cpu().float().numpy().T
|
| 2241 |
-
sf.write(audio_file, audio_np, self.sample_rate)
|
| 2242 |
-
saved_files.append(audio_file)
|
| 2243 |
-
saved_uuids.append(file_uuid)
|
| 2244 |
-
|
| 2245 |
-
# Prepare return values
|
| 2246 |
-
first_audio = saved_files[0] if len(saved_files) > 0 else None
|
| 2247 |
-
second_audio = saved_files[1] if len(saved_files) > 1 else None
|
| 2248 |
|
| 2249 |
# Format time costs if available
|
| 2250 |
time_costs_str = ""
|
|
@@ -2262,34 +2273,55 @@ class AceStepHandler:
|
|
| 2262 |
|
| 2263 |
**Seeds:** {seed_value_for_ui}
|
| 2264 |
**Steps:** {inference_steps}
|
| 2265 |
-
**
|
| 2266 |
status_message = f"✅ Generation completed successfully!"
|
| 2267 |
-
logger.info(f"[generate_music] Done! Generated {len(
|
| 2268 |
-
|
| 2269 |
-
#
|
| 2270 |
-
|
| 2271 |
-
|
| 2272 |
-
|
| 2273 |
-
|
| 2274 |
-
|
| 2275 |
-
|
| 2276 |
-
|
| 2277 |
-
|
| 2278 |
-
|
| 2279 |
-
|
| 2280 |
-
|
| 2281 |
-
|
| 2282 |
-
|
| 2283 |
-
|
| 2284 |
-
|
| 2285 |
-
|
| 2286 |
-
|
| 2287 |
-
|
| 2288 |
-
|
| 2289 |
-
|
| 2290 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2291 |
|
| 2292 |
except Exception as e:
|
| 2293 |
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
|
| 2294 |
-
|
| 2295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import re
|
| 11 |
import random
|
| 12 |
import uuid
|
| 13 |
+
import hashlib
|
| 14 |
+
import json
|
| 15 |
from contextlib import contextmanager
|
| 16 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 17 |
|
|
|
|
| 39 |
class AceStepHandler:
|
| 40 |
"""ACE-Step Business Logic Handler"""
|
| 41 |
|
| 42 |
+
def __init__(self):
|
| 43 |
self.model = None
|
| 44 |
self.config = None
|
| 45 |
self.device = "cpu"
|
| 46 |
self.dtype = torch.float32 # Will be set based on device in initialize_service
|
| 47 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# VAE for audio encoding/decoding
|
| 49 |
self.vae = None
|
| 50 |
|
|
|
|
| 79 |
def get_available_checkpoints(self) -> str:
|
| 80 |
"""Return project root directory path"""
|
| 81 |
# Get project root (handler.py is in acestep/, so go up two levels to project root)
|
| 82 |
+
project_root = self._get_project_root()
|
|
|
|
| 83 |
# default checkpoints
|
| 84 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 85 |
if os.path.exists(checkpoint_dir):
|
|
|
|
| 90 |
def get_available_acestep_v15_models(self) -> List[str]:
|
| 91 |
"""Scan and return all model directory names starting with 'acestep-v15-'"""
|
| 92 |
# Get project root
|
| 93 |
+
project_root = self._get_project_root()
|
|
|
|
| 94 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 95 |
|
| 96 |
models = []
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
# Auto-detect project root (independent of passed project_root parameter)
|
| 170 |
+
actual_project_root = self._get_project_root()
|
|
|
|
| 171 |
checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
|
| 172 |
|
| 173 |
# 1. Load main model
|
|
|
|
| 182 |
attn_implementation = "sdpa"
|
| 183 |
|
| 184 |
try:
|
| 185 |
+
logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}")
|
| 186 |
self.model = AutoModel.from_pretrained(
|
| 187 |
acestep_v15_checkpoint_path,
|
| 188 |
trust_remote_code=True,
|
|
|
|
| 190 |
dtype="bfloat16"
|
| 191 |
)
|
| 192 |
except Exception as e:
|
| 193 |
+
logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}")
|
| 194 |
if attn_implementation == "sdpa":
|
| 195 |
+
logger.info("[initialize_service] Falling back to eager attention")
|
| 196 |
attn_implementation = "eager"
|
| 197 |
self.model = AutoModel.from_pretrained(
|
| 198 |
acestep_v15_checkpoint_path,
|
|
|
|
| 210 |
else:
|
| 211 |
# If offload_to_cpu is True, check if we should keep DiT on GPU
|
| 212 |
if not self.offload_dit_to_cpu:
|
| 213 |
+
logger.info(f"[initialize_service] Keeping main model on {device} (persistent)")
|
| 214 |
self.model = self.model.to(device).to(self.dtype)
|
| 215 |
else:
|
| 216 |
self.model = self.model.to("cpu").to(self.dtype)
|
|
|
|
| 234 |
raise ValueError(f"Unsupported quantization type: {self.quantization}")
|
| 235 |
|
| 236 |
quantize_(self.model, quant_config)
|
| 237 |
+
logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
|
| 238 |
|
| 239 |
|
| 240 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
|
|
|
| 255 |
if os.path.exists(vae_checkpoint_path):
|
| 256 |
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 257 |
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
|
| 258 |
+
vae_dtype = self._get_vae_dtype(device)
|
| 259 |
if not self.offload_to_cpu:
|
| 260 |
self.vae = self.vae.to(device).to(vae_dtype)
|
| 261 |
else:
|
|
|
|
| 297 |
|
| 298 |
except Exception as e:
|
| 299 |
error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 300 |
+
logger.exception("[initialize_service] Error initializing model")
|
| 301 |
return error_msg, False
|
| 302 |
|
| 303 |
@contextmanager
|
|
|
|
| 322 |
try:
|
| 323 |
param = next(model.parameters())
|
| 324 |
if param.device.type == "cpu":
|
| 325 |
+
logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
|
| 326 |
model.to(self.device).to(self.dtype)
|
| 327 |
if hasattr(self, "silence_latent"):
|
| 328 |
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
|
|
|
| 337 |
return
|
| 338 |
|
| 339 |
# Load to GPU
|
| 340 |
+
logger.info(f"[_load_model_context] Loading {model_name} to {self.device}")
|
| 341 |
start_time = time.time()
|
| 342 |
if model_name == "vae":
|
| 343 |
+
vae_dtype = self._get_vae_dtype()
|
| 344 |
model.to(self.device).to(vae_dtype)
|
| 345 |
else:
|
| 346 |
model.to(self.device).to(self.dtype)
|
|
|
|
| 350 |
|
| 351 |
load_time = time.time() - start_time
|
| 352 |
self.current_offload_cost += load_time
|
| 353 |
+
logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s")
|
| 354 |
|
| 355 |
try:
|
| 356 |
yield
|
| 357 |
finally:
|
| 358 |
# Offload to CPU
|
| 359 |
+
logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
|
| 360 |
start_time = time.time()
|
| 361 |
model.to("cpu")
|
| 362 |
|
|
|
|
| 366 |
torch.cuda.empty_cache()
|
| 367 |
offload_time = time.time() - start_time
|
| 368 |
self.current_offload_cost += offload_time
|
| 369 |
+
logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s")
|
| 370 |
|
| 371 |
def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 372 |
"""Process target audio"""
|
|
|
|
| 382 |
else:
|
| 383 |
audio = torch.from_numpy(audio_np.T)
|
| 384 |
|
| 385 |
+
# Normalize to stereo 48kHz
|
| 386 |
+
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
return audio
|
| 389 |
except Exception as e:
|
| 390 |
+
logger.exception("[process_target_audio] Error processing target audio")
|
| 391 |
return None
|
| 392 |
|
| 393 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
|
|
|
| 396 |
return []
|
| 397 |
try:
|
| 398 |
return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
| 399 |
+
except Exception as e:
|
| 400 |
+
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
| 401 |
return []
|
| 402 |
|
| 403 |
def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
|
|
|
|
| 524 |
)
|
| 525 |
"""
|
| 526 |
# Align instruction formatting with _prepare_batch
|
| 527 |
+
final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION)
|
|
|
|
|
|
|
| 528 |
|
| 529 |
# Extract caption and language from metas if available (from LM CoT output)
|
| 530 |
# Fallback to user-provided values if not in metas
|
|
|
|
| 555 |
|
| 556 |
parsed_meta = self._parse_metas([metas])[0]
|
| 557 |
caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
|
| 558 |
+
lyrics_input = self._format_lyrics(lyrics, actual_language)
|
| 559 |
return caption_input, lyrics_input
|
| 560 |
|
| 561 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 598 |
return match.group(1).strip()
|
| 599 |
return caption
|
| 600 |
except Exception as e:
|
| 601 |
+
logger.exception("[extract_caption_from_sft_format] Error extracting caption")
|
| 602 |
return caption
|
| 603 |
|
| 604 |
def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
|
|
|
|
| 622 |
else:
|
| 623 |
try:
|
| 624 |
seed_list.append(int(float(s)))
|
| 625 |
+
except (ValueError, TypeError) as e:
|
| 626 |
+
logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {e}")
|
| 627 |
seed_list.append(-1)
|
| 628 |
elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
|
| 629 |
# If seed is None or negative, use -1 for all items
|
|
|
|
| 664 |
return actual_seed_list, seed_value_for_ui
|
| 665 |
|
| 666 |
def prepare_metadata(self, bpm, key_scale, time_signature):
|
| 667 |
+
"""Build metadata dict - use "N/A" as default for empty fields."""
|
| 668 |
+
return self._build_metadata_dict(bpm, key_scale, time_signature)
|
| 669 |
+
|
| 670 |
+
def is_silence(self, audio):
|
| 671 |
+
return torch.all(audio.abs() < 1e-6)
|
| 672 |
+
|
| 673 |
+
def _get_project_root(self) -> str:
|
| 674 |
+
"""Get project root directory path."""
|
| 675 |
+
current_file = os.path.abspath(__file__)
|
| 676 |
+
return os.path.dirname(os.path.dirname(current_file))
|
| 677 |
+
|
| 678 |
+
def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype:
|
| 679 |
+
"""Get VAE dtype based on device."""
|
| 680 |
+
device = device or self.device
|
| 681 |
+
return torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
|
| 682 |
+
|
| 683 |
+
def _format_instruction(self, instruction: str) -> str:
|
| 684 |
+
"""Format instruction to ensure it ends with colon."""
|
| 685 |
+
if not instruction.endswith(":"):
|
| 686 |
+
instruction = instruction + ":"
|
| 687 |
+
return instruction
|
| 688 |
+
|
| 689 |
+
def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor:
|
| 690 |
+
"""
|
| 691 |
+
Normalize audio to stereo 48kHz format.
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
audio: Audio tensor [channels, samples] or [samples]
|
| 695 |
+
sr: Sample rate
|
| 696 |
+
|
| 697 |
+
Returns:
|
| 698 |
+
Normalized audio tensor [2, samples] at 48kHz
|
| 699 |
+
"""
|
| 700 |
+
# Convert to stereo (duplicate channel if mono)
|
| 701 |
+
if audio.shape[0] == 1:
|
| 702 |
+
audio = torch.cat([audio, audio], dim=0)
|
| 703 |
+
|
| 704 |
+
# Keep only first 2 channels
|
| 705 |
+
audio = audio[:2]
|
| 706 |
+
|
| 707 |
+
# Resample to 48kHz if needed
|
| 708 |
+
if sr != 48000:
|
| 709 |
+
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
| 710 |
+
|
| 711 |
+
# Clamp values to [-1.0, 1.0]
|
| 712 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 713 |
+
|
| 714 |
+
return audio
|
| 715 |
+
|
| 716 |
+
def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
|
| 717 |
+
"""Normalize audio_code_hints to list of correct length."""
|
| 718 |
+
if audio_code_hints is None:
|
| 719 |
+
normalized = [None] * batch_size
|
| 720 |
+
elif isinstance(audio_code_hints, str):
|
| 721 |
+
normalized = [audio_code_hints] * batch_size
|
| 722 |
+
elif len(audio_code_hints) == 1 and batch_size > 1:
|
| 723 |
+
normalized = audio_code_hints * batch_size
|
| 724 |
+
elif len(audio_code_hints) != batch_size:
|
| 725 |
+
# Pad or truncate to match batch_size
|
| 726 |
+
normalized = list(audio_code_hints[:batch_size])
|
| 727 |
+
while len(normalized) < batch_size:
|
| 728 |
+
normalized.append(None)
|
| 729 |
+
else:
|
| 730 |
+
normalized = list(audio_code_hints)
|
| 731 |
+
|
| 732 |
+
# Clean up: convert empty strings to None
|
| 733 |
+
normalized = [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized]
|
| 734 |
+
return normalized
|
| 735 |
+
|
| 736 |
+
def _normalize_instructions(self, instructions: Optional[Union[str, List[str]]], batch_size: int, default: Optional[str] = None) -> List[str]:
|
| 737 |
+
"""Normalize instructions to list of correct length."""
|
| 738 |
+
if instructions is None:
|
| 739 |
+
default_instruction = default or DEFAULT_DIT_INSTRUCTION
|
| 740 |
+
return [default_instruction] * batch_size
|
| 741 |
+
elif isinstance(instructions, str):
|
| 742 |
+
return [instructions] * batch_size
|
| 743 |
+
elif len(instructions) == 1:
|
| 744 |
+
return instructions * batch_size
|
| 745 |
+
elif len(instructions) != batch_size:
|
| 746 |
+
# Pad or truncate to match batch_size
|
| 747 |
+
normalized = list(instructions[:batch_size])
|
| 748 |
+
default_instruction = default or DEFAULT_DIT_INSTRUCTION
|
| 749 |
+
while len(normalized) < batch_size:
|
| 750 |
+
normalized.append(default_instruction)
|
| 751 |
+
return normalized
|
| 752 |
+
else:
|
| 753 |
+
return list(instructions)
|
| 754 |
+
|
| 755 |
+
def _format_lyrics(self, lyrics: str, language: str) -> str:
|
| 756 |
+
"""Format lyrics text with language header."""
|
| 757 |
+
return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>"
|
| 758 |
+
|
| 759 |
+
def _pad_sequences(self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0) -> torch.Tensor:
|
| 760 |
+
"""Pad sequences to same length."""
|
| 761 |
+
return torch.stack([
|
| 762 |
+
torch.nn.functional.pad(seq, (0, max_length - len(seq)), 'constant', pad_value)
|
| 763 |
+
for seq in sequences
|
| 764 |
+
])
|
| 765 |
+
|
| 766 |
+
def _extract_caption_and_language(self, metas: List[Union[str, Dict[str, Any]]], captions: List[str], vocal_languages: List[str]) -> Tuple[List[str], List[str]]:
|
| 767 |
+
"""Extract caption and language from metas with fallback to provided values."""
|
| 768 |
+
actual_captions = list(captions)
|
| 769 |
+
actual_languages = list(vocal_languages)
|
| 770 |
+
|
| 771 |
+
for i, meta in enumerate(metas):
|
| 772 |
+
if i >= len(actual_captions):
|
| 773 |
+
break
|
| 774 |
+
|
| 775 |
+
meta_dict = None
|
| 776 |
+
if isinstance(meta, str):
|
| 777 |
+
parsed = self._parse_metas([meta])
|
| 778 |
+
if parsed and isinstance(parsed[0], dict):
|
| 779 |
+
meta_dict = parsed[0]
|
| 780 |
+
elif isinstance(meta, dict):
|
| 781 |
+
meta_dict = meta
|
| 782 |
+
|
| 783 |
+
if meta_dict:
|
| 784 |
+
if 'caption' in meta_dict and meta_dict['caption']:
|
| 785 |
+
actual_captions[i] = str(meta_dict['caption'])
|
| 786 |
+
if 'language' in meta_dict and meta_dict['language']:
|
| 787 |
+
actual_languages[i] = str(meta_dict['language'])
|
| 788 |
+
|
| 789 |
+
return actual_captions, actual_languages
|
| 790 |
+
|
| 791 |
+
def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor:
|
| 792 |
+
"""
|
| 793 |
+
Encode audio to latents using VAE.
|
| 794 |
+
|
| 795 |
+
Args:
|
| 796 |
+
audio: Audio tensor [channels, samples] or [batch, channels, samples]
|
| 797 |
+
|
| 798 |
+
Returns:
|
| 799 |
+
Latents tensor [T, D] or [batch, T, D]
|
| 800 |
+
"""
|
| 801 |
+
# Ensure batch dimension
|
| 802 |
+
if audio.dim() == 2:
|
| 803 |
+
audio = audio.unsqueeze(0)
|
| 804 |
+
|
| 805 |
+
# Ensure input is in VAE's dtype
|
| 806 |
+
vae_input = audio.to(self.device).to(self.vae.dtype)
|
| 807 |
+
|
| 808 |
+
# Encode to latents
|
| 809 |
+
with torch.no_grad():
|
| 810 |
+
latents = self.vae.encode(vae_input).latent_dist.sample()
|
| 811 |
+
|
| 812 |
+
# Cast back to model dtype
|
| 813 |
+
latents = latents.to(self.dtype)
|
| 814 |
+
|
| 815 |
+
# Transpose: [batch, d, T] -> [batch, T, d]
|
| 816 |
+
latents = latents.transpose(1, 2)
|
| 817 |
+
|
| 818 |
+
# Remove batch dimension if input didn't have it
|
| 819 |
+
if audio.dim() == 2:
|
| 820 |
+
latents = latents.squeeze(0)
|
| 821 |
+
|
| 822 |
+
return latents
|
| 823 |
+
|
| 824 |
+
def _build_metadata_dict(self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str, duration: Optional[float] = None) -> Dict[str, Any]:
|
| 825 |
+
"""
|
| 826 |
+
Build metadata dictionary with default values.
|
| 827 |
+
|
| 828 |
+
Args:
|
| 829 |
+
bpm: BPM value (optional)
|
| 830 |
+
key_scale: Key/scale string
|
| 831 |
+
time_signature: Time signature string
|
| 832 |
+
duration: Duration in seconds (optional)
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
Metadata dictionary
|
| 836 |
+
"""
|
| 837 |
metadata_dict = {}
|
| 838 |
if bpm:
|
| 839 |
metadata_dict["bpm"] = bpm
|
|
|
|
| 849 |
metadata_dict["timesignature"] = time_signature
|
| 850 |
else:
|
| 851 |
metadata_dict["timesignature"] = "N/A"
|
| 852 |
+
|
| 853 |
+
# Add duration if provided
|
| 854 |
+
if duration is not None:
|
| 855 |
+
metadata_dict["duration"] = f"{int(duration)} seconds"
|
| 856 |
+
|
| 857 |
return metadata_dict
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
def generate_instruction(
|
| 860 |
self,
|
|
|
|
| 901 |
# Load audio file
|
| 902 |
audio, sr = torchaudio.load(audio_file)
|
| 903 |
|
| 904 |
+
logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
|
| 905 |
+
logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
|
| 906 |
+
logger.debug(f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / 48000.0} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
|
| 908 |
+
# Normalize to stereo 48kHz
|
| 909 |
+
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 910 |
|
| 911 |
is_silence = self.is_silence(audio)
|
| 912 |
if is_silence:
|
|
|
|
| 945 |
return audio
|
| 946 |
|
| 947 |
except Exception as e:
|
| 948 |
+
logger.exception("[process_reference_audio] Error processing reference audio")
|
| 949 |
return None
|
| 950 |
|
| 951 |
def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
|
|
|
|
| 956 |
# Load audio file
|
| 957 |
audio, sr = torchaudio.load(audio_file)
|
| 958 |
|
| 959 |
+
# Normalize to stereo 48kHz
|
| 960 |
+
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
|
| 962 |
return audio
|
| 963 |
|
| 964 |
except Exception as e:
|
| 965 |
+
logger.exception("[process_src_audio] Error processing source audio")
|
| 966 |
return None
|
| 967 |
|
| 968 |
def convert_src_audio_to_codes(self, audio_file) -> str:
|
|
|
|
| 990 |
# Encode audio to latents using VAE
|
| 991 |
with torch.no_grad():
|
| 992 |
with self._load_model_context("vae"):
|
|
|
|
|
|
|
|
|
|
| 993 |
# Check if audio is silence
|
| 994 |
+
if self.is_silence(processed_audio.unsqueeze(0)):
|
| 995 |
return "❌ Audio file appears to be silent"
|
| 996 |
|
| 997 |
+
# Encode to latents using helper method
|
| 998 |
+
latents = self._encode_audio_to_latents(processed_audio) # [T, d]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 999 |
|
| 1000 |
# Create attention mask for latents
|
| 1001 |
attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
|
|
|
|
| 1020 |
|
| 1021 |
except Exception as e:
|
| 1022 |
error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
|
| 1023 |
+
logger.exception("[convert_src_audio_to_codes] Error converting audio to codes")
|
| 1024 |
return error_msg
|
| 1025 |
|
| 1026 |
def prepare_batch_data(
|
|
|
|
| 1049 |
calculated_duration = audio_duration
|
| 1050 |
|
| 1051 |
# Build metadata dict - use "N/A" as default for empty fields
|
| 1052 |
+
metadata_dict = self._build_metadata_dict(bpm, key_scale, time_signature, calculated_duration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1053 |
|
| 1054 |
# Format metadata - inference service accepts dict and will convert to string
|
| 1055 |
# Create a copy for each batch item (in case we modify it)
|
|
|
|
| 1085 |
target_wavs = torch.zeros(2, frames)
|
| 1086 |
return target_wavs
|
| 1087 |
except Exception as e:
|
| 1088 |
+
logger.exception("[create_target_wavs] Error creating target audio")
|
| 1089 |
# Fallback to 30 seconds if error
|
| 1090 |
return torch.zeros(2, 30 * 48000)
|
| 1091 |
|
|
|
|
| 1266 |
"""
|
| 1267 |
batch_size = len(captions)
|
| 1268 |
|
| 1269 |
+
# Normalize audio_code_hints to batch list
|
| 1270 |
+
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1271 |
|
| 1272 |
for ii, refer_audio_list in enumerate(refer_audios):
|
| 1273 |
if isinstance(refer_audio_list, list):
|
|
|
|
| 1279 |
if vocal_languages is None:
|
| 1280 |
vocal_languages = self._create_fallback_vocal_languages(batch_size)
|
| 1281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
# Parse metas with fallbacks
|
| 1283 |
parsed_metas = self._parse_metas(metas)
|
| 1284 |
|
|
|
|
| 1312 |
expected_latent_length = current_wav.shape[-1] // 1920
|
| 1313 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1314 |
else:
|
| 1315 |
+
# Encode using helper method
|
| 1316 |
logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
|
| 1317 |
+
target_latent = self._encode_audio_to_latents(current_wav.squeeze(0)) # Remove batch dim for helper
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1318 |
target_latents_list.append(target_latent)
|
| 1319 |
latent_lengths.append(target_latent.shape[0])
|
| 1320 |
|
|
|
|
| 1353 |
|
| 1354 |
# Process instructions early so we can use them for task type detection
|
| 1355 |
# Use custom instructions if provided, otherwise use default
|
| 1356 |
+
instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1357 |
|
| 1358 |
# Generate chunk_masks and spans based on repainting parameters
|
| 1359 |
# Also determine if this is a cover task (target audio provided without repainting)
|
|
|
|
| 1502 |
else:
|
| 1503 |
precomputed_lm_hints_25Hz = None
|
| 1504 |
|
| 1505 |
+
# Extract caption and language from metas if available (from LM CoT output)
|
| 1506 |
+
# Fallback to user-provided values if not in metas
|
| 1507 |
+
actual_captions, actual_languages = self._extract_caption_and_language(parsed_metas, captions, vocal_languages)
|
| 1508 |
+
|
| 1509 |
# Format text_inputs
|
| 1510 |
text_inputs = []
|
| 1511 |
text_token_idss = []
|
|
|
|
| 1515 |
|
| 1516 |
for i in range(batch_size):
|
| 1517 |
# Use custom instruction for this batch item
|
| 1518 |
+
instruction = self._format_instruction(instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION)
|
| 1519 |
+
|
| 1520 |
+
actual_caption = actual_captions[i]
|
| 1521 |
+
actual_language = actual_languages[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1522 |
|
| 1523 |
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1524 |
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
|
|
|
| 1535 |
text_attention_mask = text_inputs_dict.attention_mask[0].bool()
|
| 1536 |
|
| 1537 |
# Format and tokenize lyrics (using LM-generated language if available)
|
| 1538 |
+
lyrics_text = self._format_lyrics(lyrics[i], actual_language)
|
| 1539 |
lyrics_inputs_dict = self.text_tokenizer(
|
| 1540 |
lyrics_text,
|
| 1541 |
padding="longest",
|
|
|
|
| 1557 |
|
| 1558 |
# Pad tokenized sequences
|
| 1559 |
max_text_length = max(len(seq) for seq in text_token_idss)
|
| 1560 |
+
padded_text_token_idss = self._pad_sequences(text_token_idss, max_text_length, self.text_tokenizer.pad_token_id)
|
| 1561 |
+
padded_text_attention_masks = self._pad_sequences(text_attention_masks, max_text_length, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1562 |
|
| 1563 |
max_lyric_length = max(len(seq) for seq in lyric_token_idss)
|
| 1564 |
+
padded_lyric_token_idss = self._pad_sequences(lyric_token_idss, max_lyric_length, self.text_tokenizer.pad_token_id)
|
| 1565 |
+
padded_lyric_attention_masks = self._pad_sequences(lyric_attention_masks, max_lyric_length, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
|
| 1567 |
padded_non_cover_text_input_ids = None
|
| 1568 |
padded_non_cover_text_attention_masks = None
|
|
|
|
| 1571 |
non_cover_text_attention_masks = []
|
| 1572 |
for i in range(batch_size):
|
| 1573 |
# Use custom instruction for this batch item
|
| 1574 |
+
instruction = self._format_instruction(DEFAULT_DIT_INSTRUCTION)
|
| 1575 |
|
| 1576 |
# Extract caption from metas if available (from LM CoT output)
|
| 1577 |
+
actual_caption = actual_captions[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1578 |
|
| 1579 |
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1580 |
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
|
|
|
| 1592 |
non_cover_text_input_ids.append(text_token_ids)
|
| 1593 |
non_cover_text_attention_masks.append(non_cover_text_attention_mask)
|
| 1594 |
|
| 1595 |
+
padded_non_cover_text_input_ids = self._pad_sequences(non_cover_text_input_ids, max_text_length, self.text_tokenizer.pad_token_id)
|
| 1596 |
+
padded_non_cover_text_attention_masks = self._pad_sequences(non_cover_text_attention_masks, max_text_length, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1597 |
|
| 1598 |
if audio_cover_strength < 1.0:
|
| 1599 |
assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
|
|
|
|
| 1827 |
if self.config.is_turbo:
|
| 1828 |
# Limit inference steps to maximum 8
|
| 1829 |
if infer_steps > 8:
|
| 1830 |
+
logger.warning(f"[service_generate] dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
|
| 1831 |
infer_steps = 8
|
| 1832 |
# CFG parameters are not adjustable for dmd_gan (they will be ignored)
|
| 1833 |
# Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
|
|
|
|
| 1850 |
if isinstance(repainting_end, (int, float)):
|
| 1851 |
repainting_end = [repainting_end]
|
| 1852 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1853 |
# Get batch size from captions
|
| 1854 |
batch_size = len(captions)
|
| 1855 |
|
| 1856 |
+
# Normalize instructions and audio_code_hints to match batch size
|
| 1857 |
+
instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) if instructions is not None else None
|
| 1858 |
+
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) if audio_code_hints is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1859 |
|
| 1860 |
# Convert seed to list format
|
| 1861 |
if seed is None:
|
|
|
|
| 1952 |
logger.info("[service_generate] Generating audio...")
|
| 1953 |
with self._load_model_context("model"):
|
| 1954 |
outputs = self.model.generate_audio(**generate_kwargs)
|
| 1955 |
+
|
| 1956 |
+
# Add intermediate information to outputs for extra_outputs
|
| 1957 |
+
outputs["src_latents"] = src_latents
|
| 1958 |
+
outputs["target_latents_input"] = target_latents # Input target latents (before generation)
|
| 1959 |
+
outputs["chunk_masks"] = chunk_mask
|
| 1960 |
+
outputs["spans"] = spans
|
| 1961 |
+
outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
|
| 1962 |
+
|
| 1963 |
return outputs
|
| 1964 |
|
| 1965 |
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
|
|
|
| 2055 |
use_adg: bool = False,
|
| 2056 |
cfg_interval_start: float = 0.0,
|
| 2057 |
cfg_interval_end: float = 1.0,
|
|
|
|
|
|
|
| 2058 |
use_tiled_decode: bool = True,
|
| 2059 |
progress=None
|
| 2060 |
+
) -> Dict[str, Any]:
|
| 2061 |
"""
|
| 2062 |
Main interface for music generation
|
| 2063 |
|
| 2064 |
Returns:
|
| 2065 |
+
Dictionary containing:
|
| 2066 |
+
- audios: List of audio dictionaries with path, key, params
|
| 2067 |
+
- generation_info: Markdown-formatted generation information
|
| 2068 |
+
- status_message: Status message
|
| 2069 |
+
- extra_outputs: Dictionary with latents, masks, time_costs, etc.
|
| 2070 |
+
- success: Whether generation completed successfully
|
| 2071 |
+
- error: Error message if generation failed
|
| 2072 |
"""
|
| 2073 |
if progress is None:
|
| 2074 |
def progress(*args, **kwargs):
|
| 2075 |
pass
|
| 2076 |
|
| 2077 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 2078 |
+
return {
|
| 2079 |
+
"audios": [],
|
| 2080 |
+
"generation_info": "",
|
| 2081 |
+
"status_message": "❌ Model not fully initialized. Please initialize all components first.",
|
| 2082 |
+
"extra_outputs": {},
|
| 2083 |
+
"success": False,
|
| 2084 |
+
"error": "Model not fully initialized",
|
| 2085 |
+
}
|
| 2086 |
|
| 2087 |
def _has_audio_codes(v: Union[str, List[str]]) -> bool:
|
| 2088 |
if isinstance(v, list):
|
|
|
|
| 2213 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 2214 |
time_costs = outputs["time_costs"]
|
| 2215 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2216 |
+
logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
|
| 2217 |
+
logger.debug(f"[generate_music] time_costs: {time_costs}")
|
| 2218 |
if progress:
|
| 2219 |
progress(0.8, desc="Decoding audio...")
|
| 2220 |
logger.info("[generate_music] Decoding latents with VAE...")
|
|
|
|
| 2243 |
# Update offload cost one last time to include VAE offloading
|
| 2244 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2245 |
|
| 2246 |
+
logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
|
| 2247 |
if progress:
|
| 2248 |
+
progress(0.9, desc="Preparing audio data...")
|
| 2249 |
|
| 2250 |
+
# Prepare audio tensors (no file I/O here, no UUID generation)
|
| 2251 |
+
# pred_wavs is already [batch, channels, samples] format
|
| 2252 |
+
# Move to CPU and convert to float32 for return
|
| 2253 |
+
audio_tensors = []
|
| 2254 |
|
|
|
|
|
|
|
| 2255 |
for i in range(actual_batch_size):
|
| 2256 |
+
# Extract audio tensor: [channels, samples] format, CPU, float32
|
| 2257 |
+
audio_tensor = pred_wavs[i].cpu().float()
|
| 2258 |
+
audio_tensors.append(audio_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2259 |
|
| 2260 |
# Format time costs if available
|
| 2261 |
time_costs_str = ""
|
|
|
|
| 2273 |
|
| 2274 |
**Seeds:** {seed_value_for_ui}
|
| 2275 |
**Steps:** {inference_steps}
|
| 2276 |
+
**Audio Count:** {len(audio_tensors)} audio(s){time_costs_str}"""
|
| 2277 |
status_message = f"✅ Generation completed successfully!"
|
| 2278 |
+
logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
|
| 2279 |
+
|
| 2280 |
+
# Extract intermediate information from outputs
|
| 2281 |
+
src_latents = outputs.get("src_latents") # [batch, T, D]
|
| 2282 |
+
target_latents_input = outputs.get("target_latents_input") # [batch, T, D]
|
| 2283 |
+
chunk_masks = outputs.get("chunk_masks") # [batch, T]
|
| 2284 |
+
spans = outputs.get("spans", []) # List of tuples
|
| 2285 |
+
latent_masks = outputs.get("latent_masks") # [batch, T]
|
| 2286 |
+
|
| 2287 |
+
# Move latents to CPU to save memory (they can be large)
|
| 2288 |
+
extra_outputs = {
|
| 2289 |
+
"pred_latents": pred_latents.cpu() if pred_latents is not None else None,
|
| 2290 |
+
"target_latents": target_latents_input.cpu() if target_latents_input is not None else None,
|
| 2291 |
+
"src_latents": src_latents.cpu() if src_latents is not None else None,
|
| 2292 |
+
"chunk_masks": chunk_masks.cpu() if chunk_masks is not None else None,
|
| 2293 |
+
"latent_masks": latent_masks.cpu() if latent_masks is not None else None,
|
| 2294 |
+
"spans": spans,
|
| 2295 |
+
"time_costs": time_costs,
|
| 2296 |
+
"seed_value": seed_value_for_ui,
|
| 2297 |
+
}
|
| 2298 |
+
|
| 2299 |
+
# Build audios list with tensor data (no file paths, no UUIDs, handled outside)
|
| 2300 |
+
audios = []
|
| 2301 |
+
for idx, audio_tensor in enumerate(audio_tensors):
|
| 2302 |
+
audio_dict = {
|
| 2303 |
+
"tensor": audio_tensor, # torch.Tensor [channels, samples], CPU, float32
|
| 2304 |
+
"sample_rate": self.sample_rate,
|
| 2305 |
+
}
|
| 2306 |
+
audios.append(audio_dict)
|
| 2307 |
+
|
| 2308 |
+
return {
|
| 2309 |
+
"audios": audios,
|
| 2310 |
+
"generation_info": generation_info,
|
| 2311 |
+
"status_message": status_message,
|
| 2312 |
+
"extra_outputs": extra_outputs,
|
| 2313 |
+
"success": True,
|
| 2314 |
+
"error": None,
|
| 2315 |
+
}
|
| 2316 |
|
| 2317 |
except Exception as e:
|
| 2318 |
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
|
| 2319 |
+
logger.exception("[generate_music] Generation failed")
|
| 2320 |
+
return {
|
| 2321 |
+
"audios": [],
|
| 2322 |
+
"generation_info": "",
|
| 2323 |
+
"status_message": error_msg,
|
| 2324 |
+
"extra_outputs": {},
|
| 2325 |
+
"success": False,
|
| 2326 |
+
"error": str(e),
|
| 2327 |
+
}
|
acestep/inference.py
CHANGED
|
@@ -7,105 +7,100 @@ backward-compatible Gradio UI support.
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import math
|
|
|
|
|
|
|
| 10 |
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 11 |
from dataclasses import dataclass, field, asdict
|
| 12 |
from loguru import logger
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
@dataclass
|
| 17 |
-
class
|
| 18 |
-
"""Configuration for music generation.
|
| 19 |
|
| 20 |
Attributes:
|
| 21 |
# Text Inputs
|
| 22 |
-
caption:
|
| 23 |
-
lyrics: Lyrics
|
|
|
|
| 24 |
|
| 25 |
# Music Metadata
|
| 26 |
-
bpm:
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
vocal_language: Language code for vocals
|
| 30 |
-
|
| 31 |
|
| 32 |
# Generation Parameters
|
| 33 |
-
inference_steps: Number of
|
| 34 |
-
guidance_scale:
|
| 35 |
-
|
| 36 |
-
seed: Random seed for reproducibility (-1 for random)
|
| 37 |
-
batch_size: Number of samples to generate (1-8)
|
| 38 |
|
| 39 |
# Advanced DiT Parameters
|
| 40 |
-
use_adg:
|
| 41 |
-
cfg_interval_start:
|
| 42 |
-
cfg_interval_end:
|
| 43 |
-
audio_format: Output audio format ("mp3", "wav", "flac")
|
| 44 |
|
| 45 |
# Task-Specific Parameters
|
| 46 |
-
task_type:
|
| 47 |
-
reference_audio: Path to reference audio file
|
| 48 |
-
src_audio: Path to source audio file
|
| 49 |
-
|
| 50 |
-
repainting_start:
|
| 51 |
-
repainting_end:
|
| 52 |
-
audio_cover_strength: Strength of
|
| 53 |
-
instruction:
|
| 54 |
|
| 55 |
-
# 5Hz Language Model Parameters
|
| 56 |
-
|
| 57 |
-
lm_temperature:
|
| 58 |
-
lm_cfg_scale:
|
| 59 |
-
lm_top_k:
|
| 60 |
-
lm_top_p:
|
| 61 |
-
lm_negative_prompt: Negative prompt for
|
| 62 |
-
use_cot_metas:
|
| 63 |
-
use_cot_caption:
|
| 64 |
-
use_cot_language:
|
| 65 |
-
is_format_caption: Whether caption is already formatted
|
| 66 |
-
constrained_decoding_debug: Enable debug logging for constrained decoding
|
| 67 |
-
|
| 68 |
-
# Batch LM Generation
|
| 69 |
-
allow_lm_batch: Allow batch LM code generation (faster for batch_size >= 2)
|
| 70 |
-
lm_batch_chunk_size: Maximum batch size per LM inference chunk (GPU memory constraint)
|
| 71 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# Text Inputs
|
| 74 |
caption: str = ""
|
| 75 |
lyrics: str = ""
|
|
|
|
| 76 |
|
| 77 |
-
#
|
| 78 |
-
bpm: Optional[int] = None
|
| 79 |
-
key_scale: str = ""
|
| 80 |
-
time_signature: str = ""
|
| 81 |
vocal_language: str = "unknown"
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
inference_steps: int = 8
|
| 86 |
-
guidance_scale: float = 7.0
|
| 87 |
-
use_random_seed: bool = True
|
| 88 |
seed: int = -1
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
# Advanced DiT Parameters
|
| 92 |
use_adg: bool = False
|
| 93 |
cfg_interval_start: float = 0.0
|
| 94 |
cfg_interval_end: float = 1.0
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# Task-Specific Parameters
|
| 98 |
-
task_type: str = "text2music"
|
| 99 |
-
reference_audio: Optional[str] = None
|
| 100 |
-
src_audio: Optional[str] = None
|
| 101 |
-
audio_code_string: Union[str, List[str]] = ""
|
| 102 |
repainting_start: float = 0.0
|
| 103 |
repainting_end: float = -1
|
| 104 |
audio_cover_strength: float = 1.0
|
| 105 |
-
instruction: str = ""
|
| 106 |
|
| 107 |
# 5Hz Language Model Parameters
|
| 108 |
-
|
| 109 |
lm_temperature: float = 0.85
|
| 110 |
lm_cfg_scale: float = 2.0
|
| 111 |
lm_top_k: int = 0
|
|
@@ -114,66 +109,59 @@ class GenerationConfig:
|
|
| 114 |
use_cot_metas: bool = True
|
| 115 |
use_cot_caption: bool = True
|
| 116 |
use_cot_language: bool = True
|
| 117 |
-
is_format_caption: bool = False
|
| 118 |
-
constrained_decoding_debug: bool = False
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
@dataclass
|
| 126 |
class GenerationResult:
|
| 127 |
"""Result of music generation.
|
| 128 |
|
| 129 |
Attributes:
|
| 130 |
# Audio Outputs
|
| 131 |
-
|
| 132 |
-
first_audio: Path to first generated audio (backward compatibility)
|
| 133 |
-
second_audio: Path to second generated audio (backward compatibility)
|
| 134 |
-
|
| 135 |
-
# Generation Information
|
| 136 |
generation_info: Markdown-formatted generation information
|
| 137 |
status_message: Status message from generation
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
# LM-Generated Metadata (if applicable)
|
| 141 |
-
lm_metadata: Metadata generated by language model (dict or None)
|
| 142 |
-
|
| 143 |
-
# Audio-Text Alignment Scores (if available)
|
| 144 |
-
align_score_1: First alignment score
|
| 145 |
-
align_text_1: First alignment text description
|
| 146 |
-
align_plot_1: First alignment plot image
|
| 147 |
-
align_score_2: Second alignment score
|
| 148 |
-
align_text_2: Second alignment text description
|
| 149 |
-
align_plot_2: Second alignment plot image
|
| 150 |
-
|
| 151 |
-
# Success Status
|
| 152 |
success: Whether generation completed successfully
|
| 153 |
error: Error message if generation failed
|
| 154 |
"""
|
| 155 |
|
| 156 |
# Audio Outputs
|
| 157 |
-
|
| 158 |
-
first_audio: Optional[str] = None
|
| 159 |
-
second_audio: Optional[str] = None
|
| 160 |
-
|
| 161 |
# Generation Information
|
| 162 |
generation_info: str = ""
|
| 163 |
status_message: str = ""
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# LM-Generated Metadata
|
| 167 |
-
lm_metadata: Optional[Dict[str, Any]] = None
|
| 168 |
-
|
| 169 |
-
# Audio-Text Alignment Scores
|
| 170 |
-
align_score_1: Optional[float] = None
|
| 171 |
-
align_text_1: Optional[str] = None
|
| 172 |
-
align_plot_1: Optional[Any] = None
|
| 173 |
-
align_score_2: Optional[float] = None
|
| 174 |
-
align_text_2: Optional[str] = None
|
| 175 |
-
align_plot_2: Optional[Any] = None
|
| 176 |
-
|
| 177 |
# Success Status
|
| 178 |
success: bool = True
|
| 179 |
error: Optional[str] = None
|
|
@@ -186,75 +174,71 @@ class GenerationResult:
|
|
| 186 |
def generate_music(
|
| 187 |
dit_handler,
|
| 188 |
llm_handler,
|
|
|
|
| 189 |
config: GenerationConfig,
|
|
|
|
| 190 |
) -> GenerationResult:
|
| 191 |
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 192 |
|
| 193 |
-
This is the main inference API for music generation. It supports various task types
|
| 194 |
-
(text2music, cover, repaint, etc.) and can optionally use a 5Hz Language Model for
|
| 195 |
-
Chain-of-Thought reasoning to generate metadata and audio codes.
|
| 196 |
-
|
| 197 |
Args:
|
| 198 |
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 199 |
llm_handler: Initialized LLM handler (LLMHandler instance)
|
|
|
|
| 200 |
config: Generation configuration (GenerationConfig instance)
|
| 201 |
|
| 202 |
Returns:
|
| 203 |
-
GenerationResult
|
| 204 |
-
|
| 205 |
-
Example:
|
| 206 |
-
>>> from acestep.handler import AceStepHandler
|
| 207 |
-
>>> from acestep.llm_inference import LLMHandler
|
| 208 |
-
>>> from acestep.inference import GenerationConfig, generate_music
|
| 209 |
-
>>>
|
| 210 |
-
>>> # Initialize handlers
|
| 211 |
-
>>> dit_handler = AceStepHandler()
|
| 212 |
-
>>> llm_handler = LLMHandler()
|
| 213 |
-
>>> dit_handler.initialize_service(...)
|
| 214 |
-
>>> llm_handler.initialize(...)
|
| 215 |
-
>>>
|
| 216 |
-
>>> # Configure generation
|
| 217 |
-
>>> config = GenerationConfig(
|
| 218 |
-
... caption="upbeat electronic dance music",
|
| 219 |
-
... bpm=128,
|
| 220 |
-
... audio_duration=30,
|
| 221 |
-
... batch_size=2,
|
| 222 |
-
... )
|
| 223 |
-
>>>
|
| 224 |
-
>>> # Generate music
|
| 225 |
-
>>> result = generate_music(dit_handler, llm_handler, config)
|
| 226 |
-
>>> print(f"Generated {len(result.audio_paths)} audio files")
|
| 227 |
-
>>> for path in result.audio_paths:
|
| 228 |
-
... print(f"Audio: {path}")
|
| 229 |
"""
|
| 230 |
-
|
| 231 |
try:
|
| 232 |
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 233 |
-
audio_code_string_to_use =
|
| 234 |
lm_generated_metadata = None
|
| 235 |
-
lm_generated_audio_codes = None
|
| 236 |
lm_generated_audio_codes_list = []
|
| 237 |
|
| 238 |
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 239 |
-
bpm =
|
| 240 |
-
key_scale =
|
| 241 |
-
time_signature =
|
| 242 |
-
audio_duration =
|
| 243 |
|
| 244 |
-
# Determine if we
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
)
|
|
|
|
|
|
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
# LM-based Chain-of-Thought reasoning
|
| 254 |
-
if
|
| 255 |
# Convert sampling parameters
|
| 256 |
-
top_k_value = None if
|
| 257 |
-
top_p_value = None if
|
| 258 |
|
| 259 |
# Build user_metadata from user-provided values
|
| 260 |
user_metadata = {}
|
|
@@ -286,165 +270,231 @@ def generate_music(
|
|
| 286 |
|
| 287 |
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 288 |
|
| 289 |
-
#
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
logger.info(
|
| 308 |
-
f"LM batch chunk {chunk_idx+1}/{num_chunks} "
|
| 309 |
-
f"(size: {chunk_size}, seeds: {chunk_seeds})"
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
|
| 313 |
-
caption=config.caption or "",
|
| 314 |
-
lyrics=config.lyrics or "",
|
| 315 |
-
batch_size=chunk_size,
|
| 316 |
-
infer_type="llm_dit",
|
| 317 |
-
temperature=config.lm_temperature,
|
| 318 |
-
cfg_scale=config.lm_cfg_scale,
|
| 319 |
-
negative_prompt=config.lm_negative_prompt,
|
| 320 |
-
top_k=top_k_value,
|
| 321 |
-
top_p=top_p_value,
|
| 322 |
-
user_metadata=user_metadata_to_pass,
|
| 323 |
-
use_cot_caption=config.use_cot_caption,
|
| 324 |
-
use_cot_language=config.use_cot_language,
|
| 325 |
-
is_format_caption=config.is_format_caption,
|
| 326 |
-
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 327 |
-
seeds=chunk_seeds,
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
all_metadata_list.extend(metadata_list)
|
| 331 |
-
all_audio_codes_list.extend(audio_codes_list)
|
| 332 |
-
|
| 333 |
-
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 334 |
-
lm_generated_audio_codes_list = all_audio_codes_list
|
| 335 |
-
audio_code_string_to_use = all_audio_codes_list
|
| 336 |
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
else:
|
| 344 |
-
# Sequential LM generation (current behavior)
|
| 345 |
-
# Phase 1: Generate CoT metadata
|
| 346 |
-
phase1_start = time_module.time()
|
| 347 |
-
metadata, _, status = llm_handler.generate_with_stop_condition(
|
| 348 |
-
caption=config.caption or "",
|
| 349 |
-
lyrics=config.lyrics or "",
|
| 350 |
-
infer_type="dit",
|
| 351 |
-
temperature=config.lm_temperature,
|
| 352 |
-
cfg_scale=config.lm_cfg_scale,
|
| 353 |
-
negative_prompt=config.lm_negative_prompt,
|
| 354 |
-
top_k=top_k_value,
|
| 355 |
-
top_p=top_p_value,
|
| 356 |
-
user_metadata=user_metadata_to_pass,
|
| 357 |
-
use_cot_caption=config.use_cot_caption,
|
| 358 |
-
use_cot_language=config.use_cot_language,
|
| 359 |
-
is_format_caption=config.is_format_caption,
|
| 360 |
-
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 361 |
)
|
| 362 |
-
lm_phase1_time = time_module.time() - phase1_start
|
| 363 |
-
logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
|
| 364 |
|
| 365 |
-
#
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
|
|
|
| 374 |
top_k=top_k_value,
|
| 375 |
top_p=top_p_value,
|
| 376 |
user_metadata=user_metadata_to_pass,
|
| 377 |
-
use_cot_caption=
|
| 378 |
-
use_cot_language=
|
| 379 |
is_format_caption=config.is_format_caption,
|
|
|
|
| 380 |
constrained_decoding_debug=config.constrained_decoding_debug,
|
|
|
|
|
|
|
| 381 |
)
|
| 382 |
-
lm_phase2_time = time_module.time() - phase2_start
|
| 383 |
-
logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
# Phase 2: DiT music generation
|
|
|
|
| 396 |
result = dit_handler.generate_music(
|
| 397 |
-
captions=
|
| 398 |
-
lyrics=
|
| 399 |
bpm=bpm,
|
| 400 |
key_scale=key_scale,
|
| 401 |
time_signature=time_signature,
|
| 402 |
-
vocal_language=
|
| 403 |
-
inference_steps=
|
| 404 |
-
guidance_scale=
|
| 405 |
use_random_seed=config.use_random_seed,
|
| 406 |
-
seed=config.seed
|
| 407 |
-
reference_audio=
|
| 408 |
audio_duration=audio_duration,
|
| 409 |
-
batch_size=config.batch_size,
|
| 410 |
-
src_audio=
|
| 411 |
audio_code_string=audio_code_string_to_use,
|
| 412 |
-
repainting_start=
|
| 413 |
-
repainting_end=
|
| 414 |
-
instruction=
|
| 415 |
-
audio_cover_strength=
|
| 416 |
-
task_type=
|
| 417 |
-
use_adg=
|
| 418 |
-
cfg_interval_start=
|
| 419 |
-
cfg_interval_end=
|
| 420 |
-
audio_format=config.audio_format,
|
| 421 |
-
lm_temperature=config.lm_temperature,
|
| 422 |
)
|
| 423 |
|
| 424 |
-
#
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
# Append LM metadata to generation info
|
| 430 |
if lm_generated_metadata:
|
| 431 |
generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
|
| 432 |
|
| 433 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
return GenerationResult(
|
| 435 |
-
|
| 436 |
-
first_audio=first_audio,
|
| 437 |
-
second_audio=second_audio,
|
| 438 |
generation_info=generation_info,
|
| 439 |
status_message=status_message,
|
| 440 |
-
|
| 441 |
-
lm_metadata=lm_generated_metadata,
|
| 442 |
-
align_score_1=align_score_1,
|
| 443 |
-
align_text_1=align_text_1,
|
| 444 |
-
align_plot_1=align_plot_1,
|
| 445 |
-
align_score_2=align_score_2,
|
| 446 |
-
align_text_2=align_text_2,
|
| 447 |
-
align_plot_2=align_plot_2,
|
| 448 |
success=True,
|
| 449 |
error=None,
|
| 450 |
)
|
|
@@ -452,10 +502,12 @@ def generate_music(
|
|
| 452 |
except Exception as e:
|
| 453 |
logger.exception("Music generation failed")
|
| 454 |
return GenerationResult(
|
| 455 |
-
|
| 456 |
-
error=str(e),
|
| 457 |
generation_info=f"❌ Generation failed: {str(e)}",
|
| 458 |
status_message=f"Error: {str(e)}",
|
|
|
|
|
|
|
|
|
|
| 459 |
)
|
| 460 |
|
| 461 |
|
|
@@ -525,7 +577,7 @@ def _append_lm_metadata_to_info(generation_info: str, metadata: Dict[str, Any])
|
|
| 525 |
# LEGACY GRADIO UI COMPATIBILITY LAYER
|
| 526 |
# ============================================================================
|
| 527 |
|
| 528 |
-
def
|
| 529 |
dit_handler,
|
| 530 |
llm_handler,
|
| 531 |
captions,
|
|
@@ -575,20 +627,19 @@ def generate(
|
|
| 575 |
Tuple with 28 elements for Gradio UI component updates
|
| 576 |
"""
|
| 577 |
|
| 578 |
-
# Convert legacy parameters to
|
| 579 |
-
|
| 580 |
caption=captions,
|
| 581 |
lyrics=lyrics,
|
| 582 |
bpm=bpm,
|
| 583 |
-
|
| 584 |
-
|
| 585 |
vocal_language=vocal_language,
|
| 586 |
-
|
|
|
|
| 587 |
inference_steps=inference_steps,
|
| 588 |
guidance_scale=guidance_scale,
|
| 589 |
-
use_random_seed=random_seed_checkbox,
|
| 590 |
seed=seed,
|
| 591 |
-
batch_size=batch_size_input,
|
| 592 |
use_adg=use_adg,
|
| 593 |
cfg_interval_start=cfg_interval_start,
|
| 594 |
cfg_interval_end=cfg_interval_end,
|
|
@@ -596,12 +647,11 @@ def generate(
|
|
| 596 |
task_type=task_type,
|
| 597 |
reference_audio=reference_audio,
|
| 598 |
src_audio=src_audio,
|
| 599 |
-
audio_code_string=text2music_audio_code_string,
|
| 600 |
repainting_start=repainting_start,
|
| 601 |
repainting_end=repainting_end,
|
| 602 |
audio_cover_strength=audio_cover_strength,
|
| 603 |
instruction=instruction_display_gen,
|
| 604 |
-
|
| 605 |
lm_temperature=lm_temperature,
|
| 606 |
lm_cfg_scale=lm_cfg_scale,
|
| 607 |
lm_top_k=lm_top_k,
|
|
@@ -610,29 +660,49 @@ def generate(
|
|
| 610 |
use_cot_metas=use_cot_metas,
|
| 611 |
use_cot_caption=use_cot_caption,
|
| 612 |
use_cot_language=use_cot_language,
|
| 613 |
-
is_format_caption=is_format_caption,
|
| 614 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 615 |
-
allow_lm_batch=allow_lm_batch,
|
| 616 |
-
lm_batch_chunk_size=lm_batch_chunk_size,
|
| 617 |
)
|
| 618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
# Call new API
|
| 620 |
-
result = generate_music(dit_handler, llm_handler, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
# Determine which codes to update in UI
|
| 623 |
-
if config.allow_lm_batch and
|
| 624 |
# Batch mode: extract codes from metadata if available
|
| 625 |
-
lm_codes_list =
|
| 626 |
updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
|
| 627 |
codes_outputs = (lm_codes_list + [""] * 8)[:8]
|
| 628 |
else:
|
| 629 |
# Single mode
|
| 630 |
-
lm_codes =
|
| 631 |
updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
|
| 632 |
codes_outputs = [""] * 8
|
| 633 |
|
| 634 |
# Prepare audio outputs (up to 8)
|
| 635 |
-
audio_outputs = (
|
| 636 |
|
| 637 |
# Return tuple for Gradio UI (28 elements)
|
| 638 |
return (
|
|
@@ -644,16 +714,16 @@ def generate(
|
|
| 644 |
audio_outputs[5], # generated_audio_6
|
| 645 |
audio_outputs[6], # generated_audio_7
|
| 646 |
audio_outputs[7], # generated_audio_8
|
| 647 |
-
|
| 648 |
result.generation_info,
|
| 649 |
result.status_message,
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
updated_audio_codes, # Update main audio codes in UI
|
| 658 |
codes_outputs[0], # text2music_audio_code_string_1
|
| 659 |
codes_outputs[1], # text2music_audio_code_string_2
|
|
@@ -663,266 +733,8 @@ def generate(
|
|
| 663 |
codes_outputs[5], # text2music_audio_code_string_6
|
| 664 |
codes_outputs[6], # text2music_audio_code_string_7
|
| 665 |
codes_outputs[7], # text2music_audio_code_string_8
|
| 666 |
-
|
| 667 |
is_format_caption, # Keep is_format_caption unchanged
|
| 668 |
)
|
| 669 |
|
| 670 |
|
| 671 |
-
# ============================================================================
|
| 672 |
-
# TESTING & EXAMPLES
|
| 673 |
-
# ============================================================================
|
| 674 |
-
|
| 675 |
-
if __name__ == "__main__":
|
| 676 |
-
"""
|
| 677 |
-
Test suite for the inference API.
|
| 678 |
-
Demonstrates various usage scenarios and validates functionality.
|
| 679 |
-
|
| 680 |
-
Usage:
|
| 681 |
-
python -m acestep.inference
|
| 682 |
-
"""
|
| 683 |
-
|
| 684 |
-
import os
|
| 685 |
-
import json
|
| 686 |
-
from acestep.handler import AceStepHandler
|
| 687 |
-
from acestep.llm_inference import LLMHandler
|
| 688 |
-
|
| 689 |
-
# Initialize paths
|
| 690 |
-
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 691 |
-
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 692 |
-
|
| 693 |
-
print("=" * 80)
|
| 694 |
-
print("ACE-Step Inference API Test Suite")
|
| 695 |
-
print("=" * 80)
|
| 696 |
-
|
| 697 |
-
# ========================================================================
|
| 698 |
-
# Initialize Handlers
|
| 699 |
-
# ========================================================================
|
| 700 |
-
print("\n[1/3] Initializing handlers...")
|
| 701 |
-
dit_handler = AceStepHandler(save_root="./")
|
| 702 |
-
llm_handler = LLMHandler()
|
| 703 |
-
|
| 704 |
-
try:
|
| 705 |
-
# Initialize DiT handler
|
| 706 |
-
print(" - Initializing DiT model...")
|
| 707 |
-
status_dit, success_dit = dit_handler.initialize_service(
|
| 708 |
-
project_root=project_root,
|
| 709 |
-
config_path="acestep-v15-turbo-rl",
|
| 710 |
-
device="cuda",
|
| 711 |
-
)
|
| 712 |
-
if not success_dit:
|
| 713 |
-
print(f" ❌ DiT initialization failed: {status_dit}")
|
| 714 |
-
exit(1)
|
| 715 |
-
print(f" ✓ DiT model initialized successfully")
|
| 716 |
-
|
| 717 |
-
# Initialize LLM handler
|
| 718 |
-
print(" - Initializing 5Hz LM model...")
|
| 719 |
-
status_llm, success_llm = llm_handler.initialize(
|
| 720 |
-
checkpoint_dir=checkpoint_dir,
|
| 721 |
-
lm_model_path="acestep-5Hz-lm-0.6B-v3",
|
| 722 |
-
backend="vllm",
|
| 723 |
-
device="cuda",
|
| 724 |
-
)
|
| 725 |
-
if success_llm:
|
| 726 |
-
print(f" ✓ LM model initialized successfully")
|
| 727 |
-
else:
|
| 728 |
-
print(f" ⚠ LM initialization failed (will skip LM tests): {status_llm}")
|
| 729 |
-
|
| 730 |
-
except Exception as e:
|
| 731 |
-
print(f" ❌ Initialization error: {e}")
|
| 732 |
-
exit(1)
|
| 733 |
-
|
| 734 |
-
# ========================================================================
|
| 735 |
-
# Helper Functions
|
| 736 |
-
# ========================================================================
|
| 737 |
-
def load_example_config(example_file: str) -> GenerationConfig:
|
| 738 |
-
"""Load configuration from an example JSON file."""
|
| 739 |
-
try:
|
| 740 |
-
with open(example_file, 'r', encoding='utf-8') as f:
|
| 741 |
-
data = json.load(f)
|
| 742 |
-
|
| 743 |
-
# Convert example format to GenerationConfig
|
| 744 |
-
# Handle time signature format (example uses "4" instead of "4/4")
|
| 745 |
-
time_sig = data.get('timesignature', '')
|
| 746 |
-
if time_sig and '/' not in time_sig:
|
| 747 |
-
time_sig = f"{time_sig}/4" # Default to /4 if only numerator given
|
| 748 |
-
|
| 749 |
-
config = GenerationConfig(
|
| 750 |
-
caption=data.get('caption', ''),
|
| 751 |
-
lyrics=data.get('lyrics', ''),
|
| 752 |
-
bpm=data.get('bpm'),
|
| 753 |
-
key_scale=data.get('keyscale', ''),
|
| 754 |
-
time_signature=time_sig,
|
| 755 |
-
vocal_language=data.get('language', 'unknown'),
|
| 756 |
-
audio_duration=data.get('duration'),
|
| 757 |
-
use_llm_thinking=data.get('think', False),
|
| 758 |
-
batch_size=data.get('batch_size', 1),
|
| 759 |
-
inference_steps=data.get('inference_steps', 8),
|
| 760 |
-
)
|
| 761 |
-
return config
|
| 762 |
-
|
| 763 |
-
except Exception as e:
|
| 764 |
-
print(f" ⚠ Failed to load example file: {e}")
|
| 765 |
-
return None
|
| 766 |
-
|
| 767 |
-
# ========================================================================
|
| 768 |
-
# Test Cases
|
| 769 |
-
# ========================================================================
|
| 770 |
-
test_results = []
|
| 771 |
-
|
| 772 |
-
def run_test(test_name: str, config: GenerationConfig, expected_outputs: int = 1):
|
| 773 |
-
"""Run a single test case and collect results."""
|
| 774 |
-
print(f"\n{'=' * 80}")
|
| 775 |
-
print(f"Test: {test_name}")
|
| 776 |
-
print(f"{'=' * 80}")
|
| 777 |
-
|
| 778 |
-
# Display configuration
|
| 779 |
-
print("\nConfiguration:")
|
| 780 |
-
print(f" Task Type: {config.task_type}")
|
| 781 |
-
print(f" Caption: {config.caption[:60]}..." if len(config.caption) > 60 else f" Caption: {config.caption}")
|
| 782 |
-
if config.lyrics:
|
| 783 |
-
print(f" Lyrics: {config.lyrics[:60]}..." if len(config.lyrics) > 60 else f" Lyrics: {config.lyrics}")
|
| 784 |
-
if config.bpm:
|
| 785 |
-
print(f" BPM: {config.bpm}")
|
| 786 |
-
if config.key_scale:
|
| 787 |
-
print(f" Key Scale: {config.key_scale}")
|
| 788 |
-
if config.time_signature:
|
| 789 |
-
print(f" Time Signature: {config.time_signature}")
|
| 790 |
-
if config.audio_duration:
|
| 791 |
-
print(f" Duration: {config.audio_duration}s")
|
| 792 |
-
print(f" Batch Size: {config.batch_size}")
|
| 793 |
-
print(f" Inference Steps: {config.inference_steps}")
|
| 794 |
-
print(f" Use LLM Thinking: {config.use_llm_thinking}")
|
| 795 |
-
|
| 796 |
-
# Run generation
|
| 797 |
-
print("\nGenerating...")
|
| 798 |
-
import time
|
| 799 |
-
start_time = time.time()
|
| 800 |
-
|
| 801 |
-
result = generate_music(dit_handler, llm_handler, config)
|
| 802 |
-
|
| 803 |
-
elapsed_time = time.time() - start_time
|
| 804 |
-
|
| 805 |
-
# Display results
|
| 806 |
-
print("\nResults:")
|
| 807 |
-
print(f" Success: {'✓' if result.success else '✗'}")
|
| 808 |
-
|
| 809 |
-
if result.success:
|
| 810 |
-
print(f" Generated Files: {len(result.audio_paths)}")
|
| 811 |
-
for i, path in enumerate(result.audio_paths, 1):
|
| 812 |
-
if os.path.exists(path):
|
| 813 |
-
file_size = os.path.getsize(path) / (1024 * 1024) # MB
|
| 814 |
-
print(f" [{i}] {os.path.basename(path)} ({file_size:.2f} MB)")
|
| 815 |
-
else:
|
| 816 |
-
print(f" [{i}] {os.path.basename(path)} (file not found)")
|
| 817 |
-
|
| 818 |
-
print(f" Seed: {result.seed_value}")
|
| 819 |
-
print(f" Generation Time: {elapsed_time:.2f}s")
|
| 820 |
-
|
| 821 |
-
# Display LM metadata if available
|
| 822 |
-
if result.lm_metadata:
|
| 823 |
-
print(f"\n LM-Generated Metadata:")
|
| 824 |
-
for key, value in result.lm_metadata.items():
|
| 825 |
-
if key not in ['audio_codes', 'audio_codes_list']: # Skip large code strings
|
| 826 |
-
print(f" {key}: {value}")
|
| 827 |
-
|
| 828 |
-
# Validate outputs
|
| 829 |
-
if len(result.audio_paths) != expected_outputs:
|
| 830 |
-
print(f" ⚠ Warning: Expected {expected_outputs} outputs, got {len(result.audio_paths)}")
|
| 831 |
-
success = False
|
| 832 |
-
else:
|
| 833 |
-
success = True
|
| 834 |
-
|
| 835 |
-
else:
|
| 836 |
-
print(f" Error: {result.error}")
|
| 837 |
-
success = False
|
| 838 |
-
|
| 839 |
-
# Store test result
|
| 840 |
-
test_results.append({
|
| 841 |
-
"test_name": test_name,
|
| 842 |
-
"success": success,
|
| 843 |
-
"generation_success": result.success,
|
| 844 |
-
"num_outputs": len(result.audio_paths) if result.success else 0,
|
| 845 |
-
"expected_outputs": expected_outputs,
|
| 846 |
-
"elapsed_time": elapsed_time,
|
| 847 |
-
"error": result.error if not result.success else None,
|
| 848 |
-
})
|
| 849 |
-
|
| 850 |
-
return result
|
| 851 |
-
|
| 852 |
-
# ========================================================================
|
| 853 |
-
# Test: Production Example (from examples directory)
|
| 854 |
-
# ========================================================================
|
| 855 |
-
print("\n[2/3] Running Test...")
|
| 856 |
-
|
| 857 |
-
# Load production example (J-Rock song from examples/text2music/example_05.json)
|
| 858 |
-
example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
|
| 859 |
-
|
| 860 |
-
if not os.path.exists(example_file):
|
| 861 |
-
print(f"\n ❌ Example file not found: {example_file}")
|
| 862 |
-
print(" Please ensure the examples directory exists.")
|
| 863 |
-
exit(1)
|
| 864 |
-
|
| 865 |
-
print(f" Loading example: {os.path.basename(example_file)}")
|
| 866 |
-
config = load_example_config(example_file)
|
| 867 |
-
|
| 868 |
-
if not config:
|
| 869 |
-
print(" ❌ Failed to load example configuration")
|
| 870 |
-
exit(1)
|
| 871 |
-
|
| 872 |
-
# Reduce duration for faster testing (original is 200s)
|
| 873 |
-
print(f" Original duration: {config.audio_duration}s")
|
| 874 |
-
config.audio_duration = 30
|
| 875 |
-
config.use_random_seed = False
|
| 876 |
-
config.seed = 42
|
| 877 |
-
print(f" Test duration: {config.audio_duration}s (reduced for testing)")
|
| 878 |
-
|
| 879 |
-
run_test("Production Example (J-Rock Song)", config, expected_outputs=1)
|
| 880 |
-
|
| 881 |
-
# ========================================================================
|
| 882 |
-
# Test Summary
|
| 883 |
-
# ========================================================================
|
| 884 |
-
print("\n[3/3] Test Summary")
|
| 885 |
-
print("=" * 80)
|
| 886 |
-
|
| 887 |
-
if len(test_results) == 0:
|
| 888 |
-
print("No tests were run.")
|
| 889 |
-
exit(1)
|
| 890 |
-
|
| 891 |
-
result = test_results[0]
|
| 892 |
-
|
| 893 |
-
print(f"\nTest: {result['test_name']}")
|
| 894 |
-
print(f"Status: {'✓ PASS' if result['success'] else '✗ FAIL'}")
|
| 895 |
-
print(f"Generation: {'Success' if result['generation_success'] else 'Failed'}")
|
| 896 |
-
print(f"Outputs: {result['num_outputs']}/{result['expected_outputs']}")
|
| 897 |
-
print(f"Time: {result['elapsed_time']:.2f}s")
|
| 898 |
-
|
| 899 |
-
if result["error"]:
|
| 900 |
-
print(f"Error: {result['error']}")
|
| 901 |
-
|
| 902 |
-
# Save test results to JSON
|
| 903 |
-
results_file = os.path.join(project_root, "test_results.json")
|
| 904 |
-
try:
|
| 905 |
-
with open(results_file, "w") as f:
|
| 906 |
-
json.dump({
|
| 907 |
-
"test_name": result['test_name'],
|
| 908 |
-
"success": result['success'],
|
| 909 |
-
"generation_success": result['generation_success'],
|
| 910 |
-
"num_outputs": result['num_outputs'],
|
| 911 |
-
"expected_outputs": result['expected_outputs'],
|
| 912 |
-
"elapsed_time": result['elapsed_time'],
|
| 913 |
-
"error": result['error'],
|
| 914 |
-
}, f, indent=2)
|
| 915 |
-
print(f"\n✓ Test results saved to: {results_file}")
|
| 916 |
-
except Exception as e:
|
| 917 |
-
print(f"\n⚠ Failed to save test results: {e}")
|
| 918 |
-
|
| 919 |
-
# Exit with appropriate code
|
| 920 |
-
print("\n" + "=" * 80)
|
| 921 |
-
if result['success']:
|
| 922 |
-
print("Test passed! ✓")
|
| 923 |
-
print("=" * 80)
|
| 924 |
-
exit(0)
|
| 925 |
-
else:
|
| 926 |
-
print("Test failed! ✗")
|
| 927 |
-
print("=" * 80)
|
| 928 |
-
exit(1)
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import math
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 13 |
from dataclasses import dataclass, field, asdict
|
| 14 |
from loguru import logger
|
| 15 |
+
|
| 16 |
+
from acestep.audio_utils import AudioSaver, generate_uuid_from_params
|
| 17 |
|
| 18 |
|
| 19 |
@dataclass
|
| 20 |
+
class GenerationParams:
|
| 21 |
+
"""Configuration for music generation parameters.
|
| 22 |
|
| 23 |
Attributes:
|
| 24 |
# Text Inputs
|
| 25 |
+
caption: A short text prompt describing the desired music (main prompt). < 512 characters
|
| 26 |
+
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
|
| 27 |
+
instrumental: If True, generate instrumental music regardless of lyrics.
|
| 28 |
|
| 29 |
# Music Metadata
|
| 30 |
+
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
|
| 31 |
+
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
|
| 32 |
+
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
|
| 33 |
+
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
|
| 34 |
+
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
|
| 35 |
|
| 36 |
# Generation Parameters
|
| 37 |
+
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
|
| 38 |
+
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
|
| 39 |
+
seed: Integer seed for reproducibility. -1 means use random seed each time.
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Advanced DiT Parameters
|
| 42 |
+
use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
|
| 43 |
+
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
|
| 44 |
+
cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
|
|
|
|
| 45 |
|
| 46 |
# Task-Specific Parameters
|
| 47 |
+
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
|
| 48 |
+
reference_audio: Path to a reference audio file for style transfer or cover tasks.
|
| 49 |
+
src_audio: Path to a source audio file for audio-to-audio tasks.
|
| 50 |
+
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
|
| 51 |
+
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
|
| 52 |
+
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
|
| 53 |
+
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
|
| 54 |
+
instruction: Optional task instruction prompt. If empty, auto-generated by system.
|
| 55 |
|
| 56 |
+
# 5Hz Language Model Parameters for CoT reasoning
|
| 57 |
+
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
|
| 58 |
+
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
|
| 59 |
+
lm_cfg_scale: Classifier-free guidance scale for the LLM.
|
| 60 |
+
lm_top_k: LLM top-k sampling (0 = disabled).
|
| 61 |
+
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
|
| 62 |
+
lm_negative_prompt: Negative prompt to use for LLM (for control).
|
| 63 |
+
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
|
| 64 |
+
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
|
| 65 |
+
use_cot_language: Whether to let LLM detect vocal language via CoT.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
"""
|
| 67 |
+
# Required Inputs
|
| 68 |
+
task_type: str = "text2music"
|
| 69 |
+
instruction: str = "Fill the audio semantic mask based on the given conditions:"
|
| 70 |
+
|
| 71 |
+
# Audio Uploads
|
| 72 |
+
reference_audio: Optional[str] = None
|
| 73 |
+
src_audio: Optional[str] = None
|
| 74 |
+
|
| 75 |
+
# LM Codes Hints
|
| 76 |
+
audio_codes: str = ""
|
| 77 |
|
| 78 |
# Text Inputs
|
| 79 |
caption: str = ""
|
| 80 |
lyrics: str = ""
|
| 81 |
+
instrumental: bool = False
|
| 82 |
|
| 83 |
+
# Metadata
|
|
|
|
|
|
|
|
|
|
| 84 |
vocal_language: str = "unknown"
|
| 85 |
+
bpm: Optional[int] = None
|
| 86 |
+
keyscale: str = ""
|
| 87 |
+
timesignature: str = ""
|
| 88 |
+
duration: float = -1.0
|
| 89 |
+
|
| 90 |
+
# Advanced Settings
|
| 91 |
inference_steps: int = 8
|
|
|
|
|
|
|
| 92 |
seed: int = -1
|
| 93 |
+
guidance_scale: float = 7.0
|
|
|
|
|
|
|
| 94 |
use_adg: bool = False
|
| 95 |
cfg_interval_start: float = 0.0
|
| 96 |
cfg_interval_end: float = 1.0
|
| 97 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
repainting_start: float = 0.0
|
| 99 |
repainting_end: float = -1
|
| 100 |
audio_cover_strength: float = 1.0
|
|
|
|
| 101 |
|
| 102 |
# 5Hz Language Model Parameters
|
| 103 |
+
thinking: bool = True
|
| 104 |
lm_temperature: float = 0.85
|
| 105 |
lm_cfg_scale: float = 2.0
|
| 106 |
lm_top_k: int = 0
|
|
|
|
| 109 |
use_cot_metas: bool = True
|
| 110 |
use_cot_caption: bool = True
|
| 111 |
use_cot_language: bool = True
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 114 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 115 |
+
return asdict(self)
|
| 116 |
|
| 117 |
|
| 118 |
+
@dataclass
|
| 119 |
+
class GenerationConfig:
|
| 120 |
+
"""Configuration for music generation.
|
| 121 |
+
|
| 122 |
+
Attributes:
|
| 123 |
+
batch_size: Number of audio samples to generate
|
| 124 |
+
allow_lm_batch: Whether to allow batch processing in LM
|
| 125 |
+
use_random_seed: Whether to use random seed
|
| 126 |
+
seed: Seed(s) for batch generation. Can be:
|
| 127 |
+
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
|
| 128 |
+
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
|
| 129 |
+
- int: Single seed value (will be converted to list and padded)
|
| 130 |
+
lm_batch_chunk_size: Batch chunk size for LM processing
|
| 131 |
+
is_format_caption: Whether to format caption
|
| 132 |
+
constrained_decoding_debug: Whether to enable constrained decoding debug
|
| 133 |
+
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
|
| 134 |
+
"""
|
| 135 |
+
batch_size: int = 2
|
| 136 |
+
allow_lm_batch: bool = False
|
| 137 |
+
use_random_seed: bool = True
|
| 138 |
+
seed: Optional[Union[int, List[int]]] = None
|
| 139 |
+
lm_batch_chunk_size: int = 8
|
| 140 |
+
is_format_caption: bool = False
|
| 141 |
+
use_constrained_decoding: bool = True
|
| 142 |
+
constrained_decoding_debug: bool = False
|
| 143 |
+
audio_format: str = "flac" # Default to FLAC for fast saving
|
| 144 |
+
|
| 145 |
@dataclass
|
| 146 |
class GenerationResult:
|
| 147 |
"""Result of music generation.
|
| 148 |
|
| 149 |
Attributes:
|
| 150 |
# Audio Outputs
|
| 151 |
+
audios: List of audio dictionaries with paths, keys, params
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
generation_info: Markdown-formatted generation information
|
| 153 |
status_message: Status message from generation
|
| 154 |
+
extra_outputs: Extra outputs from generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
success: Whether generation completed successfully
|
| 156 |
error: Error message if generation failed
|
| 157 |
"""
|
| 158 |
|
| 159 |
# Audio Outputs
|
| 160 |
+
audios: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
| 161 |
# Generation Information
|
| 162 |
generation_info: str = ""
|
| 163 |
status_message: str = ""
|
| 164 |
+
extra_outputs: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
# Success Status
|
| 166 |
success: bool = True
|
| 167 |
error: Optional[str] = None
|
|
|
|
| 174 |
def generate_music(
|
| 175 |
dit_handler,
|
| 176 |
llm_handler,
|
| 177 |
+
params: GenerationParams,
|
| 178 |
config: GenerationConfig,
|
| 179 |
+
save_dir: Optional[str] = None,
|
| 180 |
) -> GenerationResult:
|
| 181 |
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
Args:
|
| 184 |
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 185 |
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 186 |
+
params: Generation parameters (GenerationParams instance)
|
| 187 |
config: Generation configuration (GenerationConfig instance)
|
| 188 |
|
| 189 |
Returns:
|
| 190 |
+
GenerationResult with generated audio files and metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
"""
|
|
|
|
| 192 |
try:
|
| 193 |
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 194 |
+
audio_code_string_to_use = params.audio_codes
|
| 195 |
lm_generated_metadata = None
|
|
|
|
| 196 |
lm_generated_audio_codes_list = []
|
| 197 |
|
| 198 |
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 199 |
+
bpm = params.bpm
|
| 200 |
+
key_scale = params.keyscale
|
| 201 |
+
time_signature = params.timesignature
|
| 202 |
+
audio_duration = params.duration
|
| 203 |
|
| 204 |
+
# Determine if we need to generate audio codes
|
| 205 |
+
# If user has provided audio_codes, we don't need to generate them
|
| 206 |
+
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
|
| 207 |
+
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
|
| 208 |
+
|
| 209 |
+
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
|
| 210 |
+
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
|
| 211 |
+
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
|
| 212 |
+
# Note: This logic can be refined based on specific requirements
|
| 213 |
+
need_audio_codes = not user_provided_audio_codes
|
| 214 |
|
| 215 |
+
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
|
| 216 |
+
# Determine actual batch size for chunk processing
|
| 217 |
+
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
| 218 |
+
|
| 219 |
+
# Prepare seeds for batch generation
|
| 220 |
+
# Use config.seed if provided, otherwise fallback to params.seed
|
| 221 |
+
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
|
| 222 |
+
seed_for_generation = params.seed # Default fallback
|
| 223 |
+
if config.seed is not None:
|
| 224 |
+
if isinstance(config.seed, list):
|
| 225 |
+
# Convert List[int] to comma-separated string
|
| 226 |
+
seed_for_generation = ",".join(str(s) for s in config.seed)
|
| 227 |
+
elif isinstance(config.seed, int):
|
| 228 |
+
# Single int seed
|
| 229 |
+
seed_for_generation = config.seed
|
| 230 |
+
|
| 231 |
+
# Use dit_handler.prepare_seeds to handle seed list generation and padding
|
| 232 |
+
# This will handle all the logic: padding with random seeds if needed, etc.
|
| 233 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(
|
| 234 |
+
actual_batch_size, seed_for_generation, config.use_random_seed
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
# LM-based Chain-of-Thought reasoning
|
| 238 |
+
if params.thinking and llm_handler.llm_initialized and params.use_cot_metas:
|
| 239 |
# Convert sampling parameters
|
| 240 |
+
top_k_value = None if params.lm_top_k == 0 else int(params.lm_top_k)
|
| 241 |
+
top_p_value = None if params.lm_top_p >= 1.0 else params.lm_top_p
|
| 242 |
|
| 243 |
# Build user_metadata from user-provided values
|
| 244 |
user_metadata = {}
|
|
|
|
| 270 |
|
| 271 |
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 272 |
|
| 273 |
+
# Determine infer_type based on whether we need audio codes
|
| 274 |
+
# - "llm_dit": generates both metas and audio codes (two-phase internally)
|
| 275 |
+
# - "dit": generates only metas (single phase)
|
| 276 |
+
infer_type = "llm_dit" if need_audio_codes else "dit"
|
| 277 |
+
|
| 278 |
+
# Use chunk size from config, or default to batch_size if not set
|
| 279 |
+
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
|
| 280 |
+
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
|
| 281 |
+
|
| 282 |
+
all_metadata_list = []
|
| 283 |
+
all_audio_codes_list = []
|
| 284 |
+
|
| 285 |
+
for chunk_idx in range(num_chunks):
|
| 286 |
+
chunk_start = chunk_idx * max_inference_batch_size
|
| 287 |
+
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
|
| 288 |
+
chunk_size = chunk_end - chunk_start
|
| 289 |
+
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
+
logger.info(
|
| 292 |
+
f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
|
| 293 |
+
f"(size: {chunk_size}, seeds: {chunk_seeds})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
)
|
|
|
|
|
|
|
| 295 |
|
| 296 |
+
# Use the determined infer_type
|
| 297 |
+
# - "llm_dit" will internally run two phases (metas + codes)
|
| 298 |
+
# - "dit" will only run phase 1 (metas only)
|
| 299 |
+
result = llm_handler.generate_with_stop_condition(
|
| 300 |
+
caption=params.caption or "",
|
| 301 |
+
lyrics=params.lyrics or "",
|
| 302 |
+
infer_type=infer_type,
|
| 303 |
+
temperature=params.lm_temperature,
|
| 304 |
+
cfg_scale=params.lm_cfg_scale,
|
| 305 |
+
negative_prompt=params.lm_negative_prompt,
|
| 306 |
top_k=top_k_value,
|
| 307 |
top_p=top_p_value,
|
| 308 |
user_metadata=user_metadata_to_pass,
|
| 309 |
+
use_cot_caption=params.use_cot_caption,
|
| 310 |
+
use_cot_language=params.use_cot_language,
|
| 311 |
is_format_caption=config.is_format_caption,
|
| 312 |
+
use_constrained_decoding=config.use_constrained_decoding,
|
| 313 |
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 314 |
+
batch_size=chunk_size,
|
| 315 |
+
seeds=chunk_seeds,
|
| 316 |
)
|
|
|
|
|
|
|
| 317 |
|
| 318 |
+
if chunk_size > 1:
|
| 319 |
+
metadata_list, audio_codes_list, status = result
|
| 320 |
+
all_metadata_list.extend(metadata_list)
|
| 321 |
+
all_audio_codes_list.extend(audio_codes_list)
|
| 322 |
+
else:
|
| 323 |
+
metadata, audio_codes, status = result
|
| 324 |
+
all_metadata_list.append(metadata)
|
| 325 |
+
all_audio_codes_list.append(audio_codes)
|
| 326 |
+
|
| 327 |
+
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 328 |
+
lm_generated_audio_codes_list = all_audio_codes_list
|
| 329 |
+
|
| 330 |
+
# Set audio_code_string_to_use based on infer_type
|
| 331 |
+
if infer_type == "llm_dit":
|
| 332 |
+
# If batch mode, use list; otherwise use single string
|
| 333 |
+
if actual_batch_size > 1:
|
| 334 |
+
audio_code_string_to_use = all_audio_codes_list
|
| 335 |
+
else:
|
| 336 |
+
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
|
| 337 |
+
else:
|
| 338 |
+
# For "dit" mode, keep user-provided codes or empty
|
| 339 |
+
audio_code_string_to_use = params.audio_codes
|
| 340 |
+
|
| 341 |
+
# Update metadata from LM if not provided by user
|
| 342 |
+
if lm_generated_metadata:
|
| 343 |
+
bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
|
| 344 |
+
lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
|
| 348 |
# Phase 2: DiT music generation
|
| 349 |
+
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
|
| 350 |
result = dit_handler.generate_music(
|
| 351 |
+
captions=params.caption,
|
| 352 |
+
lyrics=params.lyrics,
|
| 353 |
bpm=bpm,
|
| 354 |
key_scale=key_scale,
|
| 355 |
time_signature=time_signature,
|
| 356 |
+
vocal_language=params.vocal_language,
|
| 357 |
+
inference_steps=params.inference_steps,
|
| 358 |
+
guidance_scale=params.guidance_scale,
|
| 359 |
use_random_seed=config.use_random_seed,
|
| 360 |
+
seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
|
| 361 |
+
reference_audio=params.reference_audio,
|
| 362 |
audio_duration=audio_duration,
|
| 363 |
+
batch_size=config.batch_size if config.batch_size is not None else 1,
|
| 364 |
+
src_audio=params.src_audio,
|
| 365 |
audio_code_string=audio_code_string_to_use,
|
| 366 |
+
repainting_start=params.repainting_start,
|
| 367 |
+
repainting_end=params.repainting_end,
|
| 368 |
+
instruction=params.instruction,
|
| 369 |
+
audio_cover_strength=params.audio_cover_strength,
|
| 370 |
+
task_type=params.task_type,
|
| 371 |
+
use_adg=params.use_adg,
|
| 372 |
+
cfg_interval_start=params.cfg_interval_start,
|
| 373 |
+
cfg_interval_end=params.cfg_interval_end,
|
|
|
|
|
|
|
| 374 |
)
|
| 375 |
|
| 376 |
+
# Check if generation failed
|
| 377 |
+
if not result.get("success", False):
|
| 378 |
+
return GenerationResult(
|
| 379 |
+
audios=[],
|
| 380 |
+
generation_info=result.get("generation_info", ""),
|
| 381 |
+
status_message=result.get("status_message", ""),
|
| 382 |
+
extra_outputs={},
|
| 383 |
+
success=False,
|
| 384 |
+
error=result.get("error"),
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Extract results from dit_handler.generate_music dict
|
| 388 |
+
dit_audios = result.get("audios", [])
|
| 389 |
+
generation_info = result.get("generation_info", "")
|
| 390 |
+
status_message = result.get("status_message", "")
|
| 391 |
+
dit_extra_outputs = result.get("extra_outputs", {})
|
| 392 |
|
| 393 |
# Append LM metadata to generation info
|
| 394 |
if lm_generated_metadata:
|
| 395 |
generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
|
| 396 |
|
| 397 |
+
# Use the seed list already prepared above (from config.seed or params.seed fallback)
|
| 398 |
+
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
|
| 399 |
+
seed_list = actual_seed_list
|
| 400 |
+
|
| 401 |
+
# Get base params dictionary
|
| 402 |
+
base_params_dict = params.to_dict()
|
| 403 |
+
|
| 404 |
+
# Save audio files using AudioSaver (format from config)
|
| 405 |
+
audio_format = config.audio_format if config.audio_format else "flac"
|
| 406 |
+
audio_saver = AudioSaver(default_format=audio_format)
|
| 407 |
+
|
| 408 |
+
# Use handler's temp_dir for saving files
|
| 409 |
+
if save_dir is not None:
|
| 410 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 411 |
+
|
| 412 |
+
# Build audios list for GenerationResult with params and save files
|
| 413 |
+
# Audio saving and UUID generation handled here, outside of handler
|
| 414 |
+
audios = []
|
| 415 |
+
for idx, dit_audio in enumerate(dit_audios):
|
| 416 |
+
# Create a copy of params dict for this audio
|
| 417 |
+
audio_params = base_params_dict.copy()
|
| 418 |
+
|
| 419 |
+
# Update audio-specific values
|
| 420 |
+
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
|
| 421 |
+
|
| 422 |
+
# Add audio codes if batch mode
|
| 423 |
+
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
|
| 424 |
+
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
|
| 425 |
+
|
| 426 |
+
# Get audio tensor and metadata
|
| 427 |
+
audio_tensor = dit_audio.get("tensor")
|
| 428 |
+
sample_rate = dit_audio.get("sample_rate", 48000)
|
| 429 |
+
|
| 430 |
+
# Generate UUID for this audio (moved from handler)
|
| 431 |
+
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
|
| 432 |
+
audio_code_str = lm_generated_audio_codes_list[idx] if (lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
|
| 433 |
+
if isinstance(audio_code_str, list):
|
| 434 |
+
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
|
| 435 |
+
|
| 436 |
+
audio_key = generate_uuid_from_params(
|
| 437 |
+
captions=params.caption,
|
| 438 |
+
lyrics=params.lyrics,
|
| 439 |
+
bpm=bpm,
|
| 440 |
+
key_scale=key_scale,
|
| 441 |
+
time_signature=time_signature,
|
| 442 |
+
vocal_language=params.vocal_language,
|
| 443 |
+
inference_steps=params.inference_steps,
|
| 444 |
+
guidance_scale=params.guidance_scale,
|
| 445 |
+
seed=batch_seed,
|
| 446 |
+
audio_duration=audio_duration,
|
| 447 |
+
audio_code_string=audio_code_str,
|
| 448 |
+
repainting_start=params.repainting_start,
|
| 449 |
+
repainting_end=params.repainting_end,
|
| 450 |
+
instruction=params.instruction,
|
| 451 |
+
audio_cover_strength=params.audio_cover_strength,
|
| 452 |
+
task_type=params.task_type,
|
| 453 |
+
use_adg=params.use_adg,
|
| 454 |
+
cfg_interval_start=params.cfg_interval_start,
|
| 455 |
+
cfg_interval_end=params.cfg_interval_end,
|
| 456 |
+
audio_format=audio_format,
|
| 457 |
+
reference_audio=params.reference_audio,
|
| 458 |
+
src_audio=params.src_audio,
|
| 459 |
+
batch_index=idx,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Save audio file (handled outside handler)
|
| 463 |
+
audio_path = None
|
| 464 |
+
if audio_tensor is not None and save_dir is not None:
|
| 465 |
+
try:
|
| 466 |
+
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
|
| 467 |
+
audio_path = audio_saver.save_audio(
|
| 468 |
+
audio_tensor,
|
| 469 |
+
audio_file,
|
| 470 |
+
sample_rate=sample_rate,
|
| 471 |
+
format=audio_format,
|
| 472 |
+
channels_first=True
|
| 473 |
+
)
|
| 474 |
+
except Exception as e:
|
| 475 |
+
logger.error(f"[generate_music] Failed to save audio file: {e}")
|
| 476 |
+
audio_path = "" # Fallback to empty path
|
| 477 |
+
|
| 478 |
+
audio_dict = {
|
| 479 |
+
"path": audio_path or "", # File path (saved here, not in handler)
|
| 480 |
+
"tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
|
| 481 |
+
"key": audio_key,
|
| 482 |
+
"sample_rate": sample_rate,
|
| 483 |
+
"params": audio_params,
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
audios.append(audio_dict)
|
| 487 |
+
|
| 488 |
+
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
|
| 489 |
+
extra_outputs = dit_extra_outputs.copy()
|
| 490 |
+
extra_outputs["lm_metadata"] = lm_generated_metadata
|
| 491 |
+
|
| 492 |
+
# Create and return GenerationResult
|
| 493 |
return GenerationResult(
|
| 494 |
+
audios=audios,
|
|
|
|
|
|
|
| 495 |
generation_info=generation_info,
|
| 496 |
status_message=status_message,
|
| 497 |
+
extra_outputs=extra_outputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
success=True,
|
| 499 |
error=None,
|
| 500 |
)
|
|
|
|
| 502 |
except Exception as e:
|
| 503 |
logger.exception("Music generation failed")
|
| 504 |
return GenerationResult(
|
| 505 |
+
audios=[],
|
|
|
|
| 506 |
generation_info=f"❌ Generation failed: {str(e)}",
|
| 507 |
status_message=f"Error: {str(e)}",
|
| 508 |
+
extra_outputs={},
|
| 509 |
+
success=False,
|
| 510 |
+
error=str(e),
|
| 511 |
)
|
| 512 |
|
| 513 |
|
|
|
|
| 577 |
# LEGACY GRADIO UI COMPATIBILITY LAYER
|
| 578 |
# ============================================================================
|
| 579 |
|
| 580 |
+
def generate_for_gradio(
|
| 581 |
dit_handler,
|
| 582 |
llm_handler,
|
| 583 |
captions,
|
|
|
|
| 627 |
Tuple with 28 elements for Gradio UI component updates
|
| 628 |
"""
|
| 629 |
|
| 630 |
+
# Convert legacy parameters to GenerationParams and GenerationConfig
|
| 631 |
+
params = GenerationParams(
|
| 632 |
caption=captions,
|
| 633 |
lyrics=lyrics,
|
| 634 |
bpm=bpm,
|
| 635 |
+
keyscale=key_scale,
|
| 636 |
+
timesignature=time_signature,
|
| 637 |
vocal_language=vocal_language,
|
| 638 |
+
audio_codes=text2music_audio_code_string,
|
| 639 |
+
duration=audio_duration,
|
| 640 |
inference_steps=inference_steps,
|
| 641 |
guidance_scale=guidance_scale,
|
|
|
|
| 642 |
seed=seed,
|
|
|
|
| 643 |
use_adg=use_adg,
|
| 644 |
cfg_interval_start=cfg_interval_start,
|
| 645 |
cfg_interval_end=cfg_interval_end,
|
|
|
|
| 647 |
task_type=task_type,
|
| 648 |
reference_audio=reference_audio,
|
| 649 |
src_audio=src_audio,
|
|
|
|
| 650 |
repainting_start=repainting_start,
|
| 651 |
repainting_end=repainting_end,
|
| 652 |
audio_cover_strength=audio_cover_strength,
|
| 653 |
instruction=instruction_display_gen,
|
| 654 |
+
thinking=think_checkbox,
|
| 655 |
lm_temperature=lm_temperature,
|
| 656 |
lm_cfg_scale=lm_cfg_scale,
|
| 657 |
lm_top_k=lm_top_k,
|
|
|
|
| 660 |
use_cot_metas=use_cot_metas,
|
| 661 |
use_cot_caption=use_cot_caption,
|
| 662 |
use_cot_language=use_cot_language,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
)
|
| 664 |
|
| 665 |
+
config = GenerationConfig(batch_size=1)
|
| 666 |
+
config.batch_size = batch_size_input
|
| 667 |
+
config.use_random_seed = random_seed_checkbox
|
| 668 |
+
config.allow_lm_batch = allow_lm_batch
|
| 669 |
+
config.lm_batch_chunk_size = lm_batch_chunk_size
|
| 670 |
+
config.is_format_caption = is_format_caption
|
| 671 |
+
config.constrained_decoding_debug = constrained_decoding_debug
|
| 672 |
+
|
| 673 |
# Call new API
|
| 674 |
+
result = generate_music(dit_handler, llm_handler, params, config)
|
| 675 |
+
|
| 676 |
+
# Extract audio paths from result.audios
|
| 677 |
+
audio_paths = [audio["path"] for audio in result.audios]
|
| 678 |
+
|
| 679 |
+
# Extract extra outputs
|
| 680 |
+
extra_outputs = result.extra_outputs
|
| 681 |
+
seed_value = extra_outputs.get("seed_value", "")
|
| 682 |
+
lm_metadata = extra_outputs.get("lm_metadata", None)
|
| 683 |
+
|
| 684 |
+
# Legacy alignment fields (no longer used, set to empty/None)
|
| 685 |
+
align_score_1 = ""
|
| 686 |
+
align_text_1 = ""
|
| 687 |
+
align_plot_1 = None
|
| 688 |
+
align_score_2 = ""
|
| 689 |
+
align_text_2 = ""
|
| 690 |
+
align_plot_2 = None
|
| 691 |
|
| 692 |
# Determine which codes to update in UI
|
| 693 |
+
if config.allow_lm_batch and lm_metadata:
|
| 694 |
# Batch mode: extract codes from metadata if available
|
| 695 |
+
lm_codes_list = lm_metadata.get('audio_codes_list', [])
|
| 696 |
updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
|
| 697 |
codes_outputs = (lm_codes_list + [""] * 8)[:8]
|
| 698 |
else:
|
| 699 |
# Single mode
|
| 700 |
+
lm_codes = lm_metadata.get('audio_codes', '') if lm_metadata else ''
|
| 701 |
updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
|
| 702 |
codes_outputs = [""] * 8
|
| 703 |
|
| 704 |
# Prepare audio outputs (up to 8)
|
| 705 |
+
audio_outputs = (audio_paths + [None] * 8)[:8]
|
| 706 |
|
| 707 |
# Return tuple for Gradio UI (28 elements)
|
| 708 |
return (
|
|
|
|
| 714 |
audio_outputs[5], # generated_audio_6
|
| 715 |
audio_outputs[6], # generated_audio_7
|
| 716 |
audio_outputs[7], # generated_audio_8
|
| 717 |
+
audio_paths, # generated_audio_batch
|
| 718 |
result.generation_info,
|
| 719 |
result.status_message,
|
| 720 |
+
seed_value,
|
| 721 |
+
align_score_1,
|
| 722 |
+
align_text_1,
|
| 723 |
+
align_plot_1,
|
| 724 |
+
align_score_2,
|
| 725 |
+
align_text_2,
|
| 726 |
+
align_plot_2,
|
| 727 |
updated_audio_codes, # Update main audio codes in UI
|
| 728 |
codes_outputs[0], # text2music_audio_code_string_1
|
| 729 |
codes_outputs[1], # text2music_audio_code_string_2
|
|
|
|
| 733 |
codes_outputs[5], # text2music_audio_code_string_6
|
| 734 |
codes_outputs[6], # text2music_audio_code_string_7
|
| 735 |
codes_outputs[7], # text2music_audio_code_string_8
|
| 736 |
+
lm_metadata, # Store metadata for "Send to src audio" buttons
|
| 737 |
is_format_caption, # Keep is_format_caption unchanged
|
| 738 |
)
|
| 739 |
|
| 740 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acestep/llm_inference.py
CHANGED
|
@@ -5,7 +5,7 @@ Handles all LM-related operations including initialization and generation
|
|
| 5 |
import os
|
| 6 |
import traceback
|
| 7 |
import time
|
| 8 |
-
from typing import Optional, Dict, Any, Tuple, List
|
| 9 |
from contextlib import contextmanager
|
| 10 |
|
| 11 |
import yaml
|
|
@@ -85,6 +85,189 @@ class LLMHandler:
|
|
| 85 |
except Exception as e:
|
| 86 |
return 0.9, False
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
def initialize(
|
| 89 |
self,
|
| 90 |
checkpoint_dir: str,
|
|
@@ -150,41 +333,21 @@ class LLMHandler:
|
|
| 150 |
# vllm initialization failed, fallback to PyTorch
|
| 151 |
if not self.llm_initialized:
|
| 152 |
logger.warning("vllm initialization failed, falling back to PyTorch backend")
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
else:
|
| 158 |
-
self.llm = self.llm.to("cpu").to(self.dtype)
|
| 159 |
-
self.llm.eval()
|
| 160 |
-
self.llm_backend = "pt"
|
| 161 |
-
self.llm_initialized = True
|
| 162 |
-
logger.info("5Hz LM initialized successfully using PyTorch backend (fallback)")
|
| 163 |
-
status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
|
| 164 |
-
except Exception as e:
|
| 165 |
-
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 166 |
# If vllm initialization succeeded, self.llm_initialized should already be True
|
| 167 |
else:
|
| 168 |
# Use PyTorch backend (pt)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
self.llm = self.llm.to(device).to(self.dtype)
|
| 173 |
-
else:
|
| 174 |
-
self.llm = self.llm.to("cpu").to(self.dtype)
|
| 175 |
-
self.llm.eval()
|
| 176 |
-
self.llm_backend = "pt"
|
| 177 |
-
self.llm_initialized = True
|
| 178 |
-
logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
|
| 179 |
-
status_msg = f"✅ 5Hz LM initialized successfully\nModel: {full_lm_model_path}\nBackend: PyTorch\nDevice: {device}"
|
| 180 |
-
except Exception as e:
|
| 181 |
-
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 182 |
|
| 183 |
return status_msg, True
|
| 184 |
|
| 185 |
except Exception as e:
|
| 186 |
-
|
| 187 |
-
return error_msg, False
|
| 188 |
|
| 189 |
def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
|
| 190 |
"""Initialize 5Hz LM model using vllm backend"""
|
|
@@ -230,12 +393,11 @@ class LLMHandler:
|
|
| 230 |
return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
|
| 231 |
except Exception as e:
|
| 232 |
self.llm_initialized = False
|
| 233 |
-
|
| 234 |
-
return error_msg
|
| 235 |
|
| 236 |
-
def
|
| 237 |
self,
|
| 238 |
-
|
| 239 |
temperature: float,
|
| 240 |
cfg_scale: float,
|
| 241 |
negative_prompt: str,
|
|
@@ -244,7 +406,7 @@ class LLMHandler:
|
|
| 244 |
repetition_penalty: float,
|
| 245 |
use_constrained_decoding: bool = True,
|
| 246 |
constrained_decoding_debug: bool = False,
|
| 247 |
-
metadata_temperature: Optional[float] =
|
| 248 |
codes_temperature: Optional[float] = None,
|
| 249 |
target_duration: Optional[float] = None,
|
| 250 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
|
@@ -256,37 +418,40 @@ class LLMHandler:
|
|
| 256 |
caption: str = "",
|
| 257 |
lyrics: str = "",
|
| 258 |
cot_text: str = "",
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
from nanovllm import SamplingParams
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
# Determine effective temperature for sampler
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
|
| 266 |
|
| 267 |
-
#
|
| 268 |
-
constrained_processor =
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
# Set skip_caption and skip_language based on flags
|
| 283 |
-
self.constrained_processor.set_skip_genres(skip_genres)
|
| 284 |
-
self.constrained_processor.set_skip_caption(skip_caption)
|
| 285 |
-
self.constrained_processor.set_skip_language(skip_language)
|
| 286 |
-
# Set generation phase for phase-aware processing
|
| 287 |
-
self.constrained_processor.set_generation_phase(generation_phase)
|
| 288 |
-
|
| 289 |
-
constrained_processor = self.constrained_processor
|
| 290 |
|
| 291 |
sampling_params = SamplingParams(
|
| 292 |
max_tokens=self.max_model_len - 64,
|
|
@@ -301,119 +466,25 @@ class LLMHandler:
|
|
| 301 |
|
| 302 |
if cfg_scale > 1.0:
|
| 303 |
# Build unconditional prompt based on generation phase
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
else:
|
| 312 |
-
# CoT phase: unconditional prompt
|
| 313 |
-
# If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
|
| 314 |
-
formatted_unconditional_prompt = self.build_formatted_prompt(
|
| 315 |
-
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
outputs = self.llm.generate(
|
| 319 |
-
[formatted_prompt],
|
| 320 |
-
sampling_params,
|
| 321 |
-
unconditional_prompts=[formatted_unconditional_prompt],
|
| 322 |
-
)
|
| 323 |
-
else:
|
| 324 |
-
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
| 325 |
-
|
| 326 |
-
# Extract text (retain original selection order/logic)
|
| 327 |
-
if isinstance(outputs, list) and len(outputs) > 0:
|
| 328 |
-
if hasattr(outputs[0], "outputs") and len(outputs[0].outputs) > 0:
|
| 329 |
-
output_text = outputs[0].outputs[0].text
|
| 330 |
-
elif hasattr(outputs[0], "text"):
|
| 331 |
-
output_text = outputs[0].text
|
| 332 |
-
elif isinstance(outputs[0], dict) and "text" in outputs[0]:
|
| 333 |
-
output_text = outputs[0]["text"]
|
| 334 |
-
else:
|
| 335 |
-
output_text = str(outputs[0])
|
| 336 |
-
else:
|
| 337 |
-
output_text = str(outputs)
|
| 338 |
-
|
| 339 |
-
return output_text
|
| 340 |
-
|
| 341 |
-
def _run_vllm_batch(
|
| 342 |
-
self,
|
| 343 |
-
formatted_prompts: List[str],
|
| 344 |
-
temperature: float,
|
| 345 |
-
cfg_scale: float,
|
| 346 |
-
negative_prompt: str,
|
| 347 |
-
top_k: Optional[int],
|
| 348 |
-
top_p: Optional[float],
|
| 349 |
-
repetition_penalty: float,
|
| 350 |
-
use_constrained_decoding: bool = True,
|
| 351 |
-
constrained_decoding_debug: bool = False,
|
| 352 |
-
target_duration: Optional[float] = None,
|
| 353 |
-
generation_phase: str = "codes",
|
| 354 |
-
caption: str = "",
|
| 355 |
-
lyrics: str = "",
|
| 356 |
-
cot_text: str = "",
|
| 357 |
-
seeds: Optional[List[int]] = None,
|
| 358 |
-
) -> List[str]:
|
| 359 |
-
"""Batch generation using vllm backend"""
|
| 360 |
-
from nanovllm import SamplingParams
|
| 361 |
-
|
| 362 |
-
batch_size = len(formatted_prompts)
|
| 363 |
-
|
| 364 |
-
# Determine effective temperature for sampler
|
| 365 |
-
effective_sampler_temp = temperature
|
| 366 |
-
|
| 367 |
-
# Use shared constrained processor if enabled
|
| 368 |
-
# Note: vllm batch mode uses same processor for all items
|
| 369 |
-
constrained_processor = None
|
| 370 |
-
if use_constrained_decoding:
|
| 371 |
-
# Reset processor state for new generation
|
| 372 |
-
self.constrained_processor.reset()
|
| 373 |
-
|
| 374 |
-
self.constrained_processor.enabled = use_constrained_decoding
|
| 375 |
-
self.constrained_processor.debug = constrained_decoding_debug
|
| 376 |
-
self.constrained_processor.metadata_temperature = None
|
| 377 |
-
self.constrained_processor.codes_temperature = None
|
| 378 |
-
self.constrained_processor.set_target_duration(target_duration)
|
| 379 |
-
self.constrained_processor.set_user_metadata(None)
|
| 380 |
-
self.constrained_processor.set_stop_at_reasoning(False)
|
| 381 |
-
self.constrained_processor.set_skip_genres(True)
|
| 382 |
-
self.constrained_processor.set_skip_caption(True)
|
| 383 |
-
self.constrained_processor.set_skip_language(True)
|
| 384 |
-
self.constrained_processor.set_generation_phase(generation_phase)
|
| 385 |
-
|
| 386 |
-
constrained_processor = self.constrained_processor
|
| 387 |
-
|
| 388 |
-
# Build sampling params
|
| 389 |
-
sampling_params = SamplingParams(
|
| 390 |
-
max_tokens=self.max_model_len - 64,
|
| 391 |
-
temperature=effective_sampler_temp,
|
| 392 |
-
cfg_scale=cfg_scale,
|
| 393 |
-
top_k=top_k,
|
| 394 |
-
top_p=top_p,
|
| 395 |
-
repetition_penalty=repetition_penalty,
|
| 396 |
-
logits_processor=constrained_processor,
|
| 397 |
-
logits_processor_update_state=constrained_processor.update_state if constrained_processor else None,
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
# Generate with or without CFG
|
| 401 |
-
if cfg_scale > 1.0:
|
| 402 |
-
# Build unconditional prompts
|
| 403 |
-
formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
|
| 404 |
-
caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
|
| 405 |
)
|
| 406 |
unconditional_prompts = [formatted_unconditional_prompt] * batch_size
|
| 407 |
|
| 408 |
outputs = self.llm.generate(
|
| 409 |
-
|
| 410 |
sampling_params,
|
| 411 |
unconditional_prompts=unconditional_prompts,
|
| 412 |
)
|
| 413 |
else:
|
| 414 |
-
outputs = self.llm.generate(
|
| 415 |
-
|
| 416 |
-
# Extract text from
|
| 417 |
output_texts = []
|
| 418 |
for output in outputs:
|
| 419 |
if hasattr(output, "outputs") and len(output.outputs) > 0:
|
|
@@ -424,70 +495,11 @@ class LLMHandler:
|
|
| 424 |
output_texts.append(output["text"])
|
| 425 |
else:
|
| 426 |
output_texts.append(str(output))
|
| 427 |
-
|
| 428 |
-
return output_texts
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
formatted_prompts: List[str],
|
| 433 |
-
temperature: float,
|
| 434 |
-
cfg_scale: float,
|
| 435 |
-
negative_prompt: str,
|
| 436 |
-
top_k: Optional[int],
|
| 437 |
-
top_p: Optional[float],
|
| 438 |
-
repetition_penalty: float,
|
| 439 |
-
use_constrained_decoding: bool = True,
|
| 440 |
-
constrained_decoding_debug: bool = False,
|
| 441 |
-
target_duration: Optional[float] = None,
|
| 442 |
-
generation_phase: str = "codes",
|
| 443 |
-
caption: str = "",
|
| 444 |
-
lyrics: str = "",
|
| 445 |
-
cot_text: str = "",
|
| 446 |
-
seeds: Optional[List[int]] = None,
|
| 447 |
-
) -> List[str]:
|
| 448 |
-
"""Batch generation using PyTorch backend"""
|
| 449 |
-
import random
|
| 450 |
-
|
| 451 |
-
batch_size = len(formatted_prompts)
|
| 452 |
-
output_texts = []
|
| 453 |
-
|
| 454 |
-
# Generate each item sequentially with different seeds
|
| 455 |
-
# (PyTorch backend doesn't support true batching efficiently)
|
| 456 |
-
for i, formatted_prompt in enumerate(formatted_prompts):
|
| 457 |
-
# Set seed for this item if provided
|
| 458 |
-
if seeds and i < len(seeds):
|
| 459 |
-
torch.manual_seed(seeds[i])
|
| 460 |
-
if torch.cuda.is_available():
|
| 461 |
-
torch.cuda.manual_seed_all(seeds[i])
|
| 462 |
-
|
| 463 |
-
# Generate using single-item method
|
| 464 |
-
output_text = self._run_pt_from_formatted(
|
| 465 |
-
formatted_prompt=formatted_prompt,
|
| 466 |
-
temperature=temperature,
|
| 467 |
-
cfg_scale=cfg_scale,
|
| 468 |
-
negative_prompt=negative_prompt,
|
| 469 |
-
top_k=top_k,
|
| 470 |
-
top_p=top_p,
|
| 471 |
-
repetition_penalty=repetition_penalty,
|
| 472 |
-
use_constrained_decoding=use_constrained_decoding,
|
| 473 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 474 |
-
target_duration=target_duration,
|
| 475 |
-
user_metadata=None,
|
| 476 |
-
stop_at_reasoning=False,
|
| 477 |
-
skip_genres=True,
|
| 478 |
-
skip_caption=True,
|
| 479 |
-
skip_language=True,
|
| 480 |
-
generation_phase=generation_phase,
|
| 481 |
-
caption=caption,
|
| 482 |
-
lyrics=lyrics,
|
| 483 |
-
cot_text=cot_text,
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
-
output_texts.append(output_text)
|
| 487 |
-
|
| 488 |
-
return output_texts
|
| 489 |
|
| 490 |
-
def
|
| 491 |
self,
|
| 492 |
formatted_prompt: str,
|
| 493 |
temperature: float,
|
|
@@ -496,20 +508,20 @@ class LLMHandler:
|
|
| 496 |
top_k: Optional[int],
|
| 497 |
top_p: Optional[float],
|
| 498 |
repetition_penalty: float,
|
| 499 |
-
use_constrained_decoding: bool
|
| 500 |
-
constrained_decoding_debug: bool
|
| 501 |
-
target_duration: Optional[float]
|
| 502 |
-
user_metadata: Optional[Dict[str, Optional[str]]]
|
| 503 |
-
stop_at_reasoning: bool
|
| 504 |
-
skip_genres: bool
|
| 505 |
-
skip_caption: bool
|
| 506 |
-
skip_language: bool
|
| 507 |
-
generation_phase: str
|
| 508 |
-
caption: str
|
| 509 |
-
lyrics: str
|
| 510 |
-
cot_text: str
|
| 511 |
) -> str:
|
| 512 |
-
"""
|
| 513 |
inputs = self.llm_tokenizer(
|
| 514 |
formatted_prompt,
|
| 515 |
return_tensors="pt",
|
|
@@ -517,27 +529,19 @@ class LLMHandler:
|
|
| 517 |
truncation=True,
|
| 518 |
)
|
| 519 |
|
| 520 |
-
#
|
| 521 |
-
constrained_processor =
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
# Set skip_caption and skip_language based on flags
|
| 534 |
-
self.constrained_processor.set_skip_genres(skip_genres)
|
| 535 |
-
self.constrained_processor.set_skip_caption(skip_caption)
|
| 536 |
-
self.constrained_processor.set_skip_language(skip_language)
|
| 537 |
-
# Set generation phase for phase-aware processing
|
| 538 |
-
self.constrained_processor.set_generation_phase(generation_phase)
|
| 539 |
-
|
| 540 |
-
constrained_processor = self.constrained_processor
|
| 541 |
|
| 542 |
with self._load_model_context():
|
| 543 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
@@ -546,25 +550,18 @@ class LLMHandler:
|
|
| 546 |
max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
|
| 547 |
|
| 548 |
# Build logits processor list (only for CFG and repetition penalty)
|
| 549 |
-
logits_processor =
|
| 550 |
-
|
| 551 |
-
# Add repetition penalty if needed
|
| 552 |
-
if repetition_penalty != 1.0:
|
| 553 |
-
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 554 |
|
| 555 |
if cfg_scale > 1.0:
|
| 556 |
# Build unconditional prompt based on generation phase
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
formatted_unconditional_prompt = self.build_formatted_prompt(
|
| 566 |
-
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
|
| 567 |
-
)
|
| 568 |
|
| 569 |
# Tokenize both prompts together to ensure same length (with left padding)
|
| 570 |
# Left padding is important for generation tasks
|
|
@@ -657,7 +654,101 @@ class LLMHandler:
|
|
| 657 |
|
| 658 |
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 659 |
return output_text
|
| 660 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
|
| 662 |
"""Check if all required metadata are present."""
|
| 663 |
if user_metadata is None:
|
|
@@ -708,7 +799,9 @@ class LLMHandler:
|
|
| 708 |
use_cot_caption: bool = True,
|
| 709 |
use_cot_language: bool = True,
|
| 710 |
is_format_caption: bool = False,
|
| 711 |
-
|
|
|
|
|
|
|
| 712 |
"""Two-phase LM generation: CoT generation followed by audio codes generation.
|
| 713 |
|
| 714 |
- infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
|
|
@@ -721,30 +814,56 @@ class LLMHandler:
|
|
| 721 |
If specified, constrained decoding will inject these values directly.
|
| 722 |
use_cot_caption: Whether to generate caption in CoT (default True).
|
| 723 |
use_cot_language: Whether to generate language in CoT (default True).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
"""
|
| 725 |
import time
|
|
|
|
| 726 |
|
| 727 |
infer_type = (infer_type or "").strip().lower()
|
| 728 |
if infer_type not in {"dit", "llm_dit"}:
|
|
|
|
|
|
|
| 729 |
return {}, "", f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
|
| 730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
metadata = {}
|
| 732 |
audio_codes = ""
|
| 733 |
has_all_metas = self.has_all_metas(user_metadata)
|
| 734 |
-
|
| 735 |
-
# Timing variables
|
| 736 |
phase1_time = 0.0
|
| 737 |
phase2_time = 0.0
|
| 738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
# ========== PHASE 1: CoT Generation ==========
|
| 740 |
-
#
|
| 741 |
-
if not has_all_metas
|
| 742 |
-
|
|
|
|
|
|
|
|
|
|
| 743 |
phase1_start = time.time()
|
| 744 |
|
| 745 |
# Build formatted prompt for CoT phase
|
| 746 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
|
| 747 |
-
|
| 748 |
logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
|
| 749 |
# Generate CoT (stop at </think>)
|
| 750 |
cot_output_text, status = self.generate_from_formatted_prompt(
|
|
@@ -774,23 +893,39 @@ class LLMHandler:
|
|
| 774 |
phase1_time = time.time() - phase1_start
|
| 775 |
|
| 776 |
if not cot_output_text:
|
|
|
|
|
|
|
| 777 |
return {}, "", status
|
| 778 |
|
| 779 |
# Parse metadata from CoT output
|
| 780 |
metadata, _ = self.parse_lm_output(cot_output_text)
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
| 782 |
else:
|
| 783 |
# Use user-provided metadata
|
| 784 |
-
|
|
|
|
|
|
|
|
|
|
| 785 |
metadata = {k: v for k, v in user_metadata.items() if v is not None}
|
| 786 |
|
| 787 |
# If infer_type is 'dit', stop here and return only metadata
|
| 788 |
if infer_type == "dit":
|
| 789 |
-
|
| 790 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
|
| 792 |
# ========== PHASE 2: Audio Codes Generation ==========
|
| 793 |
-
|
|
|
|
|
|
|
|
|
|
| 794 |
phase2_start = time.time()
|
| 795 |
|
| 796 |
# Format metadata as CoT using YAML (matching training format)
|
|
@@ -799,221 +934,110 @@ class LLMHandler:
|
|
| 799 |
# Build formatted prompt with CoT for codes generation phase
|
| 800 |
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
|
| 801 |
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
|
| 802 |
-
# Generate audio codes
|
| 803 |
-
codes_output_text, status = self.generate_from_formatted_prompt(
|
| 804 |
-
formatted_prompt=formatted_prompt_with_cot,
|
| 805 |
-
cfg={
|
| 806 |
-
"temperature": temperature,
|
| 807 |
-
"cfg_scale": cfg_scale,
|
| 808 |
-
"negative_prompt": negative_prompt,
|
| 809 |
-
"top_k": top_k,
|
| 810 |
-
"top_p": top_p,
|
| 811 |
-
"repetition_penalty": repetition_penalty,
|
| 812 |
-
"target_duration": target_duration,
|
| 813 |
-
"user_metadata": None, # No user metadata injection in Phase 2
|
| 814 |
-
"skip_caption": True, # Skip caption since CoT is already included
|
| 815 |
-
"skip_language": True, # Skip language since CoT is already included
|
| 816 |
-
"generation_phase": "codes",
|
| 817 |
-
# Pass context for building unconditional prompt in codes phase
|
| 818 |
-
"caption": caption,
|
| 819 |
-
"lyrics": lyrics,
|
| 820 |
-
"cot_text": cot_text,
|
| 821 |
-
},
|
| 822 |
-
use_constrained_decoding=use_constrained_decoding,
|
| 823 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 824 |
-
stop_at_reasoning=False, # Generate codes until EOS
|
| 825 |
-
)
|
| 826 |
-
|
| 827 |
-
if not codes_output_text:
|
| 828 |
-
return metadata, "", status
|
| 829 |
-
|
| 830 |
-
phase2_time = time.time() - phase2_start
|
| 831 |
-
|
| 832 |
-
# Parse audio codes from output (metadata should be same as Phase 1)
|
| 833 |
-
_, audio_codes = self.parse_lm_output(codes_output_text)
|
| 834 |
-
|
| 835 |
-
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 836 |
-
logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
|
| 837 |
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
def generate_with_stop_condition_batch(
|
| 842 |
-
self,
|
| 843 |
-
caption: str,
|
| 844 |
-
lyrics: str,
|
| 845 |
-
batch_size: int,
|
| 846 |
-
infer_type: str = "llm_dit",
|
| 847 |
-
temperature: float = 0.85,
|
| 848 |
-
cfg_scale: float = 1.0,
|
| 849 |
-
negative_prompt: str = "NO USER INPUT",
|
| 850 |
-
top_k: Optional[int] = None,
|
| 851 |
-
top_p: Optional[float] = None,
|
| 852 |
-
repetition_penalty: float = 1.0,
|
| 853 |
-
use_constrained_decoding: bool = True,
|
| 854 |
-
constrained_decoding_debug: bool = False,
|
| 855 |
-
target_duration: Optional[float] = None,
|
| 856 |
-
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 857 |
-
use_cot_caption: bool = True,
|
| 858 |
-
use_cot_language: bool = True,
|
| 859 |
-
is_format_caption: bool = False,
|
| 860 |
-
seeds: Optional[List[int]] = None,
|
| 861 |
-
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
| 862 |
-
"""
|
| 863 |
-
Batch version of generate_with_stop_condition.
|
| 864 |
-
|
| 865 |
-
Generates multiple audio codes with same conditions but different seeds (for diversity).
|
| 866 |
-
|
| 867 |
-
Args:
|
| 868 |
-
caption: Same caption for all items
|
| 869 |
-
lyrics: Same lyrics for all items
|
| 870 |
-
batch_size: Number of items to generate
|
| 871 |
-
seeds: Optional list of seeds for each batch item (for reproducibility)
|
| 872 |
-
... (other args same as generate_with_stop_condition)
|
| 873 |
-
|
| 874 |
-
Returns:
|
| 875 |
-
Tuple of (metadata_list, audio_codes_list, status_message)
|
| 876 |
-
- metadata_list: List of metadata dicts (same metadata for all items)
|
| 877 |
-
- audio_codes_list: List of audio code strings (one per item, different due to sampling)
|
| 878 |
-
- status_message: Generation status
|
| 879 |
-
"""
|
| 880 |
-
import random
|
| 881 |
-
import time
|
| 882 |
-
|
| 883 |
-
infer_type = (infer_type or "").strip().lower()
|
| 884 |
-
if infer_type not in {"dit", "llm_dit"}:
|
| 885 |
-
return [], [], f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
|
| 886 |
-
|
| 887 |
-
# Generate seeds if not provided
|
| 888 |
-
if seeds is None:
|
| 889 |
-
seeds = [random.randint(0, 2**32 - 1) for _ in range(batch_size)]
|
| 890 |
-
elif len(seeds) < batch_size:
|
| 891 |
-
# Pad with random seeds if not enough provided
|
| 892 |
-
seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(batch_size - len(seeds))]
|
| 893 |
-
else:
|
| 894 |
-
seeds = seeds[:batch_size] # Truncate if too many
|
| 895 |
-
|
| 896 |
-
# Timing variables
|
| 897 |
-
phase1_time = 0.0
|
| 898 |
-
phase2_time = 0.0
|
| 899 |
-
|
| 900 |
-
# ========== PHASE 1: CoT Generation (ONCE for all items) ==========
|
| 901 |
-
has_all_metas = self.has_all_metas(user_metadata)
|
| 902 |
-
|
| 903 |
-
if not has_all_metas or not is_format_caption:
|
| 904 |
-
logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
|
| 905 |
-
phase1_start = time.time()
|
| 906 |
|
| 907 |
-
#
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
use_constrained_decoding=use_constrained_decoding,
|
| 919 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 920 |
-
|
| 921 |
-
user_metadata=user_metadata,
|
| 922 |
-
use_cot_caption=use_cot_caption,
|
| 923 |
-
use_cot_language=use_cot_language,
|
| 924 |
-
is_format_caption=is_format_caption,
|
| 925 |
)
|
| 926 |
|
| 927 |
-
|
|
|
|
| 928 |
|
| 929 |
-
|
| 930 |
-
return [], [], status
|
| 931 |
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
status_msg = f"✅ Generated CoT metadata successfully (batch mode)\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
|
| 942 |
-
return metadata_list, [""] * batch_size, status_msg
|
| 943 |
-
|
| 944 |
-
# ========== PHASE 2: Audio Codes Generation (BATCH) ==========
|
| 945 |
-
logger.info(f"Batch Phase 2: Generating audio codes for {batch_size} items...")
|
| 946 |
-
phase2_start = time.time()
|
| 947 |
-
|
| 948 |
-
# Format metadata as CoT
|
| 949 |
-
cot_text = self._format_metadata_as_cot(metadata)
|
| 950 |
-
|
| 951 |
-
# Build formatted prompt with CoT
|
| 952 |
-
formatted_prompt = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
|
| 953 |
-
|
| 954 |
-
# Replicate prompt for batch (all items have same prompt, differ by seeds)
|
| 955 |
-
formatted_prompts = [formatted_prompt] * batch_size
|
| 956 |
-
|
| 957 |
-
# Call backend-specific batch generation
|
| 958 |
-
try:
|
| 959 |
-
if self.llm_backend == "vllm":
|
| 960 |
-
codes_outputs = self._run_vllm_batch(
|
| 961 |
-
formatted_prompts=formatted_prompts,
|
| 962 |
-
temperature=temperature,
|
| 963 |
-
cfg_scale=cfg_scale,
|
| 964 |
-
negative_prompt=negative_prompt,
|
| 965 |
-
top_k=top_k,
|
| 966 |
-
top_p=top_p,
|
| 967 |
-
repetition_penalty=repetition_penalty,
|
| 968 |
-
use_constrained_decoding=use_constrained_decoding,
|
| 969 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 970 |
-
target_duration=target_duration,
|
| 971 |
-
generation_phase="codes",
|
| 972 |
-
caption=caption,
|
| 973 |
-
lyrics=lyrics,
|
| 974 |
-
cot_text=cot_text,
|
| 975 |
-
seeds=seeds,
|
| 976 |
-
)
|
| 977 |
-
else: # pt backend
|
| 978 |
-
codes_outputs = self._run_pt_batch(
|
| 979 |
-
formatted_prompts=formatted_prompts,
|
| 980 |
-
temperature=temperature,
|
| 981 |
-
cfg_scale=cfg_scale,
|
| 982 |
-
negative_prompt=negative_prompt,
|
| 983 |
-
top_k=top_k,
|
| 984 |
-
top_p=top_p,
|
| 985 |
-
repetition_penalty=repetition_penalty,
|
| 986 |
-
use_constrained_decoding=use_constrained_decoding,
|
| 987 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 988 |
-
target_duration=target_duration,
|
| 989 |
-
generation_phase="codes",
|
| 990 |
-
caption=caption,
|
| 991 |
-
lyrics=lyrics,
|
| 992 |
-
cot_text=cot_text,
|
| 993 |
-
seeds=seeds,
|
| 994 |
-
)
|
| 995 |
-
except Exception as e:
|
| 996 |
-
error_msg = f"❌ Error in batch codes generation: {str(e)}"
|
| 997 |
-
logger.error(error_msg)
|
| 998 |
-
return [], [], error_msg
|
| 999 |
-
|
| 1000 |
-
# Parse audio codes from each output
|
| 1001 |
-
audio_codes_list = []
|
| 1002 |
-
metadata_list = []
|
| 1003 |
-
for output_text in codes_outputs:
|
| 1004 |
-
_, audio_codes = self.parse_lm_output(output_text)
|
| 1005 |
-
audio_codes_list.append(audio_codes)
|
| 1006 |
-
metadata_list.append(metadata.copy()) # Same metadata for all
|
| 1007 |
-
|
| 1008 |
-
phase2_time = time.time() - phase2_start
|
| 1009 |
-
|
| 1010 |
-
# Log results
|
| 1011 |
-
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
|
| 1012 |
-
logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
|
| 1013 |
-
|
| 1014 |
-
status_msg = f"✅ Batch generation completed ({batch_size} items)\nPhase 1: CoT metadata\nPhase 2: {sum(codes_counts)} total codes ({codes_counts})\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
|
| 1015 |
-
return metadata_list, audio_codes_list, status_msg
|
| 1016 |
-
|
| 1017 |
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
|
| 1018 |
"""
|
| 1019 |
Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
|
|
@@ -1035,7 +1059,7 @@ class LLMHandler:
|
|
| 1035 |
if is_negative_prompt:
|
| 1036 |
# Unconditional prompt for CFG
|
| 1037 |
# Check if user provided a meaningful negative prompt (not the default)
|
| 1038 |
-
has_negative_prompt =
|
| 1039 |
|
| 1040 |
if generation_phase == "cot":
|
| 1041 |
# CoT phase unconditional prompt
|
|
@@ -1086,7 +1110,7 @@ class LLMHandler:
|
|
| 1086 |
if is_negative_prompt:
|
| 1087 |
# Unconditional prompt for codes phase
|
| 1088 |
# Check if user provided a meaningful negative prompt
|
| 1089 |
-
has_negative_prompt =
|
| 1090 |
|
| 1091 |
# Use empty CoT for unconditional
|
| 1092 |
cot_for_prompt = "<think>\n</think>"
|
|
@@ -1369,8 +1393,8 @@ class LLMHandler:
|
|
| 1369 |
|
| 1370 |
try:
|
| 1371 |
if self.llm_backend == "vllm":
|
| 1372 |
-
output_text = self.
|
| 1373 |
-
|
| 1374 |
temperature=temperature,
|
| 1375 |
cfg_scale=cfg_scale,
|
| 1376 |
negative_prompt=negative_prompt,
|
|
@@ -1393,8 +1417,8 @@ class LLMHandler:
|
|
| 1393 |
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 1394 |
|
| 1395 |
# PyTorch backend
|
| 1396 |
-
output_text = self.
|
| 1397 |
-
|
| 1398 |
temperature=temperature,
|
| 1399 |
cfg_scale=cfg_scale,
|
| 1400 |
negative_prompt=negative_prompt,
|
|
@@ -1459,26 +1483,12 @@ class LLMHandler:
|
|
| 1459 |
eos_token_id = pad_token_id
|
| 1460 |
|
| 1461 |
# Build logits processor for repetition penalty
|
| 1462 |
-
logits_processor =
|
| 1463 |
-
if repetition_penalty != 1.0:
|
| 1464 |
-
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 1465 |
|
| 1466 |
with torch.no_grad():
|
| 1467 |
for step in range(max_new_tokens):
|
| 1468 |
# Forward pass
|
| 1469 |
-
|
| 1470 |
-
outputs = model(
|
| 1471 |
-
input_ids=generated_ids,
|
| 1472 |
-
**model_kwargs,
|
| 1473 |
-
use_cache=use_cache,
|
| 1474 |
-
)
|
| 1475 |
-
else:
|
| 1476 |
-
outputs = model(
|
| 1477 |
-
input_ids=generated_ids[:, -1:],
|
| 1478 |
-
past_key_values=past_key_values,
|
| 1479 |
-
**model_kwargs,
|
| 1480 |
-
use_cache=use_cache,
|
| 1481 |
-
)
|
| 1482 |
|
| 1483 |
# Get logits for the last position
|
| 1484 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
|
|
@@ -1491,41 +1501,18 @@ class LLMHandler:
|
|
| 1491 |
for processor in logits_processor:
|
| 1492 |
next_token_logits = processor(generated_ids, next_token_logits)
|
| 1493 |
|
| 1494 |
-
# Apply top-k filtering
|
| 1495 |
-
|
| 1496 |
-
|
| 1497 |
-
next_token_logits[indices_to_remove] = float('-inf')
|
| 1498 |
-
|
| 1499 |
-
# Apply top-p filtering
|
| 1500 |
-
if top_p is not None and 0.0 < top_p < 1.0:
|
| 1501 |
-
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 1502 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 1503 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 1504 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 1505 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 1506 |
-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 1507 |
-
next_token_logits[indices_to_remove] = float('-inf')
|
| 1508 |
|
| 1509 |
# Apply temperature and sample
|
| 1510 |
-
|
| 1511 |
-
next_token_logits = next_token_logits / temperature
|
| 1512 |
-
probs = torch.softmax(next_token_logits, dim=-1)
|
| 1513 |
-
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 1514 |
-
else:
|
| 1515 |
-
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
| 1516 |
|
| 1517 |
# Update constrained processor state
|
| 1518 |
-
|
| 1519 |
-
for b in range(next_tokens.shape[0]):
|
| 1520 |
-
constrained_processor.update_state(next_tokens[b].item())
|
| 1521 |
|
| 1522 |
# Check for EOS token
|
| 1523 |
-
should_stop =
|
| 1524 |
-
if torch.any(next_tokens == eos_token_id):
|
| 1525 |
-
should_stop = True
|
| 1526 |
-
elif pad_token_id is not None and pad_token_id != eos_token_id:
|
| 1527 |
-
if torch.any(next_tokens == pad_token_id):
|
| 1528 |
-
should_stop = True
|
| 1529 |
|
| 1530 |
# Append token to sequence
|
| 1531 |
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
|
@@ -1601,28 +1588,12 @@ class LLMHandler:
|
|
| 1601 |
eos_token_id = pad_token_id
|
| 1602 |
|
| 1603 |
# Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
|
| 1604 |
-
logits_processor =
|
| 1605 |
-
if repetition_penalty != 1.0:
|
| 1606 |
-
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 1607 |
|
| 1608 |
with torch.no_grad():
|
| 1609 |
for step in range(max_new_tokens):
|
| 1610 |
# Forward pass for the entire batch (conditional + unconditional)
|
| 1611 |
-
|
| 1612 |
-
# First step: full forward pass
|
| 1613 |
-
outputs = model(
|
| 1614 |
-
input_ids=generated_ids,
|
| 1615 |
-
**model_kwargs,
|
| 1616 |
-
use_cache=use_cache,
|
| 1617 |
-
)
|
| 1618 |
-
else:
|
| 1619 |
-
# Subsequent steps: only forward the last token (utilizing KV cache)
|
| 1620 |
-
outputs = model(
|
| 1621 |
-
input_ids=generated_ids[:, -1:],
|
| 1622 |
-
past_key_values=past_key_values,
|
| 1623 |
-
**model_kwargs,
|
| 1624 |
-
use_cache=use_cache,
|
| 1625 |
-
)
|
| 1626 |
|
| 1627 |
# Get logits for the last position
|
| 1628 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
|
|
@@ -1645,45 +1616,20 @@ class LLMHandler:
|
|
| 1645 |
for processor in logits_processor:
|
| 1646 |
cfg_logits = processor(current_input_ids, cfg_logits)
|
| 1647 |
|
| 1648 |
-
# Apply top-k filtering
|
| 1649 |
-
|
| 1650 |
-
|
| 1651 |
-
cfg_logits[indices_to_remove] = float('-inf')
|
| 1652 |
-
|
| 1653 |
-
# Apply top-p (nucleus) filtering
|
| 1654 |
-
if top_p is not None and 0.0 < top_p < 1.0:
|
| 1655 |
-
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
| 1656 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 1657 |
-
# Remove tokens with cumulative probability above the threshold
|
| 1658 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 1659 |
-
# Shift the indices to the right to keep also the first token above the threshold
|
| 1660 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 1661 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 1662 |
-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 1663 |
-
cfg_logits[indices_to_remove] = float('-inf')
|
| 1664 |
|
| 1665 |
# Apply temperature and sample
|
| 1666 |
-
|
| 1667 |
-
cfg_logits = cfg_logits / temperature
|
| 1668 |
-
probs = torch.softmax(cfg_logits, dim=-1)
|
| 1669 |
-
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 1670 |
-
else:
|
| 1671 |
-
next_tokens = torch.argmax(cfg_logits, dim=-1)
|
| 1672 |
|
| 1673 |
# Update constrained processor state AFTER sampling
|
| 1674 |
-
|
| 1675 |
-
for b in range(next_tokens.shape[0]):
|
| 1676 |
-
constrained_processor.update_state(next_tokens[b].item())
|
| 1677 |
|
| 1678 |
# Check for EOS token in conditional sequences BEFORE unsqueezing
|
| 1679 |
# Stop if any conditional sequence generates EOS token
|
| 1680 |
# next_tokens shape: [batch_size] (only conditional tokens)
|
| 1681 |
-
should_stop =
|
| 1682 |
-
if torch.any(next_tokens == eos_token_id):
|
| 1683 |
-
should_stop = True
|
| 1684 |
-
elif pad_token_id is not None and pad_token_id != eos_token_id:
|
| 1685 |
-
if torch.any(next_tokens == pad_token_id):
|
| 1686 |
-
should_stop = True
|
| 1687 |
|
| 1688 |
# Apply the same sampled tokens to both conditional and unconditional sequences
|
| 1689 |
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
|
|
|
| 5 |
import os
|
| 6 |
import traceback
|
| 7 |
import time
|
| 8 |
+
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 9 |
from contextlib import contextmanager
|
| 10 |
|
| 11 |
import yaml
|
|
|
|
| 85 |
except Exception as e:
|
| 86 |
return 0.9, False
|
| 87 |
|
| 88 |
+
def _has_meaningful_negative_prompt(self, negative_prompt: str) -> bool:
|
| 89 |
+
"""Check if negative prompt is meaningful (not default/empty)"""
|
| 90 |
+
return negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
|
| 91 |
+
|
| 92 |
+
def _build_logits_processor(self, repetition_penalty: float) -> LogitsProcessorList:
|
| 93 |
+
"""Build logits processor list with repetition penalty if needed"""
|
| 94 |
+
logits_processor = LogitsProcessorList()
|
| 95 |
+
if repetition_penalty != 1.0:
|
| 96 |
+
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 97 |
+
return logits_processor
|
| 98 |
+
|
| 99 |
+
def _setup_constrained_processor(
|
| 100 |
+
self,
|
| 101 |
+
use_constrained_decoding: bool,
|
| 102 |
+
constrained_decoding_debug: bool,
|
| 103 |
+
target_duration: Optional[float],
|
| 104 |
+
user_metadata: Optional[Dict[str, Optional[str]]],
|
| 105 |
+
stop_at_reasoning: bool,
|
| 106 |
+
skip_genres: bool,
|
| 107 |
+
skip_caption: bool,
|
| 108 |
+
skip_language: bool,
|
| 109 |
+
generation_phase: str,
|
| 110 |
+
is_batch: bool = False,
|
| 111 |
+
metadata_temperature: Optional[float] = None,
|
| 112 |
+
codes_temperature: Optional[float] = None,
|
| 113 |
+
) -> Optional[MetadataConstrainedLogitsProcessor]:
|
| 114 |
+
"""Setup and configure constrained processor for generation"""
|
| 115 |
+
use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
|
| 116 |
+
|
| 117 |
+
if not use_constrained_decoding and not use_phase_temperatures:
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
# Reset processor state for new generation
|
| 121 |
+
self.constrained_processor.reset()
|
| 122 |
+
|
| 123 |
+
# Use shared processor, just update settings
|
| 124 |
+
self.constrained_processor.enabled = use_constrained_decoding
|
| 125 |
+
self.constrained_processor.debug = constrained_decoding_debug
|
| 126 |
+
|
| 127 |
+
# Phase temperatures only supported in single mode
|
| 128 |
+
if use_phase_temperatures:
|
| 129 |
+
self.constrained_processor.metadata_temperature = metadata_temperature
|
| 130 |
+
self.constrained_processor.codes_temperature = codes_temperature
|
| 131 |
+
else:
|
| 132 |
+
self.constrained_processor.metadata_temperature = None
|
| 133 |
+
self.constrained_processor.codes_temperature = None
|
| 134 |
+
|
| 135 |
+
self.constrained_processor.set_target_duration(target_duration)
|
| 136 |
+
|
| 137 |
+
# Batch mode uses default/disabled settings for these options
|
| 138 |
+
if is_batch:
|
| 139 |
+
self.constrained_processor.set_user_metadata(None)
|
| 140 |
+
self.constrained_processor.set_stop_at_reasoning(False)
|
| 141 |
+
self.constrained_processor.set_skip_genres(True)
|
| 142 |
+
self.constrained_processor.set_skip_caption(True)
|
| 143 |
+
self.constrained_processor.set_skip_language(True)
|
| 144 |
+
else:
|
| 145 |
+
# Single mode uses provided settings
|
| 146 |
+
self.constrained_processor.set_user_metadata(user_metadata)
|
| 147 |
+
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
| 148 |
+
self.constrained_processor.set_skip_genres(skip_genres)
|
| 149 |
+
self.constrained_processor.set_skip_caption(skip_caption)
|
| 150 |
+
self.constrained_processor.set_skip_language(skip_language)
|
| 151 |
+
|
| 152 |
+
# Set generation phase for phase-aware processing
|
| 153 |
+
self.constrained_processor.set_generation_phase(generation_phase)
|
| 154 |
+
|
| 155 |
+
return self.constrained_processor
|
| 156 |
+
|
| 157 |
+
def _build_unconditional_prompt(
|
| 158 |
+
self,
|
| 159 |
+
caption: str,
|
| 160 |
+
lyrics: str,
|
| 161 |
+
cot_text: str,
|
| 162 |
+
negative_prompt: str,
|
| 163 |
+
generation_phase: str,
|
| 164 |
+
is_batch: bool = False,
|
| 165 |
+
) -> str:
|
| 166 |
+
"""Build unconditional prompt for CFG based on generation phase and batch mode"""
|
| 167 |
+
if is_batch or generation_phase == "codes":
|
| 168 |
+
# Codes phase or batch mode: use empty CoT in unconditional prompt
|
| 169 |
+
return self.build_formatted_prompt_with_cot(
|
| 170 |
+
caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
# CoT phase (single mode only): unconditional prompt
|
| 174 |
+
# If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
|
| 175 |
+
return self.build_formatted_prompt(
|
| 176 |
+
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
|
| 180 |
+
"""Load PyTorch model from path and return (success, status_message)"""
|
| 181 |
+
try:
|
| 182 |
+
self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
| 183 |
+
if not self.offload_to_cpu:
|
| 184 |
+
self.llm = self.llm.to(device).to(self.dtype)
|
| 185 |
+
else:
|
| 186 |
+
self.llm = self.llm.to("cpu").to(self.dtype)
|
| 187 |
+
self.llm.eval()
|
| 188 |
+
self.llm_backend = "pt"
|
| 189 |
+
self.llm_initialized = True
|
| 190 |
+
logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
|
| 191 |
+
status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}"
|
| 192 |
+
return True, status_msg
|
| 193 |
+
except Exception as e:
|
| 194 |
+
return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 195 |
+
|
| 196 |
+
def _apply_top_k_filter(self, logits: torch.Tensor, top_k: Optional[int]) -> torch.Tensor:
|
| 197 |
+
"""Apply top-k filtering to logits"""
|
| 198 |
+
if top_k is not None and top_k > 0:
|
| 199 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 200 |
+
logits[indices_to_remove] = float('-inf')
|
| 201 |
+
return logits
|
| 202 |
+
|
| 203 |
+
def _apply_top_p_filter(self, logits: torch.Tensor, top_p: Optional[float]) -> torch.Tensor:
|
| 204 |
+
"""Apply top-p (nucleus) filtering to logits"""
|
| 205 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 206 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 207 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 208 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 209 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 210 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 211 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 212 |
+
logits[indices_to_remove] = float('-inf')
|
| 213 |
+
return logits
|
| 214 |
+
|
| 215 |
+
def _sample_tokens(self, logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
| 216 |
+
"""Sample tokens from logits with temperature"""
|
| 217 |
+
if temperature > 0:
|
| 218 |
+
logits = logits / temperature
|
| 219 |
+
probs = torch.softmax(logits, dim=-1)
|
| 220 |
+
return torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 221 |
+
else:
|
| 222 |
+
return torch.argmax(logits, dim=-1)
|
| 223 |
+
|
| 224 |
+
def _check_eos_token(self, tokens: torch.Tensor, eos_token_id: int, pad_token_id: Optional[int]) -> bool:
|
| 225 |
+
"""Check if any token in the batch is EOS or pad token"""
|
| 226 |
+
if torch.any(tokens == eos_token_id):
|
| 227 |
+
return True
|
| 228 |
+
if pad_token_id is not None and pad_token_id != eos_token_id:
|
| 229 |
+
if torch.any(tokens == pad_token_id):
|
| 230 |
+
return True
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
def _update_constrained_processor_state(self, constrained_processor: Optional[MetadataConstrainedLogitsProcessor], tokens: torch.Tensor):
|
| 234 |
+
"""Update constrained processor state with generated tokens"""
|
| 235 |
+
if constrained_processor is not None:
|
| 236 |
+
for b in range(tokens.shape[0]):
|
| 237 |
+
constrained_processor.update_state(tokens[b].item())
|
| 238 |
+
|
| 239 |
+
def _forward_pass(
|
| 240 |
+
self,
|
| 241 |
+
model: Any,
|
| 242 |
+
generated_ids: torch.Tensor,
|
| 243 |
+
model_kwargs: Dict[str, Any],
|
| 244 |
+
past_key_values: Optional[Any],
|
| 245 |
+
use_cache: bool,
|
| 246 |
+
) -> Any:
|
| 247 |
+
"""Perform forward pass with KV cache support"""
|
| 248 |
+
if past_key_values is None:
|
| 249 |
+
outputs = model(
|
| 250 |
+
input_ids=generated_ids,
|
| 251 |
+
**model_kwargs,
|
| 252 |
+
use_cache=use_cache,
|
| 253 |
+
)
|
| 254 |
+
else:
|
| 255 |
+
outputs = model(
|
| 256 |
+
input_ids=generated_ids[:, -1:],
|
| 257 |
+
past_key_values=past_key_values,
|
| 258 |
+
**model_kwargs,
|
| 259 |
+
use_cache=use_cache,
|
| 260 |
+
)
|
| 261 |
+
return outputs
|
| 262 |
+
|
| 263 |
+
def _normalize_batch_input(self, formatted_prompts: Union[str, List[str]]) -> Tuple[List[str], bool]:
|
| 264 |
+
"""Normalize batch input: convert single string to list and return (list, is_batch)"""
|
| 265 |
+
is_batch = isinstance(formatted_prompts, list)
|
| 266 |
+
if is_batch:
|
| 267 |
+
return formatted_prompts, is_batch
|
| 268 |
+
else:
|
| 269 |
+
return [formatted_prompts], is_batch
|
| 270 |
+
|
| 271 |
def initialize(
|
| 272 |
self,
|
| 273 |
checkpoint_dir: str,
|
|
|
|
| 333 |
# vllm initialization failed, fallback to PyTorch
|
| 334 |
if not self.llm_initialized:
|
| 335 |
logger.warning("vllm initialization failed, falling back to PyTorch backend")
|
| 336 |
+
success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
|
| 337 |
+
if not success:
|
| 338 |
+
return status_msg, False
|
| 339 |
+
status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
# If vllm initialization succeeded, self.llm_initialized should already be True
|
| 341 |
else:
|
| 342 |
# Use PyTorch backend (pt)
|
| 343 |
+
success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
|
| 344 |
+
if not success:
|
| 345 |
+
return status_msg, False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
return status_msg, True
|
| 348 |
|
| 349 |
except Exception as e:
|
| 350 |
+
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
|
|
|
| 351 |
|
| 352 |
def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
|
| 353 |
"""Initialize 5Hz LM model using vllm backend"""
|
|
|
|
| 393 |
return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
|
| 394 |
except Exception as e:
|
| 395 |
self.llm_initialized = False
|
| 396 |
+
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
|
|
| 397 |
|
| 398 |
+
def _run_vllm(
|
| 399 |
self,
|
| 400 |
+
formatted_prompts: Union[str, List[str]],
|
| 401 |
temperature: float,
|
| 402 |
cfg_scale: float,
|
| 403 |
negative_prompt: str,
|
|
|
|
| 406 |
repetition_penalty: float,
|
| 407 |
use_constrained_decoding: bool = True,
|
| 408 |
constrained_decoding_debug: bool = False,
|
| 409 |
+
metadata_temperature: Optional[float] = None,
|
| 410 |
codes_temperature: Optional[float] = None,
|
| 411 |
target_duration: Optional[float] = None,
|
| 412 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
|
|
|
| 418 |
caption: str = "",
|
| 419 |
lyrics: str = "",
|
| 420 |
cot_text: str = "",
|
| 421 |
+
seeds: Optional[List[int]] = None,
|
| 422 |
+
) -> Union[str, List[str]]:
|
| 423 |
+
"""
|
| 424 |
+
Unified vllm generation function supporting both single and batch modes.
|
| 425 |
+
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
|
| 426 |
+
Returns a single string for single mode, or a list of strings for batch mode.
|
| 427 |
+
"""
|
| 428 |
from nanovllm import SamplingParams
|
| 429 |
|
| 430 |
+
# Determine if batch mode
|
| 431 |
+
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
|
| 432 |
+
batch_size = len(formatted_prompt_list)
|
| 433 |
+
|
| 434 |
# Determine effective temperature for sampler
|
| 435 |
+
# Batch mode doesn't support phase temperatures, so use simple temperature
|
| 436 |
+
# Single mode supports phase temperatures
|
| 437 |
+
use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
|
| 438 |
effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
|
| 439 |
|
| 440 |
+
# Setup constrained processor
|
| 441 |
+
constrained_processor = self._setup_constrained_processor(
|
| 442 |
+
use_constrained_decoding=use_constrained_decoding or use_phase_temperatures,
|
| 443 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 444 |
+
target_duration=target_duration,
|
| 445 |
+
user_metadata=user_metadata,
|
| 446 |
+
stop_at_reasoning=stop_at_reasoning,
|
| 447 |
+
skip_genres=skip_genres,
|
| 448 |
+
skip_caption=skip_caption,
|
| 449 |
+
skip_language=skip_language,
|
| 450 |
+
generation_phase=generation_phase,
|
| 451 |
+
is_batch=is_batch,
|
| 452 |
+
metadata_temperature=metadata_temperature,
|
| 453 |
+
codes_temperature=codes_temperature,
|
| 454 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
sampling_params = SamplingParams(
|
| 457 |
max_tokens=self.max_model_len - 64,
|
|
|
|
| 466 |
|
| 467 |
if cfg_scale > 1.0:
|
| 468 |
# Build unconditional prompt based on generation phase
|
| 469 |
+
formatted_unconditional_prompt = self._build_unconditional_prompt(
|
| 470 |
+
caption=caption,
|
| 471 |
+
lyrics=lyrics,
|
| 472 |
+
cot_text=cot_text,
|
| 473 |
+
negative_prompt=negative_prompt,
|
| 474 |
+
generation_phase=generation_phase,
|
| 475 |
+
is_batch=is_batch,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
)
|
| 477 |
unconditional_prompts = [formatted_unconditional_prompt] * batch_size
|
| 478 |
|
| 479 |
outputs = self.llm.generate(
|
| 480 |
+
formatted_prompt_list,
|
| 481 |
sampling_params,
|
| 482 |
unconditional_prompts=unconditional_prompts,
|
| 483 |
)
|
| 484 |
else:
|
| 485 |
+
outputs = self.llm.generate(formatted_prompt_list, sampling_params)
|
| 486 |
+
|
| 487 |
+
# Extract text from outputs
|
| 488 |
output_texts = []
|
| 489 |
for output in outputs:
|
| 490 |
if hasattr(output, "outputs") and len(output.outputs) > 0:
|
|
|
|
| 495 |
output_texts.append(output["text"])
|
| 496 |
else:
|
| 497 |
output_texts.append(str(output))
|
|
|
|
|
|
|
| 498 |
|
| 499 |
+
# Return single string for single mode, list for batch mode
|
| 500 |
+
return output_texts[0] if not is_batch else output_texts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
+
def _run_pt_single(
|
| 503 |
self,
|
| 504 |
formatted_prompt: str,
|
| 505 |
temperature: float,
|
|
|
|
| 508 |
top_k: Optional[int],
|
| 509 |
top_p: Optional[float],
|
| 510 |
repetition_penalty: float,
|
| 511 |
+
use_constrained_decoding: bool,
|
| 512 |
+
constrained_decoding_debug: bool,
|
| 513 |
+
target_duration: Optional[float],
|
| 514 |
+
user_metadata: Optional[Dict[str, Optional[str]]],
|
| 515 |
+
stop_at_reasoning: bool,
|
| 516 |
+
skip_genres: bool,
|
| 517 |
+
skip_caption: bool,
|
| 518 |
+
skip_language: bool,
|
| 519 |
+
generation_phase: str,
|
| 520 |
+
caption: str,
|
| 521 |
+
lyrics: str,
|
| 522 |
+
cot_text: str,
|
| 523 |
) -> str:
|
| 524 |
+
"""Internal helper function for single-item PyTorch generation."""
|
| 525 |
inputs = self.llm_tokenizer(
|
| 526 |
formatted_prompt,
|
| 527 |
return_tensors="pt",
|
|
|
|
| 529 |
truncation=True,
|
| 530 |
)
|
| 531 |
|
| 532 |
+
# Setup constrained processor
|
| 533 |
+
constrained_processor = self._setup_constrained_processor(
|
| 534 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 535 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 536 |
+
target_duration=target_duration,
|
| 537 |
+
user_metadata=user_metadata,
|
| 538 |
+
stop_at_reasoning=stop_at_reasoning,
|
| 539 |
+
skip_genres=skip_genres,
|
| 540 |
+
skip_caption=skip_caption,
|
| 541 |
+
skip_language=skip_language,
|
| 542 |
+
generation_phase=generation_phase,
|
| 543 |
+
is_batch=False,
|
| 544 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
|
| 546 |
with self._load_model_context():
|
| 547 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
| 550 |
max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
|
| 551 |
|
| 552 |
# Build logits processor list (only for CFG and repetition penalty)
|
| 553 |
+
logits_processor = self._build_logits_processor(repetition_penalty)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
if cfg_scale > 1.0:
|
| 556 |
# Build unconditional prompt based on generation phase
|
| 557 |
+
formatted_unconditional_prompt = self._build_unconditional_prompt(
|
| 558 |
+
caption=caption,
|
| 559 |
+
lyrics=lyrics,
|
| 560 |
+
cot_text=cot_text,
|
| 561 |
+
negative_prompt=negative_prompt,
|
| 562 |
+
generation_phase=generation_phase,
|
| 563 |
+
is_batch=False,
|
| 564 |
+
)
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
# Tokenize both prompts together to ensure same length (with left padding)
|
| 567 |
# Left padding is important for generation tasks
|
|
|
|
| 654 |
|
| 655 |
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 656 |
return output_text
|
| 657 |
+
|
| 658 |
+
def _run_pt(
|
| 659 |
+
self,
|
| 660 |
+
formatted_prompts: Union[str, List[str]],
|
| 661 |
+
temperature: float,
|
| 662 |
+
cfg_scale: float,
|
| 663 |
+
negative_prompt: str,
|
| 664 |
+
top_k: Optional[int],
|
| 665 |
+
top_p: Optional[float],
|
| 666 |
+
repetition_penalty: float,
|
| 667 |
+
use_constrained_decoding: bool = True,
|
| 668 |
+
constrained_decoding_debug: bool = False,
|
| 669 |
+
target_duration: Optional[float] = None,
|
| 670 |
+
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 671 |
+
stop_at_reasoning: bool = False,
|
| 672 |
+
skip_genres: bool = True,
|
| 673 |
+
skip_caption: bool = False,
|
| 674 |
+
skip_language: bool = False,
|
| 675 |
+
generation_phase: str = "cot",
|
| 676 |
+
caption: str = "",
|
| 677 |
+
lyrics: str = "",
|
| 678 |
+
cot_text: str = "",
|
| 679 |
+
seeds: Optional[List[int]] = None,
|
| 680 |
+
) -> Union[str, List[str]]:
|
| 681 |
+
"""
|
| 682 |
+
Unified PyTorch generation function supporting both single and batch modes.
|
| 683 |
+
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
|
| 684 |
+
Returns a single string for single mode, or a list of strings for batch mode.
|
| 685 |
+
Note: PyTorch backend processes batch items sequentially (doesn't support true batching efficiently).
|
| 686 |
+
"""
|
| 687 |
+
# Determine if batch mode
|
| 688 |
+
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
|
| 689 |
+
|
| 690 |
+
# For batch mode, process each item sequentially with different seeds
|
| 691 |
+
if is_batch:
|
| 692 |
+
output_texts = []
|
| 693 |
+
for i, formatted_prompt in enumerate(formatted_prompt_list):
|
| 694 |
+
# Set seed for this item if provided
|
| 695 |
+
if seeds and i < len(seeds):
|
| 696 |
+
torch.manual_seed(seeds[i])
|
| 697 |
+
if torch.cuda.is_available():
|
| 698 |
+
torch.cuda.manual_seed_all(seeds[i])
|
| 699 |
+
|
| 700 |
+
# Generate using single-item method with batch-mode defaults
|
| 701 |
+
output_text = self._run_pt_single(
|
| 702 |
+
formatted_prompt=formatted_prompt,
|
| 703 |
+
temperature=temperature,
|
| 704 |
+
cfg_scale=cfg_scale,
|
| 705 |
+
negative_prompt=negative_prompt,
|
| 706 |
+
top_k=top_k,
|
| 707 |
+
top_p=top_p,
|
| 708 |
+
repetition_penalty=repetition_penalty,
|
| 709 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 710 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 711 |
+
target_duration=target_duration,
|
| 712 |
+
user_metadata=None,
|
| 713 |
+
stop_at_reasoning=False,
|
| 714 |
+
skip_genres=True,
|
| 715 |
+
skip_caption=True,
|
| 716 |
+
skip_language=True,
|
| 717 |
+
generation_phase=generation_phase,
|
| 718 |
+
caption=caption,
|
| 719 |
+
lyrics=lyrics,
|
| 720 |
+
cot_text=cot_text,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
output_texts.append(output_text)
|
| 724 |
+
|
| 725 |
+
return output_texts
|
| 726 |
+
|
| 727 |
+
# Single mode: process the formatted prompt
|
| 728 |
+
formatted_prompt = formatted_prompt_list[0]
|
| 729 |
+
|
| 730 |
+
return self._run_pt_single(
|
| 731 |
+
formatted_prompt=formatted_prompt,
|
| 732 |
+
temperature=temperature,
|
| 733 |
+
cfg_scale=cfg_scale,
|
| 734 |
+
negative_prompt=negative_prompt,
|
| 735 |
+
top_k=top_k,
|
| 736 |
+
top_p=top_p,
|
| 737 |
+
repetition_penalty=repetition_penalty,
|
| 738 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 739 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 740 |
+
target_duration=target_duration,
|
| 741 |
+
user_metadata=user_metadata,
|
| 742 |
+
stop_at_reasoning=stop_at_reasoning,
|
| 743 |
+
skip_genres=skip_genres,
|
| 744 |
+
skip_caption=skip_caption,
|
| 745 |
+
skip_language=skip_language,
|
| 746 |
+
generation_phase=generation_phase,
|
| 747 |
+
caption=caption,
|
| 748 |
+
lyrics=lyrics,
|
| 749 |
+
cot_text=cot_text,
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
|
| 753 |
"""Check if all required metadata are present."""
|
| 754 |
if user_metadata is None:
|
|
|
|
| 799 |
use_cot_caption: bool = True,
|
| 800 |
use_cot_language: bool = True,
|
| 801 |
is_format_caption: bool = False,
|
| 802 |
+
batch_size: Optional[int] = None,
|
| 803 |
+
seeds: Optional[List[int]] = None,
|
| 804 |
+
) -> Union[Tuple[Dict[str, Any], str, str], Tuple[List[Dict[str, Any]], List[str], str]]:
|
| 805 |
"""Two-phase LM generation: CoT generation followed by audio codes generation.
|
| 806 |
|
| 807 |
- infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
|
|
|
|
| 814 |
If specified, constrained decoding will inject these values directly.
|
| 815 |
use_cot_caption: Whether to generate caption in CoT (default True).
|
| 816 |
use_cot_language: Whether to generate language in CoT (default True).
|
| 817 |
+
batch_size: Optional batch size for batch generation. If None or 1, returns single result.
|
| 818 |
+
If > 1, returns batch results (lists).
|
| 819 |
+
seeds: Optional list of seeds for batch generation (for reproducibility).
|
| 820 |
+
Only used when batch_size > 1.
|
| 821 |
+
|
| 822 |
+
Returns:
|
| 823 |
+
If batch_size is None or 1: (metadata, audio_codes, status_msg)
|
| 824 |
+
If batch_size > 1: (metadata_list, audio_codes_list, status_msg)
|
| 825 |
"""
|
| 826 |
import time
|
| 827 |
+
import random
|
| 828 |
|
| 829 |
infer_type = (infer_type or "").strip().lower()
|
| 830 |
if infer_type not in {"dit", "llm_dit"}:
|
| 831 |
+
if batch_size and batch_size > 1:
|
| 832 |
+
return [], [], f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
|
| 833 |
return {}, "", f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
|
| 834 |
+
|
| 835 |
+
# Determine if batch mode
|
| 836 |
+
is_batch = batch_size and batch_size > 1
|
| 837 |
+
actual_batch_size = batch_size if is_batch else 1
|
| 838 |
+
|
| 839 |
+
# Initialize variables
|
| 840 |
metadata = {}
|
| 841 |
audio_codes = ""
|
| 842 |
has_all_metas = self.has_all_metas(user_metadata)
|
|
|
|
|
|
|
| 843 |
phase1_time = 0.0
|
| 844 |
phase2_time = 0.0
|
| 845 |
|
| 846 |
+
# Handle seeds for batch mode
|
| 847 |
+
if is_batch:
|
| 848 |
+
if seeds is None:
|
| 849 |
+
seeds = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
|
| 850 |
+
elif len(seeds) < actual_batch_size:
|
| 851 |
+
seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size - len(seeds))]
|
| 852 |
+
else:
|
| 853 |
+
seeds = seeds[:actual_batch_size]
|
| 854 |
+
|
| 855 |
# ========== PHASE 1: CoT Generation ==========
|
| 856 |
+
# Skip CoT if all metadata are user-provided OR caption is already formatted
|
| 857 |
+
if not has_all_metas and not is_format_caption:
|
| 858 |
+
if is_batch:
|
| 859 |
+
logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
|
| 860 |
+
else:
|
| 861 |
+
logger.info("Phase 1: Generating CoT metadata...")
|
| 862 |
phase1_start = time.time()
|
| 863 |
|
| 864 |
# Build formatted prompt for CoT phase
|
| 865 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
|
| 866 |
+
|
| 867 |
logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
|
| 868 |
# Generate CoT (stop at </think>)
|
| 869 |
cot_output_text, status = self.generate_from_formatted_prompt(
|
|
|
|
| 893 |
phase1_time = time.time() - phase1_start
|
| 894 |
|
| 895 |
if not cot_output_text:
|
| 896 |
+
if is_batch:
|
| 897 |
+
return [], [], status
|
| 898 |
return {}, "", status
|
| 899 |
|
| 900 |
# Parse metadata from CoT output
|
| 901 |
metadata, _ = self.parse_lm_output(cot_output_text)
|
| 902 |
+
if is_batch:
|
| 903 |
+
logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
|
| 904 |
+
else:
|
| 905 |
+
logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
|
| 906 |
else:
|
| 907 |
# Use user-provided metadata
|
| 908 |
+
if is_batch:
|
| 909 |
+
logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)")
|
| 910 |
+
else:
|
| 911 |
+
logger.info("Phase 1: Using user-provided metadata (skipping generation)")
|
| 912 |
metadata = {k: v for k, v in user_metadata.items() if v is not None}
|
| 913 |
|
| 914 |
# If infer_type is 'dit', stop here and return only metadata
|
| 915 |
if infer_type == "dit":
|
| 916 |
+
if is_batch:
|
| 917 |
+
metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
|
| 918 |
+
status_msg = f"✅ Generated CoT metadata successfully (batch mode)\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
|
| 919 |
+
return metadata_list, [""] * actual_batch_size, status_msg
|
| 920 |
+
else:
|
| 921 |
+
status_msg = f"✅ Generated CoT metadata successfully\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
|
| 922 |
+
return metadata, "", status_msg
|
| 923 |
|
| 924 |
# ========== PHASE 2: Audio Codes Generation ==========
|
| 925 |
+
if is_batch:
|
| 926 |
+
logger.info(f"Batch Phase 2: Generating audio codes for {actual_batch_size} items...")
|
| 927 |
+
else:
|
| 928 |
+
logger.info("Phase 2: Generating audio codes...")
|
| 929 |
phase2_start = time.time()
|
| 930 |
|
| 931 |
# Format metadata as CoT using YAML (matching training format)
|
|
|
|
| 934 |
# Build formatted prompt with CoT for codes generation phase
|
| 935 |
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
|
| 936 |
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
|
| 938 |
+
if is_batch:
|
| 939 |
+
# Batch mode: generate codes for all items
|
| 940 |
+
formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
|
| 942 |
+
# Call backend-specific batch generation
|
| 943 |
+
try:
|
| 944 |
+
if self.llm_backend == "vllm":
|
| 945 |
+
codes_outputs = self._run_vllm(
|
| 946 |
+
formatted_prompts=formatted_prompts,
|
| 947 |
+
temperature=temperature,
|
| 948 |
+
cfg_scale=cfg_scale,
|
| 949 |
+
negative_prompt=negative_prompt,
|
| 950 |
+
top_k=top_k,
|
| 951 |
+
top_p=top_p,
|
| 952 |
+
repetition_penalty=repetition_penalty,
|
| 953 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 954 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 955 |
+
target_duration=target_duration,
|
| 956 |
+
generation_phase="codes",
|
| 957 |
+
caption=caption,
|
| 958 |
+
lyrics=lyrics,
|
| 959 |
+
cot_text=cot_text,
|
| 960 |
+
seeds=seeds,
|
| 961 |
+
)
|
| 962 |
+
else: # pt backend
|
| 963 |
+
codes_outputs = self._run_pt(
|
| 964 |
+
formatted_prompts=formatted_prompts,
|
| 965 |
+
temperature=temperature,
|
| 966 |
+
cfg_scale=cfg_scale,
|
| 967 |
+
negative_prompt=negative_prompt,
|
| 968 |
+
top_k=top_k,
|
| 969 |
+
top_p=top_p,
|
| 970 |
+
repetition_penalty=repetition_penalty,
|
| 971 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 972 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 973 |
+
target_duration=target_duration,
|
| 974 |
+
generation_phase="codes",
|
| 975 |
+
caption=caption,
|
| 976 |
+
lyrics=lyrics,
|
| 977 |
+
cot_text=cot_text,
|
| 978 |
+
seeds=seeds,
|
| 979 |
+
)
|
| 980 |
+
except Exception as e:
|
| 981 |
+
error_msg = f"❌ Error in batch codes generation: {str(e)}"
|
| 982 |
+
logger.error(error_msg)
|
| 983 |
+
return [], [], error_msg
|
| 984 |
+
|
| 985 |
+
# Parse audio codes from each output
|
| 986 |
+
audio_codes_list = []
|
| 987 |
+
metadata_list = []
|
| 988 |
+
for output_text in codes_outputs:
|
| 989 |
+
_, audio_codes_item = self.parse_lm_output(output_text)
|
| 990 |
+
audio_codes_list.append(audio_codes_item)
|
| 991 |
+
metadata_list.append(metadata.copy()) # Same metadata for all
|
| 992 |
+
|
| 993 |
+
phase2_time = time.time() - phase2_start
|
| 994 |
+
|
| 995 |
+
# Log results
|
| 996 |
+
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
|
| 997 |
+
logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
|
| 998 |
+
|
| 999 |
+
status_msg = f"✅ Batch generation completed ({actual_batch_size} items)\nPhase 1: CoT metadata\nPhase 2: {sum(codes_counts)} total codes ({codes_counts})\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
|
| 1000 |
+
return metadata_list, audio_codes_list, status_msg
|
| 1001 |
+
else:
|
| 1002 |
+
# Single mode: generate codes for one item
|
| 1003 |
+
codes_output_text, status = self.generate_from_formatted_prompt(
|
| 1004 |
+
formatted_prompt=formatted_prompt_with_cot,
|
| 1005 |
+
cfg={
|
| 1006 |
+
"temperature": temperature,
|
| 1007 |
+
"cfg_scale": cfg_scale,
|
| 1008 |
+
"negative_prompt": negative_prompt,
|
| 1009 |
+
"top_k": top_k,
|
| 1010 |
+
"top_p": top_p,
|
| 1011 |
+
"repetition_penalty": repetition_penalty,
|
| 1012 |
+
"target_duration": target_duration,
|
| 1013 |
+
"user_metadata": None, # No user metadata injection in Phase 2
|
| 1014 |
+
"skip_caption": True, # Skip caption since CoT is already included
|
| 1015 |
+
"skip_language": True, # Skip language since CoT is already included
|
| 1016 |
+
"generation_phase": "codes",
|
| 1017 |
+
# Pass context for building unconditional prompt in codes phase
|
| 1018 |
+
"caption": caption,
|
| 1019 |
+
"lyrics": lyrics,
|
| 1020 |
+
"cot_text": cot_text,
|
| 1021 |
+
},
|
| 1022 |
use_constrained_decoding=use_constrained_decoding,
|
| 1023 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 1024 |
+
stop_at_reasoning=False, # Generate codes until EOS
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1025 |
)
|
| 1026 |
|
| 1027 |
+
if not codes_output_text:
|
| 1028 |
+
return metadata, "", status
|
| 1029 |
|
| 1030 |
+
phase2_time = time.time() - phase2_start
|
|
|
|
| 1031 |
|
| 1032 |
+
# Parse audio codes from output (metadata should be same as Phase 1)
|
| 1033 |
+
_, audio_codes = self.parse_lm_output(codes_output_text)
|
| 1034 |
+
|
| 1035 |
+
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 1036 |
+
logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
|
| 1037 |
+
|
| 1038 |
+
status_msg = f"✅ Generated successfully (2-phase)\nPhase 1: CoT metadata\nPhase 2: {codes_count} audio codes\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
|
| 1039 |
+
return metadata, audio_codes, status_msg
|
| 1040 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1041 |
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
|
| 1042 |
"""
|
| 1043 |
Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
|
|
|
|
| 1059 |
if is_negative_prompt:
|
| 1060 |
# Unconditional prompt for CFG
|
| 1061 |
# Check if user provided a meaningful negative prompt (not the default)
|
| 1062 |
+
has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
|
| 1063 |
|
| 1064 |
if generation_phase == "cot":
|
| 1065 |
# CoT phase unconditional prompt
|
|
|
|
| 1110 |
if is_negative_prompt:
|
| 1111 |
# Unconditional prompt for codes phase
|
| 1112 |
# Check if user provided a meaningful negative prompt
|
| 1113 |
+
has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
|
| 1114 |
|
| 1115 |
# Use empty CoT for unconditional
|
| 1116 |
cot_for_prompt = "<think>\n</think>"
|
|
|
|
| 1393 |
|
| 1394 |
try:
|
| 1395 |
if self.llm_backend == "vllm":
|
| 1396 |
+
output_text = self._run_vllm(
|
| 1397 |
+
formatted_prompts=formatted_prompt,
|
| 1398 |
temperature=temperature,
|
| 1399 |
cfg_scale=cfg_scale,
|
| 1400 |
negative_prompt=negative_prompt,
|
|
|
|
| 1417 |
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 1418 |
|
| 1419 |
# PyTorch backend
|
| 1420 |
+
output_text = self._run_pt(
|
| 1421 |
+
formatted_prompts=formatted_prompt,
|
| 1422 |
temperature=temperature,
|
| 1423 |
cfg_scale=cfg_scale,
|
| 1424 |
negative_prompt=negative_prompt,
|
|
|
|
| 1483 |
eos_token_id = pad_token_id
|
| 1484 |
|
| 1485 |
# Build logits processor for repetition penalty
|
| 1486 |
+
logits_processor = self._build_logits_processor(repetition_penalty)
|
|
|
|
|
|
|
| 1487 |
|
| 1488 |
with torch.no_grad():
|
| 1489 |
for step in range(max_new_tokens):
|
| 1490 |
# Forward pass
|
| 1491 |
+
outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1492 |
|
| 1493 |
# Get logits for the last position
|
| 1494 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
|
|
|
|
| 1501 |
for processor in logits_processor:
|
| 1502 |
next_token_logits = processor(generated_ids, next_token_logits)
|
| 1503 |
|
| 1504 |
+
# Apply top-k and top-p filtering
|
| 1505 |
+
next_token_logits = self._apply_top_k_filter(next_token_logits, top_k)
|
| 1506 |
+
next_token_logits = self._apply_top_p_filter(next_token_logits, top_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1507 |
|
| 1508 |
# Apply temperature and sample
|
| 1509 |
+
next_tokens = self._sample_tokens(next_token_logits, temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1510 |
|
| 1511 |
# Update constrained processor state
|
| 1512 |
+
self._update_constrained_processor_state(constrained_processor, next_tokens)
|
|
|
|
|
|
|
| 1513 |
|
| 1514 |
# Check for EOS token
|
| 1515 |
+
should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1516 |
|
| 1517 |
# Append token to sequence
|
| 1518 |
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
|
|
|
| 1588 |
eos_token_id = pad_token_id
|
| 1589 |
|
| 1590 |
# Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
|
| 1591 |
+
logits_processor = self._build_logits_processor(repetition_penalty)
|
|
|
|
|
|
|
| 1592 |
|
| 1593 |
with torch.no_grad():
|
| 1594 |
for step in range(max_new_tokens):
|
| 1595 |
# Forward pass for the entire batch (conditional + unconditional)
|
| 1596 |
+
outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1597 |
|
| 1598 |
# Get logits for the last position
|
| 1599 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
|
|
|
|
| 1616 |
for processor in logits_processor:
|
| 1617 |
cfg_logits = processor(current_input_ids, cfg_logits)
|
| 1618 |
|
| 1619 |
+
# Apply top-k and top-p filtering
|
| 1620 |
+
cfg_logits = self._apply_top_k_filter(cfg_logits, top_k)
|
| 1621 |
+
cfg_logits = self._apply_top_p_filter(cfg_logits, top_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1622 |
|
| 1623 |
# Apply temperature and sample
|
| 1624 |
+
next_tokens = self._sample_tokens(cfg_logits, temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1625 |
|
| 1626 |
# Update constrained processor state AFTER sampling
|
| 1627 |
+
self._update_constrained_processor_state(constrained_processor, next_tokens)
|
|
|
|
|
|
|
| 1628 |
|
| 1629 |
# Check for EOS token in conditional sequences BEFORE unsqueezing
|
| 1630 |
# Stop if any conditional sequence generates EOS token
|
| 1631 |
# next_tokens shape: [batch_size] (only conditional tokens)
|
| 1632 |
+
should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1633 |
|
| 1634 |
# Apply the same sampled tokens to both conditional and unconditional sequences
|
| 1635 |
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
CHANGED
|
@@ -68,10 +68,16 @@ class ModelRunner:
|
|
| 68 |
self.model = Qwen3ForCausalLM(hf_config)
|
| 69 |
load_model(self.model, config.model)
|
| 70 |
self.sampler = Sampler()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
self.warmup_model()
|
| 72 |
self.allocate_kv_cache()
|
| 73 |
if not self.enforce_eager:
|
| 74 |
self.capture_cudagraph()
|
|
|
|
| 75 |
torch.set_default_device("cpu")
|
| 76 |
torch.set_default_dtype(default_dtype)
|
| 77 |
|
|
@@ -84,6 +90,24 @@ class ModelRunner:
|
|
| 84 |
self.shm = SharedMemory(name="nanovllm")
|
| 85 |
self.loop()
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def exit(self):
|
| 88 |
if self.world_size > 1:
|
| 89 |
self.shm.close()
|
|
@@ -216,57 +240,49 @@ class ModelRunner:
|
|
| 216 |
return input_ids, positions
|
| 217 |
|
| 218 |
def prepare_decode(self, seqs: list[Sequence]):
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
for seq in seqs:
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
| 232 |
block_tables = self.prepare_block_tables(seqs)
|
| 233 |
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
| 234 |
return input_ids, positions
|
| 235 |
|
| 236 |
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
|
| 237 |
-
"""
|
| 238 |
if is_cfg_batch:
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
num_cond = len(seqs) // 2
|
| 242 |
-
temperatures = []
|
| 243 |
-
cfg_scales = []
|
| 244 |
-
top_ks = []
|
| 245 |
-
top_ps = []
|
| 246 |
-
repetition_penalties = []
|
| 247 |
-
for seq in seqs[:num_cond]:
|
| 248 |
-
temperatures.append(seq.temperature)
|
| 249 |
-
cfg_scales.append(seq.cfg_scale)
|
| 250 |
-
top_ks.append(seq.top_k if seq.top_k is not None else 0)
|
| 251 |
-
top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
|
| 252 |
-
repetition_penalties.append(seq.repetition_penalty)
|
| 253 |
else:
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
| 270 |
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 271 |
|
| 272 |
@torch.inference_mode()
|
|
|
|
| 68 |
self.model = Qwen3ForCausalLM(hf_config)
|
| 69 |
load_model(self.model, config.model)
|
| 70 |
self.sampler = Sampler()
|
| 71 |
+
|
| 72 |
+
# Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
|
| 73 |
+
# Must be called before warmup_model() since it uses these buffers
|
| 74 |
+
self._allocate_sample_buffers()
|
| 75 |
+
|
| 76 |
self.warmup_model()
|
| 77 |
self.allocate_kv_cache()
|
| 78 |
if not self.enforce_eager:
|
| 79 |
self.capture_cudagraph()
|
| 80 |
+
|
| 81 |
torch.set_default_device("cpu")
|
| 82 |
torch.set_default_dtype(default_dtype)
|
| 83 |
|
|
|
|
| 90 |
self.shm = SharedMemory(name="nanovllm")
|
| 91 |
self.loop()
|
| 92 |
|
| 93 |
+
def _allocate_sample_buffers(self):
|
| 94 |
+
"""Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
|
| 95 |
+
max_bs = self.config.max_num_seqs
|
| 96 |
+
|
| 97 |
+
# Pre-allocate pinned memory buffers on CPU for fast transfer
|
| 98 |
+
# Must explicitly specify device="cpu" since default device may be "cuda"
|
| 99 |
+
self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 100 |
+
self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 101 |
+
self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 102 |
+
self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 103 |
+
self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 104 |
+
|
| 105 |
+
# Pre-allocate decode buffers on CPU with pinned memory
|
| 106 |
+
self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 107 |
+
self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 108 |
+
self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 109 |
+
self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 110 |
+
|
| 111 |
def exit(self):
|
| 112 |
if self.world_size > 1:
|
| 113 |
self.shm.close()
|
|
|
|
| 240 |
return input_ids, positions
|
| 241 |
|
| 242 |
def prepare_decode(self, seqs: list[Sequence]):
|
| 243 |
+
"""Optimized decode preparation using pre-allocated buffers."""
|
| 244 |
+
bs = len(seqs)
|
| 245 |
+
|
| 246 |
+
# Use pre-allocated CPU buffers
|
| 247 |
+
for i, seq in enumerate(seqs):
|
| 248 |
+
self._cpu_input_ids[i] = seq.last_token
|
| 249 |
+
self._cpu_positions[i] = len(seq) - 1
|
| 250 |
+
self._cpu_context_lens[i] = len(seq)
|
| 251 |
+
self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
|
| 252 |
+
|
| 253 |
+
# Transfer to GPU using sliced views
|
| 254 |
+
input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
|
| 255 |
+
positions = self._cpu_positions[:bs].cuda(non_blocking=True)
|
| 256 |
+
slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
|
| 257 |
+
context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
|
| 258 |
block_tables = self.prepare_block_tables(seqs)
|
| 259 |
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
| 260 |
return input_ids, positions
|
| 261 |
|
| 262 |
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
|
| 263 |
+
"""Optimized sample preparation using pre-allocated buffers."""
|
| 264 |
if is_cfg_batch:
|
| 265 |
+
num_seqs = len(seqs) // 2
|
| 266 |
+
target_seqs = seqs[:num_seqs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
else:
|
| 268 |
+
num_seqs = len(seqs)
|
| 269 |
+
target_seqs = seqs
|
| 270 |
+
|
| 271 |
+
# Fill pre-allocated CPU buffers
|
| 272 |
+
for i, seq in enumerate(target_seqs):
|
| 273 |
+
self._cpu_temperatures[i] = seq.temperature
|
| 274 |
+
self._cpu_cfg_scales[i] = seq.cfg_scale
|
| 275 |
+
self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
|
| 276 |
+
self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
|
| 277 |
+
self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
|
| 278 |
+
|
| 279 |
+
# Transfer to GPU using sliced views (single batched transfer)
|
| 280 |
+
temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
|
| 281 |
+
cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
|
| 282 |
+
top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True)
|
| 283 |
+
top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True)
|
| 284 |
+
repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True)
|
| 285 |
+
|
| 286 |
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 287 |
|
| 288 |
@torch.inference_mode()
|
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py
CHANGED
|
@@ -3,12 +3,88 @@ from torch import nn
|
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
class Sampler(nn.Module):
|
| 7 |
|
| 8 |
def __init__(self):
|
| 9 |
super().__init__()
|
| 10 |
|
| 11 |
-
@torch.compile
|
| 12 |
def forward(
|
| 13 |
self,
|
| 14 |
logits: torch.Tensor,
|
|
@@ -19,56 +95,34 @@ class Sampler(nn.Module):
|
|
| 19 |
input_ids: Optional[torch.Tensor] = None,
|
| 20 |
):
|
| 21 |
"""
|
| 22 |
-
Sample tokens from logits with optional top-k
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
temperatures: [batch_size] temperature values
|
| 27 |
-
top_ks: Optional [batch_size] top-k values (None or 0 means no top-k filtering)
|
| 28 |
-
top_ps: Optional [batch_size] top-p values (None or 1.0 means no top-p filtering)
|
| 29 |
-
repetition_penalties: Optional [batch_size] repetition penalty values (1.0 means no penalty)
|
| 30 |
-
input_ids: Optional [batch_size, seq_len] input token ids for repetition penalty
|
| 31 |
"""
|
| 32 |
-
batch_size, vocab_size = logits.shape
|
| 33 |
-
|
| 34 |
-
# Note: Repetition penalty is applied in ModelRunner before calling sampler
|
| 35 |
-
# This allows us to use the full sequence context
|
| 36 |
-
|
| 37 |
# Apply temperature
|
| 38 |
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
# Get top-k logits, set others to -inf
|
| 46 |
-
top_k_logits, top_k_indices = torch.topk(logits[i], int(top_k), dim=-1)
|
| 47 |
-
filtered_logits = torch.full_like(logits[i], float('-inf'))
|
| 48 |
-
filtered_logits[top_k_indices] = top_k_logits
|
| 49 |
-
logits[i] = filtered_logits
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
if
|
| 57 |
-
|
| 58 |
-
sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
|
| 59 |
-
# Calculate cumulative probabilities
|
| 60 |
-
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 61 |
-
# Find the cutoff point
|
| 62 |
-
cutoff_idx = (cumsum_probs <= top_p).sum().item()
|
| 63 |
-
if cutoff_idx < len(sorted_indices):
|
| 64 |
-
cutoff_idx += 1 # Include one more token to ensure we have at least one
|
| 65 |
-
# Create mask for tokens to keep
|
| 66 |
-
mask = torch.zeros_like(probs[i])
|
| 67 |
-
mask[sorted_indices[:cutoff_idx]] = 1.0
|
| 68 |
-
# Apply mask: set filtered tokens to -inf
|
| 69 |
-
logits[i] = torch.where(mask > 0, logits[i], torch.tensor(float('-inf'), device=logits.device))
|
| 70 |
|
| 71 |
-
# Sample using
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
|
| 6 |
+
def apply_top_k_top_p(
|
| 7 |
+
logits: torch.Tensor,
|
| 8 |
+
k: Optional[torch.Tensor],
|
| 9 |
+
p: Optional[torch.Tensor],
|
| 10 |
+
) -> torch.Tensor:
|
| 11 |
+
"""Apply top-k and top-p masks to the logits (vLLM style).
|
| 12 |
+
|
| 13 |
+
The logits tensor is updated in-place.
|
| 14 |
+
"""
|
| 15 |
+
if p is None:
|
| 16 |
+
if k is None:
|
| 17 |
+
return logits
|
| 18 |
+
# Avoid sorting vocab for top-k only case
|
| 19 |
+
return apply_top_k_only(logits, k)
|
| 20 |
+
|
| 21 |
+
# Need to sort for top-p
|
| 22 |
+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
| 23 |
+
|
| 24 |
+
if k is not None:
|
| 25 |
+
# Apply top-k first
|
| 26 |
+
vocab_size = logits_sort.size(1)
|
| 27 |
+
# Clamp k to valid range
|
| 28 |
+
k_clamped = k.clamp(1, vocab_size).long()
|
| 29 |
+
top_k_mask_idx = vocab_size - k_clamped # shape: [B]
|
| 30 |
+
# Get the threshold value for each batch
|
| 31 |
+
top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
|
| 32 |
+
top_k_mask = logits_sort < top_k_thresh
|
| 33 |
+
logits_sort.masked_fill_(top_k_mask, float('-inf'))
|
| 34 |
+
|
| 35 |
+
# Apply top-p
|
| 36 |
+
probs_sort = logits_sort.softmax(dim=-1)
|
| 37 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
|
| 38 |
+
top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
|
| 39 |
+
# Ensure at least one token is kept
|
| 40 |
+
top_p_mask[:, -1] = False
|
| 41 |
+
logits_sort.masked_fill_(top_p_mask, float('-inf'))
|
| 42 |
+
|
| 43 |
+
# Re-sort back to original positions
|
| 44 |
+
logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
| 45 |
+
return logits
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def apply_top_k_only(
|
| 49 |
+
logits: torch.Tensor,
|
| 50 |
+
k: torch.Tensor,
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""Apply top-k mask without sorting the entire vocab (vLLM style).
|
| 53 |
+
|
| 54 |
+
This is much faster than sorting for top-k only cases.
|
| 55 |
+
The logits tensor is updated in-place.
|
| 56 |
+
"""
|
| 57 |
+
vocab_size = logits.shape[1]
|
| 58 |
+
# Handle cases where k >= vocab_size (no filtering needed)
|
| 59 |
+
no_top_k_mask = (k <= 0) | (k >= vocab_size)
|
| 60 |
+
# Set invalid k to 1 so we can still gather
|
| 61 |
+
k_safe = k.masked_fill(no_top_k_mask, 1).long()
|
| 62 |
+
# NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
|
| 63 |
+
max_top_k = int(k_safe.max().clamp(max=vocab_size))
|
| 64 |
+
|
| 65 |
+
# Get top-k values for all batches
|
| 66 |
+
# topk.values has shape [batch_size, max_top_k]
|
| 67 |
+
topk_values = logits.topk(max_top_k, dim=1).values
|
| 68 |
+
|
| 69 |
+
# Convert k to 0-based index: we want the k-th largest value (index k-1)
|
| 70 |
+
# Clamp to valid range for gather
|
| 71 |
+
k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
|
| 72 |
+
# Gather the threshold value (the k-th largest)
|
| 73 |
+
top_k_thresh = topk_values.gather(1, k_index)
|
| 74 |
+
|
| 75 |
+
# For rows with no top-k filtering, set threshold to -inf so nothing gets masked
|
| 76 |
+
top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
|
| 77 |
+
|
| 78 |
+
# Mask all values below the threshold
|
| 79 |
+
logits.masked_fill_(logits < top_k_thresh, float('-inf'))
|
| 80 |
+
return logits
|
| 81 |
+
|
| 82 |
+
|
| 83 |
class Sampler(nn.Module):
|
| 84 |
|
| 85 |
def __init__(self):
|
| 86 |
super().__init__()
|
| 87 |
|
|
|
|
| 88 |
def forward(
|
| 89 |
self,
|
| 90 |
logits: torch.Tensor,
|
|
|
|
| 95 |
input_ids: Optional[torch.Tensor] = None,
|
| 96 |
):
|
| 97 |
"""
|
| 98 |
+
Sample tokens from logits with optional top-k and top-p filtering.
|
| 99 |
|
| 100 |
+
Condition checking is done OUTSIDE the compiled function to avoid
|
| 101 |
+
graph breaks from .any() calls.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# Apply temperature
|
| 104 |
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 105 |
|
| 106 |
+
# Check conditions OUTSIDE compiled code to avoid graph breaks
|
| 107 |
+
# These .any() calls cause CPU-GPU sync, but we do it once here
|
| 108 |
+
# instead of inside the compiled function
|
| 109 |
+
need_topk = top_ks is not None and bool((top_ks > 0).any()) and bool((top_ks < logits.shape[1]).any())
|
| 110 |
+
need_topp = top_ps is not None and bool((top_ps < 1.0).any()) and bool((top_ps > 0.0).any())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
if need_topk or need_topp:
|
| 113 |
+
# Apply filtering (this part is not compiled due to dynamic control flow)
|
| 114 |
+
logits = apply_top_k_top_p(
|
| 115 |
+
logits,
|
| 116 |
+
top_ks if need_topk else None,
|
| 117 |
+
top_ps if need_topp else None,
|
| 118 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
# Sample using compiled function
|
| 121 |
+
return self._sample(logits)
|
| 122 |
+
|
| 123 |
+
@torch.compile(dynamic=True)
|
| 124 |
+
def _sample(self, logits: torch.Tensor) -> torch.Tensor:
|
| 125 |
+
"""Compiled sampling kernel - no graph breaks here."""
|
| 126 |
+
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
| 127 |
+
q = torch.empty_like(probs).exponential_()
|
| 128 |
+
return probs.div(q).argmax(dim=-1)
|
acestep/third_parts/nano-vllm/pyproject.toml
CHANGED
|
@@ -15,8 +15,6 @@ dependencies = [
|
|
| 15 |
"triton-windows>=3.0.0; sys_platform == 'win32'",
|
| 16 |
"triton>=3.0.0; sys_platform != 'win32'",
|
| 17 |
"transformers>=4.51.0",
|
| 18 |
-
"flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.3/flash_attn-2.8.3+cu128torch2.8.0cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl; sys_platform == 'win32'",
|
| 19 |
-
"flash-attn; sys_platform != 'win32'",
|
| 20 |
"xxhash",
|
| 21 |
]
|
| 22 |
|
|
|
|
| 15 |
"triton-windows>=3.0.0; sys_platform == 'win32'",
|
| 16 |
"triton>=3.0.0; sys_platform != 'win32'",
|
| 17 |
"transformers>=4.51.0",
|
|
|
|
|
|
|
| 18 |
"xxhash",
|
| 19 |
]
|
| 20 |
|
profile_inference.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Profiling script for acestep/inference.py using cProfile
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python profile_inference.py
|
| 7 |
+
python profile_inference.py --warmup
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import cProfile
|
| 11 |
+
import pstats
|
| 12 |
+
import io
|
| 13 |
+
import time
|
| 14 |
+
import argparse
|
| 15 |
+
import sys
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
# Add project root to path
|
| 19 |
+
project_root = os.path.abspath(os.path.dirname(__file__))
|
| 20 |
+
if project_root not in sys.path:
|
| 21 |
+
sys.path.insert(0, project_root)
|
| 22 |
+
|
| 23 |
+
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 24 |
+
from acestep.handler import AceStepHandler
|
| 25 |
+
from acestep.llm_inference import LLMHandler
|
| 26 |
+
import json
|
| 27 |
+
from typing import Tuple
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def profile_with_cprofile(dit_handler, llm_handler, params, config, warmup=False):
|
| 31 |
+
"""Profile using Python's built-in cProfile.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
warmup: If True, run once for warmup before profiling (default: False)
|
| 35 |
+
"""
|
| 36 |
+
print("=" * 80)
|
| 37 |
+
print("Profiling with cProfile")
|
| 38 |
+
print("=" * 80)
|
| 39 |
+
|
| 40 |
+
# Warmup run (to exclude PyTorch compilation overhead)
|
| 41 |
+
if warmup:
|
| 42 |
+
print("\n[Warmup] Running first generation to warm up (PyTorch compilation, etc.)...")
|
| 43 |
+
warmup_start = time.time()
|
| 44 |
+
params.use_cot_metas = False
|
| 45 |
+
config.is_format_caption = True
|
| 46 |
+
config.use_constrained_decoding = False
|
| 47 |
+
warmup_result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
|
| 48 |
+
warmup_time = time.time() - warmup_start
|
| 49 |
+
print(f"[Warmup] Completed in {warmup_time:.2f}s")
|
| 50 |
+
if not warmup_result.success:
|
| 51 |
+
print(f"[Warmup] ⚠ Warmup generation failed: {warmup_result.error}")
|
| 52 |
+
return warmup_result
|
| 53 |
+
|
| 54 |
+
# Actual profiling run (first inference)
|
| 55 |
+
print("\n[Profiling] Running first generation for profiling...")
|
| 56 |
+
profiler = cProfile.Profile()
|
| 57 |
+
profiler.enable()
|
| 58 |
+
|
| 59 |
+
profiling_start = time.time()
|
| 60 |
+
try:
|
| 61 |
+
result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
|
| 62 |
+
finally:
|
| 63 |
+
profiler.disable()
|
| 64 |
+
profiling_time = time.time() - profiling_start
|
| 65 |
+
|
| 66 |
+
# Create stats
|
| 67 |
+
s = io.StringIO()
|
| 68 |
+
ps = pstats.Stats(profiler, stream=s)
|
| 69 |
+
ps.sort_stats('cumulative')
|
| 70 |
+
|
| 71 |
+
print(f"\n[Profiling] Completed in {profiling_time:.2f}s")
|
| 72 |
+
print("\nTop 30 functions by cumulative time:")
|
| 73 |
+
print("-" * 80)
|
| 74 |
+
ps.print_stats(30)
|
| 75 |
+
|
| 76 |
+
print("\nTop 30 functions by total time:")
|
| 77 |
+
print("-" * 80)
|
| 78 |
+
ps.sort_stats('tottime')
|
| 79 |
+
ps.print_stats(30)
|
| 80 |
+
|
| 81 |
+
# Save detailed report to file
|
| 82 |
+
output_file = "profile_cprofile.txt"
|
| 83 |
+
with open(output_file, 'w') as f:
|
| 84 |
+
# Create a new Stats object with file as stream
|
| 85 |
+
ps_file = pstats.Stats(profiler, stream=f)
|
| 86 |
+
ps_file.sort_stats('cumulative')
|
| 87 |
+
ps_file.print_stats()
|
| 88 |
+
print(f"\nDetailed profile saved to: {output_file}")
|
| 89 |
+
|
| 90 |
+
return result
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def main():
|
| 94 |
+
parser = argparse.ArgumentParser(description="Profile acestep/inference.py")
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--checkpoint-dir",
|
| 97 |
+
type=str,
|
| 98 |
+
default="./checkpoints",
|
| 99 |
+
help="Path to checkpoints directory"
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--config-path",
|
| 103 |
+
type=str,
|
| 104 |
+
default="acestep-v15-turbo-rl",
|
| 105 |
+
help="Model config path"
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--device",
|
| 109 |
+
type=str,
|
| 110 |
+
default="cuda",
|
| 111 |
+
help="Device to use (cuda/cpu)"
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--lm-model",
|
| 115 |
+
type=str,
|
| 116 |
+
default="acestep-5Hz-lm-0.6B-v3",
|
| 117 |
+
help="LM model path"
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--lm-backend",
|
| 121 |
+
type=str,
|
| 122 |
+
default="vllm",
|
| 123 |
+
help="LM backend"
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--warmup",
|
| 127 |
+
action="store_true",
|
| 128 |
+
help="Enable warmup run before profiling (default: False, profile first run)"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
args = parser.parse_args()
|
| 132 |
+
|
| 133 |
+
# Initialize handlers
|
| 134 |
+
print("Initializing handlers...")
|
| 135 |
+
dit_handler = AceStepHandler()
|
| 136 |
+
llm_handler = LLMHandler()
|
| 137 |
+
|
| 138 |
+
# Initialize DiT
|
| 139 |
+
print(" - Initializing DiT model...")
|
| 140 |
+
status_dit, success_dit = dit_handler.initialize_service(
|
| 141 |
+
project_root=project_root,
|
| 142 |
+
config_path=args.config_path,
|
| 143 |
+
device=args.device,
|
| 144 |
+
)
|
| 145 |
+
if not success_dit:
|
| 146 |
+
print(f" ❌ DiT initialization failed: {status_dit}")
|
| 147 |
+
sys.exit(1)
|
| 148 |
+
print(" ✓ DiT model initialized")
|
| 149 |
+
|
| 150 |
+
# Initialize LLM
|
| 151 |
+
print(" - Initializing LLM model...")
|
| 152 |
+
status_llm, success_llm = llm_handler.initialize(
|
| 153 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 154 |
+
lm_model_path=args.lm_model,
|
| 155 |
+
backend=args.lm_backend,
|
| 156 |
+
device=args.device,
|
| 157 |
+
)
|
| 158 |
+
if success_llm:
|
| 159 |
+
print(" ✓ LM model initialized")
|
| 160 |
+
else:
|
| 161 |
+
print(f" ⚠ LM initialization failed: {status_llm}")
|
| 162 |
+
|
| 163 |
+
# Load test parameters from example file (same as acestep/inference.py)
|
| 164 |
+
def load_example_config(example_file: str) -> Tuple[GenerationParams, GenerationConfig]:
|
| 165 |
+
"""Load configuration from an example JSON file."""
|
| 166 |
+
try:
|
| 167 |
+
with open(example_file, 'r', encoding='utf-8') as f:
|
| 168 |
+
data = json.load(f)
|
| 169 |
+
|
| 170 |
+
# Convert example format to GenerationParams and GenerationConfig
|
| 171 |
+
# Handle time signature format (example uses "4" instead of "4/4")
|
| 172 |
+
time_sig = data.get('timesignature', '')
|
| 173 |
+
|
| 174 |
+
params = GenerationParams(
|
| 175 |
+
caption=data.get('caption', ''),
|
| 176 |
+
lyrics=data.get('lyrics', ''),
|
| 177 |
+
bpm=data.get('bpm'),
|
| 178 |
+
keyscale=data.get('keyscale', ''),
|
| 179 |
+
timesignature=time_sig,
|
| 180 |
+
vocal_language=data.get('language', 'unknown'),
|
| 181 |
+
duration=data.get('duration'),
|
| 182 |
+
thinking=data.get('think', False),
|
| 183 |
+
inference_steps=data.get('inference_steps', 8),
|
| 184 |
+
seed=42,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
config = GenerationConfig()
|
| 188 |
+
config.batch_size = data.get('batch_size', 1)
|
| 189 |
+
|
| 190 |
+
return params, config
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f" ⚠ Failed to load example file: {e}")
|
| 194 |
+
return None, None
|
| 195 |
+
|
| 196 |
+
# Load production example (same as acestep/inference.py)
|
| 197 |
+
example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
|
| 198 |
+
|
| 199 |
+
if not os.path.exists(example_file):
|
| 200 |
+
print(f"\n ❌ Example file not found: {example_file}")
|
| 201 |
+
print(" Please ensure the examples directory exists.")
|
| 202 |
+
sys.exit(1)
|
| 203 |
+
|
| 204 |
+
print(f"\n Loading example: {os.path.basename(example_file)}")
|
| 205 |
+
params, config = load_example_config(example_file)
|
| 206 |
+
|
| 207 |
+
if not params or not config:
|
| 208 |
+
print(" ❌ Failed to load example configuration")
|
| 209 |
+
sys.exit(1)
|
| 210 |
+
|
| 211 |
+
print("\n" + "=" * 80)
|
| 212 |
+
print("Starting profiling...")
|
| 213 |
+
print("=" * 80)
|
| 214 |
+
|
| 215 |
+
result = profile_with_cprofile(dit_handler, llm_handler, params, config, warmup=args.warmup)
|
| 216 |
+
|
| 217 |
+
if result and not result.success:
|
| 218 |
+
print(f"\n⚠ Generation failed: {result.error}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
main()
|
| 223 |
+
|