Ace-Step-v1.5 / scripts /prepare_vae_calibration_data.py
xushengyuan's picture
dynamic quantization for dit model & data prepare for further static quantized vae
1daf6b4
raw
history blame
4.33 kB
import torch
import os
import soundfile as sf
from diffusers.models import AutoencoderOobleck
from tqdm import tqdm
import torch.nn.functional as F
def process_audio(audio_path, target_sr=48000):
try:
# Load audio using soundfile
audio_np, sr = sf.read(audio_path, dtype='float32')
# Convert to torch: [samples, channels] or [samples] -> [channels, samples]
if audio_np.ndim == 1:
audio = torch.from_numpy(audio_np).unsqueeze(0)
else:
audio = torch.from_numpy(audio_np.T)
# Ensure stereo
if audio.shape[0] == 1:
audio = torch.cat([audio, audio], dim=0)
audio = audio[:2]
# Resample if needed
if sr != target_sr:
ratio = target_sr / sr
new_length = int(audio.shape[-1] * ratio)
audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
audio = torch.clamp(audio, -1.0, 1.0)
return audio.unsqueeze(0) # Add batch dim: [1, 2, samples]
except Exception as e:
print(f"Error processing {audio_path}: {e}")
return None
def main():
print("Initializing Calibration Data Preparation...")
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
data_dir = os.path.join(project_root, "data", "quant_data")
output_path = os.path.join(project_root, "data", "calibration_latents.pt")
vae_path = os.path.join(project_root, "checkpoints", "vae")
if not os.path.exists(data_dir):
print(f"Error: Data directory not found at {data_dir}")
return
print(f"Loading VAE from {vae_path}...")
try:
vae = AutoencoderOobleck.from_pretrained(vae_path)
except Exception as e:
print(f"Failed to load VAE: {e}")
return
device = "cuda" if torch.cuda.is_available() else "cpu"
# Check for XPU
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = "xpu"
print(f"Using device: {device}")
vae = vae.to(device)
vae.eval()
audio_files = [f for f in os.listdir(data_dir) if f.endswith('.flac')]
print(f"Found {len(audio_files)} audio files.")
all_chunks = []
chunk_size = 512 # Latent frames
samples_per_latent = 1920
audio_chunk_size = chunk_size * samples_per_latent
pbar = tqdm(audio_files, desc="Processing audio")
for audio_file in pbar:
file_path = os.path.join(data_dir, audio_file)
full_audio = process_audio(file_path)
if full_audio is None:
continue
# Split audio into chunks
num_samples = full_audio.shape[-1]
for start_idx in range(0, num_samples, audio_chunk_size):
end_idx = start_idx + audio_chunk_size
if end_idx > num_samples:
break # Skip incomplete chunks
audio_input = full_audio[:, :, start_idx:end_idx].to(device)
try:
with torch.no_grad():
# Encode
# VAE encode expects [Batch, Channels, Samples]
# Returns DiagonalGaussianDistribution
posterior = vae.encode(audio_input).latent_dist
latents = posterior.sample() # [1, 64, LatentLength]
# It should be exactly chunk_size, but let's be safe
if latents.shape[-1] >= chunk_size:
all_chunks.append(latents[:, :, :chunk_size].cpu())
pbar.set_postfix({"chunks": len(all_chunks)})
except Exception as e:
print(f"Error encoding chunk {start_idx}-{end_idx} of {audio_file}: {e}")
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
print(f"Collected {len(all_chunks)} chunks of size {chunk_size}.")
if len(all_chunks) > 0:
print(f"Saving to {output_path}...")
torch.save(all_chunks, output_path)
print("Done.")
else:
print("No chunks collected.")
if __name__ == "__main__":
main()