xushengyuan commited on
Commit
1daf6b4
·
1 Parent(s): 6b3112a

dynamic quantization for dit model & data prepare for further static quantized vae

Browse files
.gitignore CHANGED
@@ -1,3 +1,4 @@
 
1
  *.mp3
2
  *.wav
3
 
 
1
+ data/
2
  *.mp3
3
  *.wav
4
 
acestep/handler.py CHANGED
@@ -155,6 +155,7 @@ class AceStepHandler:
155
  compile_model: bool = False,
156
  offload_to_cpu: bool = False,
157
  offload_dit_to_cpu: bool = False,
 
158
  ) -> Tuple[str, bool]:
159
  """
160
  Initialize model service
@@ -184,6 +185,14 @@ class AceStepHandler:
184
  self.offload_dit_to_cpu = offload_dit_to_cpu
185
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
186
  self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
 
 
 
 
 
 
 
 
187
 
188
  # Auto-detect project root (independent of passed project_root parameter)
189
  current_file = os.path.abspath(__file__)
@@ -236,9 +245,19 @@ class AceStepHandler:
236
  self.model.eval()
237
 
238
  if compile_model:
239
- logger.info("Compiling model with torch.compile...")
240
  self.model = torch.compile(self.model)
241
 
 
 
 
 
 
 
 
 
 
 
 
242
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
243
  if os.path.exists(silence_latent_path):
244
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
@@ -265,6 +284,9 @@ class AceStepHandler:
265
  self.vae.eval()
266
  else:
267
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
 
 
 
268
 
269
  # 3. Load text encoder and tokenizer
270
  text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
 
155
  compile_model: bool = False,
156
  offload_to_cpu: bool = False,
157
  offload_dit_to_cpu: bool = False,
158
+ quantization: Optional[str] = None,
159
  ) -> Tuple[str, bool]:
160
  """
161
  Initialize model service
 
185
  self.offload_dit_to_cpu = offload_dit_to_cpu
186
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
187
  self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
188
+ self.quantization = quantization
189
+ if self.quantization is not None:
190
+ assert compile_model, "Quantization requires compile_model to be True"
191
+ try:
192
+ import torchao
193
+ except ImportError:
194
+ raise ImportError("torchao is required for quantization but is not installed. Please install torchao to use quantization features.")
195
+
196
 
197
  # Auto-detect project root (independent of passed project_root parameter)
198
  current_file = os.path.abspath(__file__)
 
245
  self.model.eval()
246
 
247
  if compile_model:
 
248
  self.model = torch.compile(self.model)
249
 
250
+ if self.quantization == "int8_weight_only":
251
+ from torchao.quantization import quantize_, Int8WeightOnlyConfig
252
+ quantize_(self.model, Int8WeightOnlyConfig())
253
+ logger.info("DiT quantized with Int8WeightOnlyConfig")
254
+ elif self.quantization == "fp8_weight_only":
255
+ from torchao.quantization import quantize_, Float8WeightOnlyConfig
256
+ quantize_(self.model, Float8WeightOnlyConfig())
257
+ elif self.quantization is not None:
258
+ raise ValueError(f"Unsupported quantization type: {self.quantization}")
259
+
260
+
261
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
262
  if os.path.exists(silence_latent_path):
263
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
 
284
  self.vae.eval()
285
  else:
286
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
287
+
288
+ if compile_model:
289
+ self.vae = torch.compile(self.vae)
290
 
291
  # 3. Load text encoder and tokenizer
292
  text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
