Spaces:
Running
on
A100
Running
on
A100
| import torch | |
| import os | |
| import soundfile as sf | |
| from diffusers.models import AutoencoderOobleck | |
| from tqdm import tqdm | |
| import torch.nn.functional as F | |
| def process_audio(audio_path, target_sr=48000): | |
| try: | |
| # Load audio using soundfile | |
| audio_np, sr = sf.read(audio_path, dtype='float32') | |
| # Convert to torch: [samples, channels] or [samples] -> [channels, samples] | |
| if audio_np.ndim == 1: | |
| audio = torch.from_numpy(audio_np).unsqueeze(0) | |
| else: | |
| audio = torch.from_numpy(audio_np.T) | |
| # Ensure stereo | |
| if audio.shape[0] == 1: | |
| audio = torch.cat([audio, audio], dim=0) | |
| audio = audio[:2] | |
| # Resample if needed | |
| if sr != target_sr: | |
| ratio = target_sr / sr | |
| new_length = int(audio.shape[-1] * ratio) | |
| audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0) | |
| audio = torch.clamp(audio, -1.0, 1.0) | |
| return audio.unsqueeze(0) # Add batch dim: [1, 2, samples] | |
| except Exception as e: | |
| print(f"Error processing {audio_path}: {e}") | |
| return None | |
| def main(): | |
| print("Initializing Calibration Data Preparation...") | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| project_root = os.path.dirname(current_dir) | |
| data_dir = os.path.join(project_root, "data", "quant_data") | |
| output_path = os.path.join(project_root, "data", "calibration_latents.pt") | |
| vae_path = os.path.join(project_root, "checkpoints", "vae") | |
| if not os.path.exists(data_dir): | |
| print(f"Error: Data directory not found at {data_dir}") | |
| return | |
| print(f"Loading VAE from {vae_path}...") | |
| try: | |
| vae = AutoencoderOobleck.from_pretrained(vae_path) | |
| except Exception as e: | |
| print(f"Failed to load VAE: {e}") | |
| return | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Check for XPU | |
| if hasattr(torch, "xpu") and torch.xpu.is_available(): | |
| device = "xpu" | |
| print(f"Using device: {device}") | |
| vae = vae.to(device) | |
| vae.eval() | |
| audio_files = [f for f in os.listdir(data_dir) if f.endswith('.flac')] | |
| print(f"Found {len(audio_files)} audio files.") | |
| all_chunks = [] | |
| chunk_size = 512 # Latent frames | |
| samples_per_latent = 1920 | |
| audio_chunk_size = chunk_size * samples_per_latent | |
| pbar = tqdm(audio_files, desc="Processing audio") | |
| for audio_file in pbar: | |
| file_path = os.path.join(data_dir, audio_file) | |
| full_audio = process_audio(file_path) | |
| if full_audio is None: | |
| continue | |
| # Split audio into chunks | |
| num_samples = full_audio.shape[-1] | |
| for start_idx in range(0, num_samples, audio_chunk_size): | |
| end_idx = start_idx + audio_chunk_size | |
| if end_idx > num_samples: | |
| break # Skip incomplete chunks | |
| audio_input = full_audio[:, :, start_idx:end_idx].to(device) | |
| try: | |
| with torch.no_grad(): | |
| # Encode | |
| # VAE encode expects [Batch, Channels, Samples] | |
| # Returns DiagonalGaussianDistribution | |
| posterior = vae.encode(audio_input).latent_dist | |
| latents = posterior.sample() # [1, 64, LatentLength] | |
| # It should be exactly chunk_size, but let's be safe | |
| if latents.shape[-1] >= chunk_size: | |
| all_chunks.append(latents[:, :, :chunk_size].cpu()) | |
| pbar.set_postfix({"chunks": len(all_chunks)}) | |
| except Exception as e: | |
| print(f"Error encoding chunk {start_idx}-{end_idx} of {audio_file}: {e}") | |
| torch.cuda.empty_cache() | |
| if device == "xpu": | |
| torch.xpu.empty_cache() | |
| print(f"Collected {len(all_chunks)} chunks of size {chunk_size}.") | |
| if len(all_chunks) > 0: | |
| print(f"Saving to {output_path}...") | |
| torch.save(all_chunks, output_path) | |
| print("Done.") | |
| else: | |
| print("No chunks collected.") | |
| if __name__ == "__main__": | |
| main() | |