learnable-speech / dac-vae /extract_dac_latents.py
primepake
Update dac-vae before subtree push
5d43438
# extract_dac_latents.py - With random decoding check
import os
import glob
import argparse
import torch
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader
import numpy as np
import soundfile as sf
import librosa
from pathlib import Path
from tqdm import tqdm
import yaml
import json
from collections import defaultdict
import random
import shutil
def process_single_audio(audio_path, model, sample_rate, device):
"""Process a single audio file without padding"""
try:
# Load audio
audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True)
# Convert to tensor [1, 1, T]
audio_tensor = torch.from_numpy(audio).float()
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0).to(device)
# Normalize
audio_tensor = torch.clamp(audio_tensor, -1.0, 1.0)
# Encode
with torch.no_grad():
z, mu, logs = model.encode(audio_tensor, sample_rate)
return {
'success': True,
'z': z.cpu(),
'mu': mu.cpu(),
'logs': logs.cpu(),
'duration': len(audio) / sample_rate,
'samples': len(audio),
'compression_ratio': audio_tensor.shape[-1] // z.shape[-1],
'original_audio': audio # Keep original audio for comparison
}
except Exception as e:
print(f"Error processing {audio_path}: {e}")
return {
'success': False,
'error': str(e),
'path': audio_path
}
def decode_and_save_sample(model, latent_data, original_audio, audio_path, tmp_dir, device):
"""Decode a latent and save both original and reconstructed audio for comparison"""
try:
# Extract info from path
base_name = os.path.basename(audio_path)
name_without_ext = os.path.splitext(base_name)[0]
# Create subdirectory in tmp for this sample
sample_dir = os.path.join(tmp_dir, name_without_ext)
os.makedirs(sample_dir, exist_ok=True)
# Decode latent
z = latent_data['z'].to(device)
z = z.unsqueeze(0)
print('z shape: ', z.shape)
with torch.no_grad():
reconstructed = model.decode(z)
# Convert to numpy
reconstructed = reconstructed.squeeze().cpu().numpy()
if reconstructed.ndim == 2:
reconstructed = reconstructed[0]
reconstructed = np.clip(reconstructed, -1.0, 1.0)
# Save original audio
original_path = os.path.join(sample_dir, f"{name_without_ext}_original.wav")
sf.write(original_path, original_audio, latent_data['sample_rate'])
# Save reconstructed audio
reconstructed_path = os.path.join(sample_dir, f"{name_without_ext}_reconstructed.wav")
sf.write(reconstructed_path, reconstructed, latent_data['sample_rate'])
# Calculate metrics
min_len = min(len(original_audio), len(reconstructed))
original_trimmed = original_audio[:min_len]
reconstructed_trimmed = reconstructed[:min_len]
mse = np.mean((original_trimmed - reconstructed_trimmed) ** 2)
snr = 10 * np.log10(np.var(original_trimmed) / (mse + 1e-10))
# Save info file
info = {
'original_path': audio_path,
'original_duration': len(original_audio) / latent_data['sample_rate'],
'reconstructed_duration': len(reconstructed) / latent_data['sample_rate'],
'latent_shape': latent_data['latent_shape'],
'compression_ratio': latent_data['compression_ratio'],
'mse': float(mse),
'snr': float(snr)
}
info_path = os.path.join(sample_dir, 'info.json')
with open(info_path, 'w') as f:
json.dump(info, f, indent=2)
print(f"Sample saved to {sample_dir} - SNR: {snr:.2f} dB, MSE: {mse:.6f}")
return True, info
except Exception as e:
print(f"Error decoding sample: {e}")
return False, {'error': str(e)}
def extract_latents_gpu(rank, world_size, args, audio_files):
"""Extract latents on a single GPU"""
# Setup device
device = torch.device(f'cuda:{rank}')
torch.cuda.set_device(device)
# Load DAC model
from model import DACVAE as VAE
print(f"[GPU {rank}] Loading DAC model...")
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
model = VAE(**config['vae'])
checkpoint = torch.load(args.checkpoint, map_location='cpu')
if 'generator' in checkpoint:
model.load_state_dict(checkpoint['generator'])
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
sample_rate = config['vae']['sample_rate']
# Split files across GPUs
files_per_gpu = len(audio_files) // world_size
start_idx = rank * files_per_gpu
end_idx = start_idx + files_per_gpu if rank < world_size - 1 else len(audio_files)
gpu_files = audio_files[start_idx:end_idx]
print(f"[GPU {rank}] Processing {len(gpu_files)} files...")
# Create tmp directory for this GPU
tmp_dir = os.path.join(args.tmp_dir, f'gpu_{rank}')
os.makedirs(tmp_dir, exist_ok=True)
# Randomly select files for decoding check
num_samples = min(args.num_decode_samples, len(gpu_files))
sample_indices = random.sample(range(len(gpu_files)), num_samples)
# Process files one by one
results = []
decode_results = []
for idx, audio_path in enumerate(tqdm(gpu_files, desc=f'GPU {rank}', position=rank)):
# Process single audio
result = process_single_audio(audio_path, model, sample_rate, device)
if result['success']:
# Create output path: a/b/c/d.wav -> a/b/c/d_latent2x.pt
base_path = os.path.splitext(audio_path)[0] # Remove extension
output_path = f"{base_path}_latent2x.pt"
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Extract data
z = result['z'].squeeze(0) # Remove batch dim
mu = result['mu'].squeeze(0)
logs = result['logs'].squeeze(0)
# Save as torch tensor
latent_data = {
'z': z,
'mu': mu,
'logs': logs,
'sample_rate': sample_rate,
'compression_ratio': result['compression_ratio'],
'original_duration': result['duration'],
'original_samples': result['samples'],
'latent_shape': list(z.shape),
'original_path': audio_path
}
torch.save(latent_data, output_path)
results.append({
'path': audio_path,
'output_path': output_path,
'latent_shape': latent_data['latent_shape'],
'duration': result['duration'],
'compression_ratio': result['compression_ratio']
})
# Check if this is a sample to decode
if idx in sample_indices:
print(f"\n[GPU {rank}] Decoding sample {idx}: {os.path.basename(audio_path)}")
success, decode_info = decode_and_save_sample(
model, latent_data, result['original_audio'],
audio_path, tmp_dir, device
)
if success:
decode_results.append(decode_info)
if rank == 0 and len(results) % 100 == 0:
print(f"[GPU {rank}] Processed {len(results)} files...")
else:
print(f"[GPU {rank}] Failed to process: {audio_path}")
results.append({
'path': audio_path,
'error': result['error'],
'status': 'failed'
})
# Save metadata for this GPU
metadata_path = os.path.join(args.output_dir, f'metadata_gpu{rank}.json')
with open(metadata_path, 'w') as f:
json.dump(results, f, indent=2)
# Save decode results
if decode_results:
decode_path = os.path.join(tmp_dir, 'decode_results.json')
with open(decode_path, 'w') as f:
json.dump({
'num_samples': len(decode_results),
'samples': decode_results,
'average_snr': np.mean([r['snr'] for r in decode_results if 'snr' in r]),
'average_mse': np.mean([r['mse'] for r in decode_results if 'mse' in r])
}, f, indent=2)
print(f"[GPU {rank}] Completed processing {len(results)} files")
if decode_results:
avg_snr = np.mean([r['snr'] for r in decode_results if 'snr' in r])
print(f"[GPU {rank}] Average SNR for decoded samples: {avg_snr:.2f} dB")
def find_audio_files(root_path, extensions=['.wav', '.flac', '.mp3']):
"""Find all audio files in root_path with various structures"""
audio_files = []
# Check if root_path is a file
if os.path.isfile(root_path):
if any(root_path.endswith(ext) for ext in extensions):
return [root_path]
# Search for audio files
for ext in extensions:
# Direct files in root
audio_files.extend(glob.glob(os.path.join(root_path, f'*{ext}')))
# Recursive search
audio_files.extend(glob.glob(os.path.join(root_path, '**', f'*{ext}'), recursive=True))
# Remove duplicates and sort
audio_files = sorted(list(set(audio_files)))
return audio_files
def merge_metadata(output_dir, tmp_dir, world_size):
"""Merge metadata from all GPUs"""
all_results = []
failed_files = []
all_decode_results = []
for rank in range(world_size):
metadata_path = os.path.join(output_dir, f'metadata_gpu{rank}.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
results = json.load(f)
for r in results:
if 'error' in r:
failed_files.append(r)
else:
all_results.append(r)
# Remove individual metadata files
os.remove(metadata_path)
# Load decode results
decode_path = os.path.join(tmp_dir, f'gpu_{rank}', 'decode_results.json')
if os.path.exists(decode_path):
with open(decode_path, 'r') as f:
decode_data = json.load(f)
all_decode_results.extend(decode_data['samples'])
# Save merged metadata
metadata_path = os.path.join(output_dir, 'metadata.json')
with open(metadata_path, 'w') as f:
json.dump({
'total_files': len(all_results),
'failed_files': len(failed_files),
'files': all_results
}, f, indent=2)
# Save failed files list if any
if failed_files:
failed_path = os.path.join(output_dir, 'failed_files.json')
with open(failed_path, 'w') as f:
json.dump(failed_files, f, indent=2)
# Create summary statistics
total_duration = sum(r['duration'] for r in all_results)
latent_dims = defaultdict(int)
compression_ratios = defaultdict(int)
for r in all_results:
shape_key = str(r['latent_shape'])
latent_dims[shape_key] += 1
compression_ratios[r['compression_ratio']] += 1
summary = {
'total_files': len(all_results),
'failed_files': len(failed_files),
'total_duration_hours': total_duration / 3600,
'latent_dimensions': dict(latent_dims),
'compression_ratios': dict(compression_ratios),
'average_duration': total_duration / len(all_results) if all_results else 0,
'decode_samples': len(all_decode_results)
}
if all_decode_results:
summary['average_snr'] = np.mean([r['snr'] for r in all_decode_results if 'snr' in r])
summary['average_mse'] = np.mean([r['mse'] for r in all_decode_results if 'mse' in r])
summary_path = os.path.join(output_dir, 'summary.json')
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"\nProcessing complete!")
print(f"Successfully processed: {len(all_results)} files")
print(f"Failed: {len(failed_files)} files")
print(f"Total duration: {total_duration/3600:.2f} hours")
print(f"Average duration: {summary['average_duration']:.2f} seconds")
print(f"Compression ratios: {dict(compression_ratios)}")
if all_decode_results:
print(f"\nDecode Quality Check:")
print(f"Samples decoded: {len(all_decode_results)}")
print(f"Average SNR: {summary['average_snr']:.2f} dB")
print(f"Average MSE: {summary['average_mse']:.6f}")
print(f"Check tmp/ folder for audio comparisons")
print(f"\nResults saved to: {output_dir}")
def main():
parser = argparse.ArgumentParser(description='Extract DAC latents with multi-GPU support')
parser.add_argument('--root_path', type=str, required=True,
help='Root path containing audio files')
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save metadata (latents saved alongside audio)')
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to DAC checkpoint')
parser.add_argument('--config', type=str, required=True,
help='Path to DAC config')
parser.add_argument('--num_gpus', type=int, default=None,
help='Number of GPUs to use (default: all available)')
parser.add_argument('--file_list', type=str, default=None,
help='Optional text file containing list of audio paths')
parser.add_argument('--skip_existing', action='store_true',
help='Skip files that already have latents')
parser.add_argument('--tmp_dir', type=str, default='./tmp',
help='Directory to save decoded samples for checking')
parser.add_argument('--num_decode_samples', type=int, default=5,
help='Number of random samples to decode per GPU for quality check')
parser.add_argument('--clean_tmp', action='store_true',
help='Clean tmp directory before starting')
args = parser.parse_args()
# Clean tmp directory if requested
if args.clean_tmp and os.path.exists(args.tmp_dir):
print(f"Cleaning tmp directory: {args.tmp_dir}")
shutil.rmtree(args.tmp_dir)
# Create tmp directory
os.makedirs(args.tmp_dir, exist_ok=True)
# Find audio files
if args.file_list:
print(f"Loading file list from {args.file_list}")
with open(args.file_list, 'r') as f:
audio_files = [line.strip() for line in f if line.strip()]
else:
print(f"Searching for audio files in {args.root_path}")
audio_files = find_audio_files(args.root_path)
if not audio_files:
print("No audio files found!")
return
# Filter out existing if requested
if args.skip_existing:
filtered_files = []
for audio_path in audio_files:
base_path = os.path.splitext(audio_path)[0]
latent_path = f"{base_path}_latent2x.pt"
old_latent_path = f"{base_path}_latent.pt"
if os.path.exists(old_latent_path):
os.remove(old_latent_path)
print(f"Removed old latent file: {old_latent_path}")
if not os.path.exists(latent_path):
filtered_files.append(audio_path)
print(f"Skipping {len(audio_files) - len(filtered_files)} existing files")
audio_files = filtered_files
print(f"Found {len(audio_files)} audio files to process")
if not audio_files:
print("No files to process!")
return
# Create output directory for metadata
os.makedirs(args.output_dir, exist_ok=True)
# Determine number of GPUs
if args.num_gpus is None:
args.num_gpus = torch.cuda.device_count()
print(f"Using {args.num_gpus} GPUs")
print(f"Will decode {args.num_decode_samples} random samples per GPU for quality check")
if args.num_gpus == 1:
# Single GPU
extract_latents_gpu(0, 1, args, audio_files)
else:
# Multi-GPU
mp.spawn(
extract_latents_gpu,
args=(args.num_gpus, args, audio_files),
nprocs=args.num_gpus,
join=True
)
# Merge metadata
merge_metadata(args.output_dir, args.tmp_dir, args.num_gpus)
if __name__ == '__main__':
main()