xushengyuan commited on
Commit
84d50ff
·
1 Parent(s): 1a9005f

add torch compile. sdpa, intel xpu support

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. acestep/handler.py +58 -8
  3. test.py +134 -0
.gitignore CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[codz]
 
1
+ *.mp3
2
+ *.wav
3
+
4
  # Byte-compiled / optimized / DLL files
5
  __pycache__/
6
  *.py[codz]
acestep/handler.py CHANGED
@@ -144,6 +144,7 @@ class AceStepHandler:
144
  init_llm: bool = False,
145
  lm_model_path: str = "acestep-5Hz-lm-0.6B",
146
  use_flash_attention: bool = False,
 
147
  ) -> Tuple[str, bool]:
148
  """
149
  Initialize model service
@@ -155,6 +156,7 @@ class AceStepHandler:
155
  init_llm: Whether to initialize 5Hz LM model
156
  lm_model_path: 5Hz LM model path
157
  use_flash_attention: Whether to use flash attention (requires flash_attn package)
 
158
 
159
  Returns:
160
  (status_message, enable_generate_button)
@@ -165,7 +167,7 @@ class AceStepHandler:
165
 
166
  self.device = device
167
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
168
- self.dtype = torch.bfloat16 if device == "cuda" else torch.float32
169
 
170
  # Auto-detect project root (independent of passed project_root parameter)
171
  current_file = os.path.abspath(__file__)
@@ -177,15 +179,44 @@ class AceStepHandler:
177
  acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
178
  if os.path.exists(acestep_v15_checkpoint_path):
179
  # Determine attention implementation
180
- attn_implementation = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
181
  if use_flash_attention and self.is_flash_attention_available():
 
182
  self.dtype = torch.bfloat16
183
- self.model = AutoModel.from_pretrained(acestep_v15_checkpoint_path, trust_remote_code=True, dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  self.model.config._attn_implementation = attn_implementation
185
  self.config = self.model.config
186
  # Move model to device and set dtype
187
  self.model = self.model.to(device).to(self.dtype)
188
  self.model.eval()
 
 
 
 
 
189
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
190
  if os.path.exists(silence_latent_path):
191
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
@@ -199,7 +230,9 @@ class AceStepHandler:
199
  vae_checkpoint_path = os.path.join(checkpoint_dir, "vae")
200
  if os.path.exists(vae_checkpoint_path):
201
  self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
202
- self.vae = self.vae.to(device).to(self.dtype)
 
 
203
  self.vae.eval()
204
  else:
205
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
@@ -229,7 +262,7 @@ class AceStepHandler:
229
  return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
230
 
231
  # Determine actual attention implementation used
232
- actual_attn = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
233
 
234
  status_msg = f"✅ Model initialized successfully on {device}\n"
235
  status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
@@ -240,7 +273,8 @@ class AceStepHandler:
240
  else:
241
  status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
242
  status_msg += f"Dtype: {self.dtype}\n"
243
- status_msg += f"Attention: {actual_attn}"
 
244
 
245
  return status_msg, True
246
 
@@ -1115,7 +1149,11 @@ class AceStepHandler:
1115
  expected_latent_length = current_wav.shape[-1] // 1920
1116
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1117
  else:
1118
- target_latent = self.vae.encode(current_wav.to(self.device).to(self.dtype)).latent_dist.sample()
 
 
 
 
1119
  target_latent = target_latent.squeeze(0).transpose(0, 1)
1120
  target_latents_list.append(target_latent)
1121
  latent_lengths.append(target_latent.shape[0])
@@ -1471,7 +1509,11 @@ class AceStepHandler:
1471
  refer_audio_order_mask.append(batch_idx)
1472
  else:
1473
  for refer_audio in refer_audios:
1474
- refer_audio_latent = self.vae.encode(refer_audio.unsqueeze(0)).latent_dist.sample()
 
 
 
 
1475
  refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
1476
  refer_audio_order_mask.append(batch_idx)
1477
 
@@ -1802,6 +1844,10 @@ class AceStepHandler:
1802
  seed_value_for_ui, align_score_1, align_text_1, align_plot_1,
1803
  align_score_2, align_text_2, align_plot_2)
1804
  """
 
 
 
 
1805
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
1806
  return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
1807
 
@@ -1927,7 +1973,11 @@ class AceStepHandler:
1927
  with torch.no_grad():
1928
  # Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
1929
  pred_latents_for_decode = pred_latents.transpose(1, 2)
 
 
1930
  pred_wavs = self.vae.decode(pred_latents_for_decode).sample # [batch, channels, samples]
 
 
1931
  end_time = time.time()
1932
  time_costs["vae_decode_time_cost"] = end_time - start_time
1933
  time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
 
144
  init_llm: bool = False,
145
  lm_model_path: str = "acestep-5Hz-lm-0.6B",
146
  use_flash_attention: bool = False,