scripts/prepare_vae_calibration_data.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import soundfile as sf
4
+ from diffusers.models import AutoencoderOobleck
5
+ from tqdm import tqdm
6
+ import torch.nn.functional as F
7
+
8
+ def process_audio(audio_path, target_sr=48000):
9
+ try:
10
+ # Load audio using soundfile
11
+ audio_np, sr = sf.read(audio_path, dtype='float32')
12
+
13
+ # Convert to torch: [samples, channels] or [samples] -> [channels, samples]
14
+ if audio_np.ndim == 1:
15
+ audio = torch.from_numpy(audio_np).unsqueeze(0)
16
+ else:
17
+ audio = torch.from_numpy(audio_np.T)
18
+
19
+ # Ensure stereo
20
+ if audio.shape[0] == 1:
21
+ audio = torch.cat([audio, audio], dim=0)
22
+
23
+ audio = audio[:2]
24
+
25
+ # Resample if needed
26
+ if sr != target_sr:
27
+ ratio = target_sr / sr
28
+ new_length = int(audio.shape[-1] * ratio)
29
+ audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
30
+
31
+ audio = torch.clamp(audio, -1.0, 1.0)
32
+ return audio.unsqueeze(0) # Add batch dim: [1, 2, samples]
33
+
34
+ except Exception as e:
35
+ print(f"Error processing {audio_path}: {e}")
36
+ return None
37
+
38
+ def main():
39
+ print("Initializing Calibration Data Preparation...")
40
+
41
+ current_dir = os.path.dirname(os.path.abspath(__file__))
42
+ project_root = os.path.dirname(current_dir)
43
+ data_dir = os.path.join(project_root, "data", "quant_data")
44
+ output_path = os.path.join(project_root, "data", "calibration_latents.pt")
45
+ vae_path = os.path.join(project_root, "checkpoints", "vae")
46
+
47
+ if not os.path.exists(data_dir):
48
+ print(f"Error: Data directory not found at {data_dir}")
49
+ return
50
+
51
+ print(f"Loading VAE from {vae_path}...")
52
+ try:
53
+ vae = AutoencoderOobleck.from_pretrained(vae_path)
54
+ except Exception as e:
55
+ print(f"Failed to load VAE: {e}")
56
+ return
57
+
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ # Check for XPU
60
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
61
+ device = "xpu"
62
+
63
+ print(f"Using device: {device}")
64
+ vae = vae.to(device)
65
+ vae.eval()
66
+
67
+ audio_files = [f for f in os.listdir(data_dir) if f.endswith('.flac')]
68
+ print(f"Found {len(audio_files)} audio files.")
69
+
70
+ all_chunks = []
71
+ chunk_size = 512 # Latent frames
72
+ samples_per_latent = 1920
73
+ audio_chunk_size = chunk_size * samples_per_latent
74
+
75
+ pbar = tqdm(audio_files, desc="Processing audio")
76
+ for audio_file in pbar:
77
+ file_path = os.path.join(data_dir, audio_file)
78
+ full_audio = process_audio(file_path)
79
+
80
+ if full_audio is None:
81
+ continue
82
+
83
+ # Split audio into chunks
84
+ num_samples = full_audio.shape[-1]
85
+
86
+ for start_idx in range(0, num_samples, audio_chunk_size):
87
+ end_idx = start_idx + audio_chunk_size
88
+ if end_idx > num_samples:
89
+ break # Skip incomplete chunks
90
+
91
+ audio_input = full_audio[:, :, start_idx:end_idx].to(device)
92
+
93
+ try:
94
+ with torch.no_grad():
95
+ # Encode
96
+ # VAE encode expects [Batch, Channels, Samples]
97
+ # Returns DiagonalGaussianDistribution
98
+ posterior = vae.encode(audio_input).latent_dist
99
+ latents = posterior.sample() # [1, 64, LatentLength]
100
+
101
+ # It should be exactly chunk_size, but let's be safe
102
+ if latents.shape[-1] >= chunk_size:
103
+ all_chunks.append(latents[:, :, :chunk_size].cpu())
104
+
105
+ pbar.set_postfix({"chunks": len(all_chunks)})
106
+
107
+ except Exception as e:
108
+ print(f"Error encoding chunk {start_idx}-{end_idx} of {audio_file}: {e}")
109
+ torch.cuda.empty_cache()
110
+ if device == "xpu":
111
+ torch.xpu.empty_cache()
112
+
113
+ print(f"Collected {len(all_chunks)} chunks of size {chunk_size}.")
114
+
115
+ if len(all_chunks) > 0:
116
+ print(f"Saving to {output_path}...")
117
+ torch.save(all_chunks, output_path)
118
+ print("Done.")
119
+ else:
120
+ print("No chunks collected.")
121
+
122
+ if __name__ == "__main__":
123
+ main()
test.py CHANGED
@@ -46,6 +46,7 @@ def main():
46
  compile_model=True,
47
  offload_to_cpu=True,
48
  offload_dit_to_cpu=False, # Keep DiT on GPU
 
49
  )
50
 
51
  if not enabled:
 
46
  compile_model=True,
47
  offload_to_cpu=True,
48
  offload_dit_to_cpu=False, # Keep DiT on GPU
49
+ quantization="fp8_weight_only", # Enable FP8 weight-only quantization
50
  )
51
 
52
  if not enabled: