ChuxiJ commited on
Commit
8ff7c0c
·
2 Parent(s): a3b47b770c780d

Merge branch 'main' of github.com:ace-step/ACE-Step-1.5 into main

Browse files
.gitignore CHANGED
@@ -1,3 +1,4 @@
 
1
  *.mp3
2
  *.wav
3
 
 
1
+ data/
2
  *.mp3
3
  *.wav
4
 
acestep/handler.py CHANGED
@@ -156,6 +156,7 @@ class AceStepHandler:
156
  compile_model: bool = False,
157
  offload_to_cpu: bool = False,
158
  offload_dit_to_cpu: bool = False,
 
159
  ) -> Tuple[str, bool]:
160
  """
161
  Initialize model service
@@ -186,6 +187,14 @@ class AceStepHandler:
186
  self.offload_dit_to_cpu = offload_dit_to_cpu
187
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
188
  self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
 
 
 
 
 
 
 
 
189
 
190
  # Auto-detect project root (independent of passed project_root parameter)
191
  current_file = os.path.abspath(__file__)
@@ -238,9 +247,26 @@ class AceStepHandler:
238
  self.model.eval()
239
 
240
  if compile_model:
241
- logger.info("Compiling model with torch.compile...")
242
  self.model = torch.compile(self.model)
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
245
  if os.path.exists(silence_latent_path):
246
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
@@ -267,6 +293,9 @@ class AceStepHandler:
267
  self.vae.eval()
268
  else:
269
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
 
 
 
270
 
271
  # 3. Load text encoder and tokenizer
272
  text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
 
156
  compile_model: bool = False,
157
  offload_to_cpu: bool = False,
158
  offload_dit_to_cpu: bool = False,
159
+ quantization: Optional[str] = None,
160
  ) -> Tuple[str, bool]:
161
  """
162
  Initialize model service
 
187
  self.offload_dit_to_cpu = offload_dit_to_cpu
188
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
189
  self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
190
+ self.quantization = quantization
191
+ if self.quantization is not None:
192
+ assert compile_model, "Quantization requires compile_model to be True"
193
+ try:
194
+ import torchao
195
+ except ImportError:
196
+ raise ImportError("torchao is required for quantization but is not installed. Please install torchao to use quantization features.")
197
+
198
 
199
  # Auto-detect project root (independent of passed project_root parameter)
200
  current_file = os.path.abspath(__file__)
 
247
  self.model.eval()
248
 
249
  if compile_model:
 
250
  self.model = torch.compile(self.model)
251
 
252
+ if self.quantization is not None:
253
+ from torchao.quantization import quantize_
254
+ if self.quantization == "int8_weight_only":
255
+ from torchao.quantization import Int8WeightOnlyConfig
256
+ quant_config = Int8WeightOnlyConfig()
257
+ elif self.quantization == "fp8_weight_only":
258
+ from torchao.quantization import Float8WeightOnlyConfig
259
+ quant_config = Float8WeightOnlyConfig()
260
+ elif self.quantization == "w8a8_dynamic":
261
+ from torchao.quantization import Int8DynamicActivationInt8WeightConfig, MappingType
262
+ quant_config = Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC)
263
+ else:
264
+ raise ValueError(f"Unsupported quantization type: {self.quantization}")
265
+
266
+ quantize_(self.model, quant_config)
267
+ logger.info("DiT quantized with:",self.quantization)
268
+
269
+
270
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
271
  if os.path.exists(silence_latent_path):
272
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
 
293
  self.vae.eval()
294
  else:
295
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
296
+
297
+ if compile_model:
298
+ self.vae = torch.compile(self.vae)
299
 
300
  # 3. Load text encoder and tokenizer
301
  text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