147
+ compile_model: bool = False,
148
  ) -> Tuple[str, bool]:
149
  """
150
  Initialize model service
 
156
  init_llm: Whether to initialize 5Hz LM model
157
  lm_model_path: 5Hz LM model path
158
  use_flash_attention: Whether to use flash attention (requires flash_attn package)
159
+ compile_model: Whether to use torch.compile to optimize the model
160
 
161
  Returns:
162
  (status_message, enable_generate_button)
 
167
 
168
  self.device = device
169
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
170
+ self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
171
 
172
  # Auto-detect project root (independent of passed project_root parameter)
173
  current_file = os.path.abspath(__file__)
 
179
  acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
180
  if os.path.exists(acestep_v15_checkpoint_path):
181
  # Determine attention implementation
 
182
  if use_flash_attention and self.is_flash_attention_available():
183
+ attn_implementation = "flash_attention_2"
184
  self.dtype = torch.bfloat16
185
+ else:
186
+ attn_implementation = "sdpa"
187
+
188
+ try:
189
+ logger.info(f"Attempting to load model with attention implementation: {attn_implementation}")
190
+ self.model = AutoModel.from_pretrained(
191
+ acestep_v15_checkpoint_path,
192
+ trust_remote_code=True,
193
+ dtype=self.dtype,
194
+ attn_implementation=attn_implementation
195
+ )
196
+ except Exception as e:
197
+ logger.warning(f"Failed to load model with {attn_implementation}: {e}")
198
+ if attn_implementation == "sdpa":
199
+ logger.info("Falling back to eager attention")
200
+ attn_implementation = "eager"
201
+ self.model = AutoModel.from_pretrained(
202
+ acestep_v15_checkpoint_path,
203
+ trust_remote_code=True,
204
+ dtype=self.dtype,
205
+ attn_implementation=attn_implementation
206
+ )
207
+ else:
208
+ raise e
209
+
210
  self.model.config._attn_implementation = attn_implementation
211
  self.config = self.model.config
212
  # Move model to device and set dtype
213
  self.model = self.model.to(device).to(self.dtype)
214
  self.model.eval()
215
+
216
+ if compile_model:
217
+ logger.info("Compiling model with torch.compile...")
218
+ self.model = torch.compile(self.model)
219
+
220
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
221
  if os.path.exists(silence_latent_path):
222
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
 
230
  vae_checkpoint_path = os.path.join(checkpoint_dir, "vae")
231
  if os.path.exists(vae_checkpoint_path):
232
  self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
233
+ # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
234
+ vae_dtype = torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
235
+ self.vae = self.vae.to(device).to(vae_dtype)
236
  self.vae.eval()
237
  else:
238
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
 
262
  return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
263
 
264
  # Determine actual attention implementation used
265
+ actual_attn = getattr(self.config, "_attn_implementation", "eager")
266
 
267
  status_msg = f"✅ Model initialized successfully on {device}\n"
268
  status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
 
273
  else:
274
  status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
275
  status_msg += f"Dtype: {self.dtype}\n"
276
+ status_msg += f"Attention: {actual_attn}\n"
277
+ status_msg += f"Compiled: {compile_model}"
278
 
279
  return status_msg, True
280
 
 
1149
  expected_latent_length = current_wav.shape[-1] // 1920
1150
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1151
  else:
1152
+ # Ensure input is in VAE's dtype
1153
+ vae_input = current_wav.to(self.device).to(self.vae.dtype)
1154
+ target_latent = self.vae.encode(vae_input).latent_dist.sample()
1155
+ # Cast back to model dtype
1156
+ target_latent = target_latent.to(self.dtype)
1157
  target_latent = target_latent.squeeze(0).transpose(0, 1)
1158
  target_latents_list.append(target_latent)
1159
  latent_lengths.append(target_latent.shape[0])
 
1509
  refer_audio_order_mask.append(batch_idx)
1510
  else:
1511
  for refer_audio in refer_audios:
1512
+ # Ensure input is in VAE's dtype
1513
+ vae_input = refer_audio.unsqueeze(0).to(self.vae.dtype)
1514
+ refer_audio_latent = self.vae.encode(vae_input).latent_dist.sample()
1515
+ # Cast back to model dtype
1516
+ refer_audio_latent = refer_audio_latent.to(self.dtype)
1517
  refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
1518
  refer_audio_order_mask.append(batch_idx)
1519
 
 
1844
  seed_value_for_ui, align_score_1, align_text_1, align_plot_1,
1845
  align_score_2, align_text_2, align_plot_2)
1846
  """
1847
+ if progress is None:
1848
+ def progress(*args, **kwargs):
1849
+ pass
1850
+
1851
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
1852
  return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
1853
 
 
1973
  with torch.no_grad():
1974
  # Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
1975
  pred_latents_for_decode = pred_latents.transpose(1, 2)
1976
+ # Ensure input is in VAE's dtype
1977
+ pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
1978
  pred_wavs = self.vae.decode(pred_latents_for_decode).sample # [batch, channels, samples]
