Spaces:
Running
on
A100
Running
on
A100
Commit
·
84d50ff
1
Parent(s):
1a9005f
add torch compile. sdpa, intel xpu support
Browse files- .gitignore +3 -0
- acestep/handler.py +58 -8
- test.py +134 -0
.gitignore
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Byte-compiled / optimized / DLL files
|
| 2 |
__pycache__/
|
| 3 |
*.py[codz]
|
|
|
|
| 1 |
+
*.mp3
|
| 2 |
+
*.wav
|
| 3 |
+
|
| 4 |
# Byte-compiled / optimized / DLL files
|
| 5 |
__pycache__/
|
| 6 |
*.py[codz]
|
acestep/handler.py
CHANGED
|
@@ -144,6 +144,7 @@ class AceStepHandler:
|
|
| 144 |
init_llm: bool = False,
|
| 145 |
lm_model_path: str = "acestep-5Hz-lm-0.6B",
|
| 146 |
use_flash_attention: bool = False,
|
|
|
|
| 147 |
) -> Tuple[str, bool]:
|
| 148 |
"""
|
| 149 |
Initialize model service
|
|
@@ -155,6 +156,7 @@ class AceStepHandler:
|
|
| 155 |
init_llm: Whether to initialize 5Hz LM model
|
| 156 |
lm_model_path: 5Hz LM model path
|
| 157 |
use_flash_attention: Whether to use flash attention (requires flash_attn package)
|
|
|
|
| 158 |
|
| 159 |
Returns:
|
| 160 |
(status_message, enable_generate_button)
|
|
@@ -165,7 +167,7 @@ class AceStepHandler:
|
|
| 165 |
|
| 166 |
self.device = device
|
| 167 |
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
|
| 168 |
-
self.dtype = torch.bfloat16 if device
|
| 169 |
|
| 170 |
# Auto-detect project root (independent of passed project_root parameter)
|
| 171 |
current_file = os.path.abspath(__file__)
|
|
@@ -177,15 +179,44 @@ class AceStepHandler:
|
|
| 177 |
acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
|
| 178 |
if os.path.exists(acestep_v15_checkpoint_path):
|
| 179 |
# Determine attention implementation
|
| 180 |
-
attn_implementation = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
|
| 181 |
if use_flash_attention and self.is_flash_attention_available():
|
|
|
|
| 182 |
self.dtype = torch.bfloat16
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
self.model.config._attn_implementation = attn_implementation
|
| 185 |
self.config = self.model.config
|
| 186 |
# Move model to device and set dtype
|
| 187 |
self.model = self.model.to(device).to(self.dtype)
|
| 188 |
self.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 190 |
if os.path.exists(silence_latent_path):
|
| 191 |
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
|
@@ -199,7 +230,9 @@ class AceStepHandler:
|
|
| 199 |
vae_checkpoint_path = os.path.join(checkpoint_dir, "vae")
|
| 200 |
if os.path.exists(vae_checkpoint_path):
|
| 201 |
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 202 |
-
|
|
|
|
|
|
|
| 203 |
self.vae.eval()
|
| 204 |
else:
|
| 205 |
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
|
|
@@ -229,7 +262,7 @@ class AceStepHandler:
|
|
| 229 |
return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
|
| 230 |
|
| 231 |
# Determine actual attention implementation used
|
| 232 |
-
actual_attn =
|
| 233 |
|
| 234 |
status_msg = f"✅ Model initialized successfully on {device}\n"
|
| 235 |
status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
|
|
@@ -240,7 +273,8 @@ class AceStepHandler:
|
|
| 240 |
else:
|
| 241 |
status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
|
| 242 |
status_msg += f"Dtype: {self.dtype}\n"
|
| 243 |
-
status_msg += f"Attention: {actual_attn}"
|
|
|
|
| 244 |
|
| 245 |
return status_msg, True
|
| 246 |
|
|
@@ -1115,7 +1149,11 @@ class AceStepHandler:
|
|
| 1115 |
expected_latent_length = current_wav.shape[-1] // 1920
|
| 1116 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1117 |
else:
|
| 1118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
target_latent = target_latent.squeeze(0).transpose(0, 1)
|
| 1120 |
target_latents_list.append(target_latent)
|
| 1121 |
latent_lengths.append(target_latent.shape[0])
|
|
@@ -1471,7 +1509,11 @@ class AceStepHandler:
|
|
| 1471 |
refer_audio_order_mask.append(batch_idx)
|
| 1472 |
else:
|
| 1473 |
for refer_audio in refer_audios:
|
| 1474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1475 |
refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
|
| 1476 |
refer_audio_order_mask.append(batch_idx)
|
| 1477 |
|
|
@@ -1802,6 +1844,10 @@ class AceStepHandler:
|
|
| 1802 |
seed_value_for_ui, align_score_1, align_text_1, align_plot_1,
|
| 1803 |
align_score_2, align_text_2, align_plot_2)
|
| 1804 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1805 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 1806 |
return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
|
| 1807 |
|
|
@@ -1927,7 +1973,11 @@ class AceStepHandler:
|
|
| 1927 |
with torch.no_grad():
|
| 1928 |
# Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
|
| 1929 |
pred_latents_for_decode = pred_latents.transpose(1, 2)
|
|
|
|
|
|
|
| 1930 |
pred_wavs = self.vae.decode(pred_latents_for_decode).sample # [batch, channels, samples]
|
|
|
|
|
|
|
| 1931 |
end_time = time.time()
|
| 1932 |
time_costs["vae_decode_time_cost"] = end_time - start_time
|
| 1933 |
time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
|
|
|
|
| 144 |
init_llm: bool = False,
|
| 145 |
lm_model_path: str = "acestep-5Hz-lm-0.6B",
|
| 146 |
use_flash_attention: bool = False,
|
| 147 |
+
compile_model: bool = False,
|
| 148 |
) -> Tuple[str, bool]:
|
| 149 |
"""
|
| 150 |
Initialize model service
|
|
|
|
| 156 |
init_llm: Whether to initialize 5Hz LM model
|
| 157 |
lm_model_path: 5Hz LM model path
|
| 158 |
use_flash_attention: Whether to use flash attention (requires flash_attn package)
|
| 159 |
+
compile_model: Whether to use torch.compile to optimize the model
|
| 160 |
|
| 161 |
Returns:
|
| 162 |
(status_message, enable_generate_button)
|
|
|
|
| 167 |
|
| 168 |
self.device = device
|
| 169 |
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
|
| 170 |
+
self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
|
| 171 |
|
| 172 |
# Auto-detect project root (independent of passed project_root parameter)
|
| 173 |
current_file = os.path.abspath(__file__)
|
|
|
|
| 179 |
acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
|
| 180 |
if os.path.exists(acestep_v15_checkpoint_path):
|
| 181 |
# Determine attention implementation
|
|
|
|
| 182 |
if use_flash_attention and self.is_flash_attention_available():
|
| 183 |
+
attn_implementation = "flash_attention_2"
|
| 184 |
self.dtype = torch.bfloat16
|
| 185 |
+
else:
|
| 186 |
+
attn_implementation = "sdpa"
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
logger.info(f"Attempting to load model with attention implementation: {attn_implementation}")
|
| 190 |
+
self.model = AutoModel.from_pretrained(
|
| 191 |
+
acestep_v15_checkpoint_path,
|
| 192 |
+
trust_remote_code=True,
|
| 193 |
+
dtype=self.dtype,
|
| 194 |
+
attn_implementation=attn_implementation
|
| 195 |
+
)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.warning(f"Failed to load model with {attn_implementation}: {e}")
|
| 198 |
+
if attn_implementation == "sdpa":
|
| 199 |
+
logger.info("Falling back to eager attention")
|
| 200 |
+
attn_implementation = "eager"
|
| 201 |
+
self.model = AutoModel.from_pretrained(
|
| 202 |
+
acestep_v15_checkpoint_path,
|
| 203 |
+
trust_remote_code=True,
|
| 204 |
+
dtype=self.dtype,
|
| 205 |
+
attn_implementation=attn_implementation
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
raise e
|
| 209 |
+
|
| 210 |
self.model.config._attn_implementation = attn_implementation
|
| 211 |
self.config = self.model.config
|
| 212 |
# Move model to device and set dtype
|
| 213 |
self.model = self.model.to(device).to(self.dtype)
|
| 214 |
self.model.eval()
|
| 215 |
+
|
| 216 |
+
if compile_model:
|
| 217 |
+
logger.info("Compiling model with torch.compile...")
|
| 218 |
+
self.model = torch.compile(self.model)
|
| 219 |
+
|
| 220 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 221 |
if os.path.exists(silence_latent_path):
|
| 222 |
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
|
|
|
| 230 |
vae_checkpoint_path = os.path.join(checkpoint_dir, "vae")
|
| 231 |
if os.path.exists(vae_checkpoint_path):
|
| 232 |
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 233 |
+
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
|
| 234 |
+
vae_dtype = torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
|
| 235 |
+
self.vae = self.vae.to(device).to(vae_dtype)
|
| 236 |
self.vae.eval()
|
| 237 |
else:
|
| 238 |
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
|
|
|
|
| 262 |
return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
|
| 263 |
|
| 264 |
# Determine actual attention implementation used
|
| 265 |
+
actual_attn = getattr(self.config, "_attn_implementation", "eager")
|
| 266 |
|
| 267 |
status_msg = f"✅ Model initialized successfully on {device}\n"
|
| 268 |
status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
|
|
|
|
| 273 |
else:
|
| 274 |
status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
|
| 275 |
status_msg += f"Dtype: {self.dtype}\n"
|
| 276 |
+
status_msg += f"Attention: {actual_attn}\n"
|
| 277 |
+
status_msg += f"Compiled: {compile_model}"
|
| 278 |
|
| 279 |
return status_msg, True
|
| 280 |
|
|
|
|
| 1149 |
expected_latent_length = current_wav.shape[-1] // 1920
|
| 1150 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1151 |
else:
|
| 1152 |
+
# Ensure input is in VAE's dtype
|
| 1153 |
+
vae_input = current_wav.to(self.device).to(self.vae.dtype)
|
| 1154 |
+
target_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1155 |
+
# Cast back to model dtype
|
| 1156 |
+
target_latent = target_latent.to(self.dtype)
|
| 1157 |
target_latent = target_latent.squeeze(0).transpose(0, 1)
|
| 1158 |
target_latents_list.append(target_latent)
|
| 1159 |
latent_lengths.append(target_latent.shape[0])
|
|
|
|
| 1509 |
refer_audio_order_mask.append(batch_idx)
|
| 1510 |
else:
|
| 1511 |
for refer_audio in refer_audios:
|
| 1512 |
+
# Ensure input is in VAE's dtype
|
| 1513 |
+
vae_input = refer_audio.unsqueeze(0).to(self.vae.dtype)
|
| 1514 |
+
refer_audio_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1515 |
+
# Cast back to model dtype
|
| 1516 |
+
refer_audio_latent = refer_audio_latent.to(self.dtype)
|
| 1517 |
refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
|
| 1518 |
refer_audio_order_mask.append(batch_idx)
|
| 1519 |
|
|
|
|
| 1844 |
seed_value_for_ui, align_score_1, align_text_1, align_plot_1,
|
| 1845 |
align_score_2, align_text_2, align_plot_2)
|
| 1846 |
"""
|
| 1847 |
+
if progress is None:
|
| 1848 |
+
def progress(*args, **kwargs):
|
| 1849 |
+
pass
|
| 1850 |
+
|
| 1851 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 1852 |
return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
|
| 1853 |
|
|
|
|
| 1973 |
with torch.no_grad():
|
| 1974 |
# Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
|
| 1975 |
pred_latents_for_decode = pred_latents.transpose(1, 2)
|
| 1976 |
+
# Ensure input is in VAE's dtype
|
| 1977 |
+
pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
|
| 1978 |
pred_wavs = self.vae.decode(pred_latents_for_decode).sample # [batch, channels, samples]
|
| 1979 |
+
# Cast output to float32 for audio processing/saving
|
| 1980 |
+
pred_wavs = pred_wavs.to(torch.float32)
|
| 1981 |
end_time = time.time()
|
| 1982 |
time_costs["vae_decode_time_cost"] = end_time - start_time
|
| 1983 |
time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
|
test.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import shutil
|
| 5 |
+
from acestep.handler import AceStepHandler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
print("Initializing AceStepHandler...")
|
| 10 |
+
handler = AceStepHandler()
|
| 11 |
+
|
| 12 |
+
# Find checkpoints
|
| 13 |
+
checkpoints = handler.get_available_checkpoints()
|
| 14 |
+
if checkpoints:
|
| 15 |
+
project_root = checkpoints[0]
|
| 16 |
+
else:
|
| 17 |
+
# Fallback
|
| 18 |
+
current_file = os.path.abspath(__file__)
|
| 19 |
+
project_root = os.path.join(os.path.dirname(current_file), "checkpoints")
|
| 20 |
+
|
| 21 |
+
print(f"Project root (checkpoints dir): {project_root}")
|
| 22 |
+
|
| 23 |
+
# Find models
|
| 24 |
+
models = handler.get_available_acestep_v15_models()
|
| 25 |
+
if not models:
|
| 26 |
+
print("No models found. Using default 'acestep-v15-turbo'.")
|
| 27 |
+
model_name = "./acestep-v15-turbo"
|
| 28 |
+
else:
|
| 29 |
+
model_name = models[0]
|
| 30 |
+
print(f"Found models: {models}")
|
| 31 |
+
print(f"Using model: {model_name}")
|
| 32 |
+
|
| 33 |
+
# Initialize service
|
| 34 |
+
device = "xpu"
|
| 35 |
+
print(f"Using device: {device}")
|
| 36 |
+
|
| 37 |
+
status, enabled = handler.initialize_service(
|
| 38 |
+
project_root=project_root,
|
| 39 |
+
config_path=model_name,
|
| 40 |
+
device=device,
|
| 41 |
+
init_llm=True,
|
| 42 |
+
use_flash_attention=False, # Default in UI
|
| 43 |
+
compile_model=True
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if not enabled:
|
| 47 |
+
print(f"Error initializing service: {status}")
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
print(status)
|
| 51 |
+
print("Service initialized successfully.")
|
| 52 |
+
|
| 53 |
+
# Prepare inputs
|
| 54 |
+
captions = "A soft pop arrangement led by light, fingerpicked guitar sets a gentle foundation, Airy keys subtly fill the background, while delicate percussion adds warmth, The sweet female voice floats above, blending naturally with minimal harmonies in the chorus for an intimate, uplifting sound"
|
| 55 |
+
|
| 56 |
+
lyrics = """[Intro]
|
| 57 |
+
|
| 58 |
+
[Verse 1]
|
| 59 |
+
风吹动那年仲夏
|
| 60 |
+
翻开谁青涩喧哗
|
| 61 |
+
白枫书架
|
| 62 |
+
第七页码
|
| 63 |
+
|
| 64 |
+
[Verse 2]
|
| 65 |
+
珍藏谁的长发
|
| 66 |
+
星夜似手中花洒
|
| 67 |
+
淋湿旧忆木篱笆
|
| 68 |
+
木槿花下
|
| 69 |
+
天蓝发夹
|
| 70 |
+
她默认了他
|
| 71 |
+
|
| 72 |
+
[Bridge]
|
| 73 |
+
时光将青春的薄荷红蜡
|
| 74 |
+
匆匆地融化
|
| 75 |
+
她却沉入人海再无应答
|
| 76 |
+
隐没在天涯
|
| 77 |
+
|
| 78 |
+
[Chorus]
|
| 79 |
+
燕子在窗前飞掠
|
| 80 |
+
寻不到的花被季节带回
|
| 81 |
+
拧不干的思念如月
|
| 82 |
+
初恋颜色才能够描绘
|
| 83 |
+
|
| 84 |
+
木槿在窗外落雪
|
| 85 |
+
倾泻道别的滋味
|
| 86 |
+
闭上眼听见微咸的泪水
|
| 87 |
+
到后来才知那故梦珍贵
|
| 88 |
+
|
| 89 |
+
[Outro]"""
|
| 90 |
+
|
| 91 |
+
seeds = "320145306, 1514681811"
|
| 92 |
+
|
| 93 |
+
print("Starting generation...")
|
| 94 |
+
|
| 95 |
+
# Call generate_music
|
| 96 |
+
results = handler.generate_music(
|
| 97 |
+
captions=captions,
|
| 98 |
+
lyrics=lyrics,
|
| 99 |
+
bpm=90,
|
| 100 |
+
key_scale="A major",
|
| 101 |
+
time_signature="4",
|
| 102 |
+
vocal_language="zh",
|
| 103 |
+
inference_steps=8,
|
| 104 |
+
guidance_scale=7.0,
|
| 105 |
+
use_random_seed=False,
|
| 106 |
+
seed=seeds,
|
| 107 |
+
audio_duration=120,
|
| 108 |
+
batch_size=1,
|
| 109 |
+
task_type="text2music",
|
| 110 |
+
cfg_interval_start=0.0,
|
| 111 |
+
cfg_interval_end=0.95,
|
| 112 |
+
audio_format="wav"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Unpack results
|
| 116 |
+
(audio1, audio2, saved_files, info, status_msg, seed_val,
|
| 117 |
+
align_score1, align_text1, align_plot1,
|
| 118 |
+
align_score2, align_text2, align_plot2) = results
|
| 119 |
+
|
| 120 |
+
print("\nGeneration Complete!")
|
| 121 |
+
print(f"Status: {status_msg}")
|
| 122 |
+
print(f"Info: {info}")
|
| 123 |
+
print(f"Seeds used: {seed_val}")
|
| 124 |
+
print(f"Saved files: {saved_files}")
|
| 125 |
+
|
| 126 |
+
# Copy files
|
| 127 |
+
for f in saved_files:
|
| 128 |
+
if os.path.exists(f):
|
| 129 |
+
dst = os.path.basename(f)
|
| 130 |
+
shutil.copy(f, dst)
|
| 131 |
+
print(f"Saved output to: {os.path.abspath(dst)}")
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|