| import torch
|
| import os
|
| import soundfile as sf
|
| from diffusers.models import AutoencoderOobleck
|
| from tqdm import tqdm
|
| import torch.nn.functional as F
|
|
|
| def process_audio(audio_path, target_sr=48000):
|
| try:
|
|
|
| audio_np, sr = sf.read(audio_path, dtype='float32')
|
|
|
|
|
| if audio_np.ndim == 1:
|
| audio = torch.from_numpy(audio_np).unsqueeze(0)
|
| else:
|
| audio = torch.from_numpy(audio_np.T)
|
|
|
|
|
| if audio.shape[0] == 1:
|
| audio = torch.cat([audio, audio], dim=0)
|
|
|
| audio = audio[:2]
|
|
|
|
|
| if sr != target_sr:
|
| ratio = target_sr / sr
|
| new_length = int(audio.shape[-1] * ratio)
|
| audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
|
|
|
| audio = torch.clamp(audio, -1.0, 1.0)
|
| return audio.unsqueeze(0)
|
|
|
| except Exception as e:
|
| print(f"Error processing {audio_path}: {e}")
|
| return None
|
|
|
| def main():
|
| print("Initializing Calibration Data Preparation...")
|
|
|
| current_dir = os.path.dirname(os.path.abspath(__file__))
|
| project_root = os.path.dirname(current_dir)
|
| data_dir = os.path.join(project_root, "data", "quant_data")
|
| output_path = os.path.join(project_root, "data", "calibration_latents.pt")
|
| vae_path = os.path.join(project_root, "checkpoints", "vae")
|
|
|
| if not os.path.exists(data_dir):
|
| print(f"Error: Data directory not found at {data_dir}")
|
| return
|
|
|
| print(f"Loading VAE from {vae_path}...")
|
| try:
|
| vae = AutoencoderOobleck.from_pretrained(vae_path)
|
| except Exception as e:
|
| print(f"Failed to load VAE: {e}")
|
| return
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| if hasattr(torch, "xpu") and torch.xpu.is_available():
|
| device = "xpu"
|
|
|
| print(f"Using device: {device}")
|
| vae = vae.to(device)
|
| vae.eval()
|
|
|
| audio_files = [f for f in os.listdir(data_dir) if f.endswith('.flac')]
|
| print(f"Found {len(audio_files)} audio files.")
|
|
|
| all_chunks = []
|
| chunk_size = 512
|
| samples_per_latent = 1920
|
| audio_chunk_size = chunk_size * samples_per_latent
|
|
|
| pbar = tqdm(audio_files, desc="Processing audio")
|
| for audio_file in pbar:
|
| file_path = os.path.join(data_dir, audio_file)
|
| full_audio = process_audio(file_path)
|
|
|
| if full_audio is None:
|
| continue
|
|
|
|
|
| num_samples = full_audio.shape[-1]
|
|
|
| for start_idx in range(0, num_samples, audio_chunk_size):
|
| end_idx = start_idx + audio_chunk_size
|
| if end_idx > num_samples:
|
| break
|
|
|
| audio_input = full_audio[:, :, start_idx:end_idx].to(device)
|
|
|
| try:
|
| with torch.no_grad():
|
|
|
|
|
|
|
| posterior = vae.encode(audio_input).latent_dist
|
| latents = posterior.sample()
|
|
|
|
|
| if latents.shape[-1] >= chunk_size:
|
| all_chunks.append(latents[:, :, :chunk_size].cpu())
|
|
|
| pbar.set_postfix({"chunks": len(all_chunks)})
|
|
|
| except Exception as e:
|
| print(f"Error encoding chunk {start_idx}-{end_idx} of {audio_file}: {e}")
|
| torch.cuda.empty_cache()
|
| if device == "xpu":
|
| torch.xpu.empty_cache()
|
|
|
| print(f"Collected {len(all_chunks)} chunks of size {chunk_size}.")
|
|
|
| if len(all_chunks) > 0:
|
| print(f"Saving to {output_path}...")
|
| torch.save(all_chunks, output_path)
|
| print("Done.")
|
| else:
|
| print("No chunks collected.")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|