1979
+ # Cast output to float32 for audio processing/saving
1980
+ pred_wavs = pred_wavs.to(torch.float32)
1981
  end_time = time.time()
1982
  time_costs["vae_decode_time_cost"] = end_time - start_time
1983
  time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
test.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import shutil
5
+ from acestep.handler import AceStepHandler
6
+
7
+
8
+ def main():
9
+ print("Initializing AceStepHandler...")
10
+ handler = AceStepHandler()
11
+
12
+ # Find checkpoints
13
+ checkpoints = handler.get_available_checkpoints()
14
+ if checkpoints:
15
+ project_root = checkpoints[0]
16
+ else:
17
+ # Fallback
18
+ current_file = os.path.abspath(__file__)
19
+ project_root = os.path.join(os.path.dirname(current_file), "checkpoints")
20
+
21
+ print(f"Project root (checkpoints dir): {project_root}")
22
+
23
+ # Find models
24
+ models = handler.get_available_acestep_v15_models()
25
+ if not models:
26
+ print("No models found. Using default 'acestep-v15-turbo'.")
27
+ model_name = "./acestep-v15-turbo"
28
+ else:
29
+ model_name = models[0]
30
+ print(f"Found models: {models}")
31
+ print(f"Using model: {model_name}")
32
+
33
+ # Initialize service
34
+ device = "xpu"
35
+ print(f"Using device: {device}")
36
+
37
+ status, enabled = handler.initialize_service(
38
+ project_root=project_root,
39
+ config_path=model_name,
40
+ device=device,
41
+ init_llm=True,
42
+ use_flash_attention=False, # Default in UI
43
+ compile_model=True
44
+ )
45
+
46
+ if not enabled:
47
+ print(f"Error initializing service: {status}")
48
+ return
49
+
50
+ print(status)
51
+ print("Service initialized successfully.")
52
+
53
+ # Prepare inputs
54
+ captions = "A soft pop arrangement led by light, fingerpicked guitar sets a gentle foundation, Airy keys subtly fill the background, while delicate percussion adds warmth, The sweet female voice floats above, blending naturally with minimal harmonies in the chorus for an intimate, uplifting sound"
55
+
56
+ lyrics = """[Intro]
57
+
58
+ [Verse 1]
59
+ 风吹动那年仲夏
60
+ 翻开谁青涩喧哗
61
+ 白枫书架
62
+ 第七页码
63
+
64
+ [Verse 2]
65
+ 珍藏谁的长发
66
+ 星夜似手中花洒
67
+ 淋湿旧忆木篱笆
68
+ 木槿花下
69
+ 天蓝发夹
70
+ 她默认了他
71
+
72
+ [Bridge]
73
+ 时光将青春的薄荷红蜡
74
+ 匆匆地融化
75
+ 她却沉入人海再无应答
76
+ 隐没在天涯
77
+
78
+ [Chorus]
79
+ 燕子在窗前飞掠
80
+ 寻不到的花被季节带回
81
+ 拧不干的思念如月
82
+ 初恋颜色才能够描绘
83
+
84
+ 木槿在窗外落雪
85
+ 倾泻道别的滋味
86
+ 闭上眼听见微咸的泪水
87
+ 到后来才知那故梦珍贵
88
+
89
+ [Outro]"""
90
+
91
+ seeds = "320145306, 1514681811"
92
+
93
+ print("Starting generation...")
94
+
95
+ # Call generate_music
96
+ results = handler.generate_music(
97
+ captions=captions,
98
+ lyrics=lyrics,
99
+ bpm=90,
100
+ key_scale="A major",
101
+ time_signature="4",
102
+ vocal_language="zh",
103
+ inference_steps=8,
104
+ guidance_scale=7.0,
105
+ use_random_seed=False,
106
+ seed=seeds,
107
+ audio_duration=120,
108
+ batch_size=1,
109
+ task_type="text2music",
110
+ cfg_interval_start=0.0,
111
+ cfg_interval_end=0.95,
112
+ audio_format="wav"
113
+ )
114
+
115
+ # Unpack results
116
+ (audio1, audio2, saved_files, info, status_msg, seed_val,
117
+ align_score1, align_text1, align_plot1,
118
+ align_score2, align_text2, align_plot2) = results
119
+
120
+ print("\nGeneration Complete!")
121
+ print(f"Status: {status_msg}")
122
+ print(f"Info: {info}")
123
+ print(f"Seeds used: {seed_val}")
124
+ print(f"Saved files: {saved_files}")
125
+
126
+ # Copy files
127
+ for f in saved_files:
128
+ if os.path.exists(f):
129
+ dst = os.path.basename(f)
130
+ shutil.copy(f, dst)
131
+ print(f"Saved output to: {os.path.abspath(dst)}")
132
+
133
+ if __name__ == "__main__":
134
+ main()