# extract_dac_latents.py - With random decoding check import os import glob import argparse import torch import torch.multiprocessing as mp from torch.utils.data import Dataset, DataLoader import numpy as np import soundfile as sf import librosa from pathlib import Path from tqdm import tqdm import yaml import json from collections import defaultdict import random import shutil def process_single_audio(audio_path, model, sample_rate, device): """Process a single audio file without padding""" try: # Load audio audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True) # Convert to tensor [1, 1, T] audio_tensor = torch.from_numpy(audio).float() audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0).to(device) # Normalize audio_tensor = torch.clamp(audio_tensor, -1.0, 1.0) # Encode with torch.no_grad(): z, mu, logs = model.encode(audio_tensor, sample_rate) return { 'success': True, 'z': z.cpu(), 'mu': mu.cpu(), 'logs': logs.cpu(), 'duration': len(audio) / sample_rate, 'samples': len(audio), 'compression_ratio': audio_tensor.shape[-1] // z.shape[-1], 'original_audio': audio # Keep original audio for comparison } except Exception as e: print(f"Error processing {audio_path}: {e}") return { 'success': False, 'error': str(e), 'path': audio_path } def decode_and_save_sample(model, latent_data, original_audio, audio_path, tmp_dir, device): """Decode a latent and save both original and reconstructed audio for comparison""" try: # Extract info from path base_name = os.path.basename(audio_path) name_without_ext = os.path.splitext(base_name)[0] # Create subdirectory in tmp for this sample sample_dir = os.path.join(tmp_dir, name_without_ext) os.makedirs(sample_dir, exist_ok=True) # Decode latent z = latent_data['z'].to(device) z = z.unsqueeze(0) print('z shape: ', z.shape) with torch.no_grad(): reconstructed = model.decode(z) # Convert to numpy reconstructed = reconstructed.squeeze().cpu().numpy() if reconstructed.ndim == 2: reconstructed = reconstructed[0] reconstructed = np.clip(reconstructed, -1.0, 1.0) # Save original audio original_path = os.path.join(sample_dir, f"{name_without_ext}_original.wav") sf.write(original_path, original_audio, latent_data['sample_rate']) # Save reconstructed audio reconstructed_path = os.path.join(sample_dir, f"{name_without_ext}_reconstructed.wav") sf.write(reconstructed_path, reconstructed, latent_data['sample_rate']) # Calculate metrics min_len = min(len(original_audio), len(reconstructed)) original_trimmed = original_audio[:min_len] reconstructed_trimmed = reconstructed[:min_len] mse = np.mean((original_trimmed - reconstructed_trimmed) ** 2) snr = 10 * np.log10(np.var(original_trimmed) / (mse + 1e-10)) # Save info file info = { 'original_path': audio_path, 'original_duration': len(original_audio) / latent_data['sample_rate'], 'reconstructed_duration': len(reconstructed) / latent_data['sample_rate'], 'latent_shape': latent_data['latent_shape'], 'compression_ratio': latent_data['compression_ratio'], 'mse': float(mse), 'snr': float(snr) } info_path = os.path.join(sample_dir, 'info.json') with open(info_path, 'w') as f: json.dump(info, f, indent=2) print(f"Sample saved to {sample_dir} - SNR: {snr:.2f} dB, MSE: {mse:.6f}") return True, info except Exception as e: print(f"Error decoding sample: {e}") return False, {'error': str(e)} def extract_latents_gpu(rank, world_size, args, audio_files): """Extract latents on a single GPU""" # Setup device device = torch.device(f'cuda:{rank}') torch.cuda.set_device(device) # Load DAC model from model import DACVAE as VAE print(f"[GPU {rank}] Loading DAC model...") with open(args.config, 'r') as f: config = yaml.safe_load(f) model = VAE(**config['vae']) checkpoint = torch.load(args.checkpoint, map_location='cpu') if 'generator' in checkpoint: model.load_state_dict(checkpoint['generator']) else: model.load_state_dict(checkpoint) model.to(device) model.eval() sample_rate = config['vae']['sample_rate'] # Split files across GPUs files_per_gpu = len(audio_files) // world_size start_idx = rank * files_per_gpu end_idx = start_idx + files_per_gpu if rank < world_size - 1 else len(audio_files) gpu_files = audio_files[start_idx:end_idx] print(f"[GPU {rank}] Processing {len(gpu_files)} files...") # Create tmp directory for this GPU tmp_dir = os.path.join(args.tmp_dir, f'gpu_{rank}') os.makedirs(tmp_dir, exist_ok=True) # Randomly select files for decoding check num_samples = min(args.num_decode_samples, len(gpu_files)) sample_indices = random.sample(range(len(gpu_files)), num_samples) # Process files one by one results = [] decode_results = [] for idx, audio_path in enumerate(tqdm(gpu_files, desc=f'GPU {rank}', position=rank)): # Process single audio result = process_single_audio(audio_path, model, sample_rate, device) if result['success']: # Create output path: a/b/c/d.wav -> a/b/c/d_latent2x.pt base_path = os.path.splitext(audio_path)[0] # Remove extension output_path = f"{base_path}_latent2x.pt" # Create directory if it doesn't exist os.makedirs(os.path.dirname(output_path), exist_ok=True) # Extract data z = result['z'].squeeze(0) # Remove batch dim mu = result['mu'].squeeze(0) logs = result['logs'].squeeze(0) # Save as torch tensor latent_data = { 'z': z, 'mu': mu, 'logs': logs, 'sample_rate': sample_rate, 'compression_ratio': result['compression_ratio'], 'original_duration': result['duration'], 'original_samples': result['samples'], 'latent_shape': list(z.shape), 'original_path': audio_path } torch.save(latent_data, output_path) results.append({ 'path': audio_path, 'output_path': output_path, 'latent_shape': latent_data['latent_shape'], 'duration': result['duration'], 'compression_ratio': result['compression_ratio'] }) # Check if this is a sample to decode if idx in sample_indices: print(f"\n[GPU {rank}] Decoding sample {idx}: {os.path.basename(audio_path)}") success, decode_info = decode_and_save_sample( model, latent_data, result['original_audio'], audio_path, tmp_dir, device ) if success: decode_results.append(decode_info) if rank == 0 and len(results) % 100 == 0: print(f"[GPU {rank}] Processed {len(results)} files...") else: print(f"[GPU {rank}] Failed to process: {audio_path}") results.append({ 'path': audio_path, 'error': result['error'], 'status': 'failed' }) # Save metadata for this GPU metadata_path = os.path.join(args.output_dir, f'metadata_gpu{rank}.json') with open(metadata_path, 'w') as f: json.dump(results, f, indent=2) # Save decode results if decode_results: decode_path = os.path.join(tmp_dir, 'decode_results.json') with open(decode_path, 'w') as f: json.dump({ 'num_samples': len(decode_results), 'samples': decode_results, 'average_snr': np.mean([r['snr'] for r in decode_results if 'snr' in r]), 'average_mse': np.mean([r['mse'] for r in decode_results if 'mse' in r]) }, f, indent=2) print(f"[GPU {rank}] Completed processing {len(results)} files") if decode_results: avg_snr = np.mean([r['snr'] for r in decode_results if 'snr' in r]) print(f"[GPU {rank}] Average SNR for decoded samples: {avg_snr:.2f} dB") def find_audio_files(root_path, extensions=['.wav', '.flac', '.mp3']): """Find all audio files in root_path with various structures""" audio_files = [] # Check if root_path is a file if os.path.isfile(root_path): if any(root_path.endswith(ext) for ext in extensions): return [root_path] # Search for audio files for ext in extensions: # Direct files in root audio_files.extend(glob.glob(os.path.join(root_path, f'*{ext}'))) # Recursive search audio_files.extend(glob.glob(os.path.join(root_path, '**', f'*{ext}'), recursive=True)) # Remove duplicates and sort audio_files = sorted(list(set(audio_files))) return audio_files def merge_metadata(output_dir, tmp_dir, world_size): """Merge metadata from all GPUs""" all_results = [] failed_files = [] all_decode_results = [] for rank in range(world_size): metadata_path = os.path.join(output_dir, f'metadata_gpu{rank}.json') if os.path.exists(metadata_path): with open(metadata_path, 'r') as f: results = json.load(f) for r in results: if 'error' in r: failed_files.append(r) else: all_results.append(r) # Remove individual metadata files os.remove(metadata_path) # Load decode results decode_path = os.path.join(tmp_dir, f'gpu_{rank}', 'decode_results.json') if os.path.exists(decode_path): with open(decode_path, 'r') as f: decode_data = json.load(f) all_decode_results.extend(decode_data['samples']) # Save merged metadata metadata_path = os.path.join(output_dir, 'metadata.json') with open(metadata_path, 'w') as f: json.dump({ 'total_files': len(all_results), 'failed_files': len(failed_files), 'files': all_results }, f, indent=2) # Save failed files list if any if failed_files: failed_path = os.path.join(output_dir, 'failed_files.json') with open(failed_path, 'w') as f: json.dump(failed_files, f, indent=2) # Create summary statistics total_duration = sum(r['duration'] for r in all_results) latent_dims = defaultdict(int) compression_ratios = defaultdict(int) for r in all_results: shape_key = str(r['latent_shape']) latent_dims[shape_key] += 1 compression_ratios[r['compression_ratio']] += 1 summary = { 'total_files': len(all_results), 'failed_files': len(failed_files), 'total_duration_hours': total_duration / 3600, 'latent_dimensions': dict(latent_dims), 'compression_ratios': dict(compression_ratios), 'average_duration': total_duration / len(all_results) if all_results else 0, 'decode_samples': len(all_decode_results) } if all_decode_results: summary['average_snr'] = np.mean([r['snr'] for r in all_decode_results if 'snr' in r]) summary['average_mse'] = np.mean([r['mse'] for r in all_decode_results if 'mse' in r]) summary_path = os.path.join(output_dir, 'summary.json') with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) print(f"\nProcessing complete!") print(f"Successfully processed: {len(all_results)} files") print(f"Failed: {len(failed_files)} files") print(f"Total duration: {total_duration/3600:.2f} hours") print(f"Average duration: {summary['average_duration']:.2f} seconds") print(f"Compression ratios: {dict(compression_ratios)}") if all_decode_results: print(f"\nDecode Quality Check:") print(f"Samples decoded: {len(all_decode_results)}") print(f"Average SNR: {summary['average_snr']:.2f} dB") print(f"Average MSE: {summary['average_mse']:.6f}") print(f"Check tmp/ folder for audio comparisons") print(f"\nResults saved to: {output_dir}") def main(): parser = argparse.ArgumentParser(description='Extract DAC latents with multi-GPU support') parser.add_argument('--root_path', type=str, required=True, help='Root path containing audio files') parser.add_argument('--output_dir', type=str, required=True, help='Directory to save metadata (latents saved alongside audio)') parser.add_argument('--checkpoint', type=str, required=True, help='Path to DAC checkpoint') parser.add_argument('--config', type=str, required=True, help='Path to DAC config') parser.add_argument('--num_gpus', type=int, default=None, help='Number of GPUs to use (default: all available)') parser.add_argument('--file_list', type=str, default=None, help='Optional text file containing list of audio paths') parser.add_argument('--skip_existing', action='store_true', help='Skip files that already have latents') parser.add_argument('--tmp_dir', type=str, default='./tmp', help='Directory to save decoded samples for checking') parser.add_argument('--num_decode_samples', type=int, default=5, help='Number of random samples to decode per GPU for quality check') parser.add_argument('--clean_tmp', action='store_true', help='Clean tmp directory before starting') args = parser.parse_args() # Clean tmp directory if requested if args.clean_tmp and os.path.exists(args.tmp_dir): print(f"Cleaning tmp directory: {args.tmp_dir}") shutil.rmtree(args.tmp_dir) # Create tmp directory os.makedirs(args.tmp_dir, exist_ok=True) # Find audio files if args.file_list: print(f"Loading file list from {args.file_list}") with open(args.file_list, 'r') as f: audio_files = [line.strip() for line in f if line.strip()] else: print(f"Searching for audio files in {args.root_path}") audio_files = find_audio_files(args.root_path) if not audio_files: print("No audio files found!") return # Filter out existing if requested if args.skip_existing: filtered_files = [] for audio_path in audio_files: base_path = os.path.splitext(audio_path)[0] latent_path = f"{base_path}_latent2x.pt" old_latent_path = f"{base_path}_latent.pt" if os.path.exists(old_latent_path): os.remove(old_latent_path) print(f"Removed old latent file: {old_latent_path}") if not os.path.exists(latent_path): filtered_files.append(audio_path) print(f"Skipping {len(audio_files) - len(filtered_files)} existing files") audio_files = filtered_files print(f"Found {len(audio_files)} audio files to process") if not audio_files: print("No files to process!") return # Create output directory for metadata os.makedirs(args.output_dir, exist_ok=True) # Determine number of GPUs if args.num_gpus is None: args.num_gpus = torch.cuda.device_count() print(f"Using {args.num_gpus} GPUs") print(f"Will decode {args.num_decode_samples} random samples per GPU for quality check") if args.num_gpus == 1: # Single GPU extract_latents_gpu(0, 1, args, audio_files) else: # Multi-GPU mp.spawn( extract_latents_gpu, args=(args.num_gpus, args, audio_files), nprocs=args.num_gpus, join=True ) # Merge metadata merge_metadata(args.output_dir, args.tmp_dir, args.num_gpus) if __name__ == '__main__': main()