scripts/prepare_vae_calibration_data.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import soundfile as sf
4
+ from diffusers.models import AutoencoderOobleck
5
+ from tqdm import tqdm
6
+ import torch.nn.functional as F
7
+
8
+ def process_audio(audio_path, target_sr=48000):
9
+ try:
10
+ # Load audio using soundfile
11
+ audio_np, sr = sf.read(audio_path, dtype='float32')
12
+
13
+ # Convert to torch: [samples, channels] or [samples] -> [channels, samples]
14
+ if audio_np.ndim == 1:
15
+ audio = torch.from_numpy(audio_np).unsqueeze(0)
16
+ else:
17
+ audio = torch.from_numpy(audio_np.T)
18
+
19
+ # Ensure stereo
20
+ if audio.shape[0] == 1:
21
+ audio = torch.cat([audio, audio], dim=0)
22
+
23
+ audio = audio[:2]
24
+
25
+ # Resample if needed
26
+ if sr != target_sr:
27
+ ratio = target_sr / sr
28
+ new_length = int(audio.shape[-1] * ratio)
29
+ audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
30
+
31
+ audio = torch.clamp(audio, -1.0, 1.0)
32
+ return audio.unsqueeze(0) # Add batch dim: [1, 2, samples]
33
+
34
+ except Exception as e:
35
+ print(f"Error processing {audio_path}: {e}")
36
+ return None
37
+
38
+ def main():
39
+ print("Initializing Calibration Data Preparation...")
40
+
41
+ current_dir = os.path.dirname(os.path.abspath(__file__))
42
+ project_root = os.path.dirname(current_dir)
43
+ data_dir = os.path.join(project_root, "data", "quant_data")
44
+ output_path = os.path.join(project_root, "data", "calibration_latents.pt")
45
+ vae_path = os.path.join(project_root, "checkpoints", "vae")
46
+
47
+ if not os.path.exists(data_dir):
48
+ print(f"Error: Data directory not found at {data_dir}")
49
+ return
50
+
51
+ print(f"Loading VAE from {vae_path}...")
52
+ try:
53
+ vae = AutoencoderOobleck.from_pretrained(vae_path)
54
+ except Exception as e:
55
+ print(f"Failed to load VAE: {e}")
56
+ return
57
+
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ # Check for XPU
60
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
61
+ device = "xpu"
62
+
63
+ print(f"Using device: {device}")
64
+ vae = vae.to(device)
65
+ vae.eval()
66
+
67
+ audio_files = [f for f in os.listdir(data_dir) if f.endswith('.flac')]
68
+ print(f"Found {len(audio_files)} audio files.")
69
+
70
+ all_chunks = []
71
+ chunk_size = 512 # Latent frames
72
+ samples_per_latent = 1920
73
+ audio_chunk_size = chunk_size * samples_per_latent
74
+
75
+ pbar = tqdm(audio_files, desc="Processing audio")
76
+ for audio_file in pbar:
77
+ file_path = os.path.join(data_dir, audio_file)
78
+ full_audio = process_audio(file_path)
79
+
80
+ if full_audio is None:
81
+ continue
82
+
83
+ # Split audio into chunks
84
+ num_samples = full_audio.shape[-1]
85
+
86
+ for start_idx in range(0, num_samples, audio_chunk_size):
87
+ end_idx = start_idx + audio_chunk_size
88
+ if end_idx > num_samples:
89
+ break # Skip incomplete chunks
90
+
91
+ audio_input = full_audio[:, :, start_idx:end_idx].to(device)
92
+
93
+ try:
94
+ with torch.no_grad():
95
+ # Encode
96
+ # VAE encode expects [Batch, Channels, Samples]
97
+ # Returns DiagonalGaussianDistribution
98
+ posterior = vae.encode(audio_input).latent_dist
99
+ latents = posterior.sample() # [1, 64, LatentLength]
100
+
101
+ # It should be exactly chunk_size, but let's be safe
102
+ if latents.shape[-1] >= chunk_size:
103
+ all_chunks.append(latents[:, :, :chunk_size].cpu())
104
+
105
+ pbar.set_postfix({"chunks": len(all_chunks)})
106
+
107
+ except Exception as e:
108
+ print(f"Error encoding chunk {start_idx}-{end_idx} of {audio_file}: {e}")
109
+ torch.cuda.empty_cache()
110
+ if device == "xpu":
111
+ torch.xpu.empty_cache()
112
+
113
+ print(f"Collected {len(all_chunks)} chunks of size {chunk_size}.")
114
+
115
+ if len(all_chunks) > 0:
116
+ print(f"Saving to {output_path}...")
117
+ torch.save(all_chunks, output_path)
118
+ print("Done.")
119
+ else:
120
+ print("No chunks collected.")
121
+
122
+ if __name__ == "__main__":
123
+ main()
test.py CHANGED
@@ -46,6 +46,7 @@ def main():
46
  compile_model=True,
47
  offload_to_cpu=True,
48
  offload_dit_to_cpu=False, # Keep DiT on GPU
 
49
  )
50
 
51
  if not enabled:
@@ -107,7 +108,12 @@ def main():
107
  print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
108
  else:
109
  print("Skipping 5Hz LLM generation...")
110
- metadata = {}
 
 
 
 
 
111
  audio_codes = None
112
  lm_status = "Skipped"
113
 
 
46
  compile_model=True,
47
  offload_to_cpu=True,
48
  offload_dit_to_cpu=False, # Keep DiT on GPU
49
+ quantization="int8_weight_only", # Enable FP8 weight-only quantization
50
  )
51
 
52
  if not enabled:
 
108
  print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
109
  else:
110
  print("Skipping 5Hz LLM generation...")
111
+ metadata = {
112
+ 'bpm': 90,
113
+ 'keyscale': 'A major',
114
+ 'timesignature': '4',
115
+ 'duration': 240,
116
+ }
117
  audio_codes = None
118
  lm_status = "Skipped"
119