primepake commited on
Commit
f6beba0
·
1 Parent(s): ea8cd35

extract dac latent

Browse files
Files changed (2) hide show
  1. dac-vae/extract.sh +17 -0
  2. dac-vae/extract_dac_latents.py +447 -0
dac-vae/extract.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python extract_dac_latents.py \
2
+ --root_path /data/dataset \
3
+ --file_list files.txt \
4
+ --output_dir /data/dataset/metadata \
5
+ --checkpoint ./checkpoint.pt \
6
+ --config ./config.yml \
7
+ --num_gpus 1 \
8
+ --num_decode_samples 10
9
+
10
+
11
+ python extract_dac_latents.py \
12
+ --root_path data_test \
13
+ --output_dir data_test/metadata \
14
+ --checkpoint ./checkpoint.pt \
15
+ --config ./config.yml \
16
+ --num_gpus 1 \
17
+ --num_decode_samples 10
dac-vae/extract_dac_latents.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extract_dac_latents.py - With random decoding check
2
+
3
+ import os
4
+ import glob
5
+ import argparse
6
+ import torch
7
+ import torch.multiprocessing as mp
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import librosa
12
+ from pathlib import Path
13
+ from tqdm import tqdm
14
+ import yaml
15
+ import json
16
+ from collections import defaultdict
17
+ import random
18
+ import shutil
19
+
20
+ def process_single_audio(audio_path, model, sample_rate, device):
21
+ """Process a single audio file without padding"""
22
+ try:
23
+ # Load audio
24
+ audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True)
25
+
26
+ # Convert to tensor [1, 1, T]
27
+ audio_tensor = torch.from_numpy(audio).float()
28
+ audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0).to(device)
29
+
30
+ # Normalize
31
+ audio_tensor = torch.clamp(audio_tensor, -1.0, 1.0)
32
+
33
+ # Encode
34
+ with torch.no_grad():
35
+ z, mu, logs = model.encode(audio_tensor, sample_rate)
36
+
37
+ return {
38
+ 'success': True,
39
+ 'z': z.cpu(),
40
+ 'mu': mu.cpu(),
41
+ 'logs': logs.cpu(),
42
+ 'duration': len(audio) / sample_rate,
43
+ 'samples': len(audio),
44
+ 'compression_ratio': audio_tensor.shape[-1] // z.shape[-1],
45
+ 'original_audio': audio # Keep original audio for comparison
46
+ }
47
+
48
+ except Exception as e:
49
+ print(f"Error processing {audio_path}: {e}")
50
+ return {
51
+ 'success': False,
52
+ 'error': str(e),
53
+ 'path': audio_path
54
+ }
55
+
56
+
57
+ def decode_and_save_sample(model, latent_data, original_audio, audio_path, tmp_dir, device):
58
+ """Decode a latent and save both original and reconstructed audio for comparison"""
59
+ try:
60
+ # Extract info from path
61
+ base_name = os.path.basename(audio_path)
62
+ name_without_ext = os.path.splitext(base_name)[0]
63
+
64
+ # Create subdirectory in tmp for this sample
65
+ sample_dir = os.path.join(tmp_dir, name_without_ext)
66
+ os.makedirs(sample_dir, exist_ok=True)
67
+
68
+ # Decode latent
69
+ z = latent_data['z'].to(device)
70
+ z = z.unsqueeze(0)
71
+ print('z shape: ', z.shape)
72
+ with torch.no_grad():
73
+ reconstructed = model.decode(z)
74
+
75
+ # Convert to numpy
76
+ reconstructed = reconstructed.squeeze().cpu().numpy()
77
+ if reconstructed.ndim == 2:
78
+ reconstructed = reconstructed[0]
79
+ reconstructed = np.clip(reconstructed, -1.0, 1.0)
80
+
81
+ # Save original audio
82
+ original_path = os.path.join(sample_dir, f"{name_without_ext}_original.wav")
83
+ sf.write(original_path, original_audio, latent_data['sample_rate'])
84
+
85
+ # Save reconstructed audio
86
+ reconstructed_path = os.path.join(sample_dir, f"{name_without_ext}_reconstructed.wav")
87
+ sf.write(reconstructed_path, reconstructed, latent_data['sample_rate'])
88
+
89
+ # Calculate metrics
90
+ min_len = min(len(original_audio), len(reconstructed))
91
+ original_trimmed = original_audio[:min_len]
92
+ reconstructed_trimmed = reconstructed[:min_len]
93
+
94
+ mse = np.mean((original_trimmed - reconstructed_trimmed) ** 2)
95
+ snr = 10 * np.log10(np.var(original_trimmed) / (mse + 1e-10))
96
+
97
+ # Save info file
98
+ info = {
99
+ 'original_path': audio_path,
100
+ 'original_duration': len(original_audio) / latent_data['sample_rate'],
101
+ 'reconstructed_duration': len(reconstructed) / latent_data['sample_rate'],
102
+ 'latent_shape': latent_data['latent_shape'],
103
+ 'compression_ratio': latent_data['compression_ratio'],
104
+ 'mse': float(mse),
105
+ 'snr': float(snr)
106
+ }
107
+
108
+ info_path = os.path.join(sample_dir, 'info.json')
109
+ with open(info_path, 'w') as f:
110
+ json.dump(info, f, indent=2)
111
+
112
+ print(f"Sample saved to {sample_dir} - SNR: {snr:.2f} dB, MSE: {mse:.6f}")
113
+
114
+ return True, info
115
+
116
+ except Exception as e:
117
+ print(f"Error decoding sample: {e}")
118
+ return False, {'error': str(e)}
119
+
120
+
121
+ def extract_latents_gpu(rank, world_size, args, audio_files):
122
+ """Extract latents on a single GPU"""
123
+
124
+ # Setup device
125
+ device = torch.device(f'cuda:{rank}')
126
+ torch.cuda.set_device(device)
127
+
128
+ # Load DAC model
129
+ from model import DACVAE as VAE
130
+
131
+ print(f"[GPU {rank}] Loading DAC model...")
132
+ with open(args.config, 'r') as f:
133
+ config = yaml.safe_load(f)
134
+
135
+ model = VAE(**config['vae'])
136
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
137
+ if 'generator' in checkpoint:
138
+ model.load_state_dict(checkpoint['generator'])
139
+ else:
140
+ model.load_state_dict(checkpoint)
141
+
142
+ model.to(device)
143
+ model.eval()
144
+ sample_rate = config['vae']['sample_rate']
145
+
146
+ # Split files across GPUs
147
+ files_per_gpu = len(audio_files) // world_size
148
+ start_idx = rank * files_per_gpu
149
+ end_idx = start_idx + files_per_gpu if rank < world_size - 1 else len(audio_files)
150
+ gpu_files = audio_files[start_idx:end_idx]
151
+
152
+ print(f"[GPU {rank}] Processing {len(gpu_files)} files...")
153
+
154
+ # Create tmp directory for this GPU
155
+ tmp_dir = os.path.join(args.tmp_dir, f'gpu_{rank}')
156
+ os.makedirs(tmp_dir, exist_ok=True)
157
+
158
+ # Randomly select files for decoding check
159
+ num_samples = min(args.num_decode_samples, len(gpu_files))
160
+ sample_indices = random.sample(range(len(gpu_files)), num_samples)
161
+
162
+ # Process files one by one
163
+ results = []
164
+ decode_results = []
165
+
166
+ for idx, audio_path in enumerate(tqdm(gpu_files, desc=f'GPU {rank}', position=rank)):
167
+ # Process single audio
168
+ result = process_single_audio(audio_path, model, sample_rate, device)
169
+
170
+ if result['success']:
171
+ # Create output path: a/b/c/d.wav -> a/b/c/d_latent.pt
172
+ base_path = os.path.splitext(audio_path)[0] # Remove extension
173
+ output_path = f"{base_path}_latent.pt"
174
+
175
+ # Create directory if it doesn't exist
176
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
177
+
178
+ # Extract data
179
+ z = result['z'].squeeze(0) # Remove batch dim
180
+ mu = result['mu'].squeeze(0)
181
+ logs = result['logs'].squeeze(0)
182
+
183
+ # Save as torch tensor
184
+ latent_data = {
185
+ 'z': z,
186
+ 'mu': mu,
187
+ 'logs': logs,
188
+ 'sample_rate': sample_rate,
189
+ 'compression_ratio': result['compression_ratio'],
190
+ 'original_duration': result['duration'],
191
+ 'original_samples': result['samples'],
192
+ 'latent_shape': list(z.shape),
193
+ 'original_path': audio_path
194
+ }
195
+
196
+ torch.save(latent_data, output_path)
197
+
198
+ results.append({
199
+ 'path': audio_path,
200
+ 'output_path': output_path,
201
+ 'latent_shape': latent_data['latent_shape'],
202
+ 'duration': result['duration'],
203
+ 'compression_ratio': result['compression_ratio']
204
+ })
205
+
206
+ # Check if this is a sample to decode
207
+ if idx in sample_indices:
208
+ print(f"\n[GPU {rank}] Decoding sample {idx}: {os.path.basename(audio_path)}")
209
+ success, decode_info = decode_and_save_sample(
210
+ model, latent_data, result['original_audio'],
211
+ audio_path, tmp_dir, device
212
+ )
213
+ if success:
214
+ decode_results.append(decode_info)
215
+
216
+ if rank == 0 and len(results) % 100 == 0:
217
+ print(f"[GPU {rank}] Processed {len(results)} files...")
218
+ else:
219
+ print(f"[GPU {rank}] Failed to process: {audio_path}")
220
+ results.append({
221
+ 'path': audio_path,
222
+ 'error': result['error'],
223
+ 'status': 'failed'
224
+ })
225
+
226
+ # Save metadata for this GPU
227
+ metadata_path = os.path.join(args.output_dir, f'metadata_gpu{rank}.json')
228
+ with open(metadata_path, 'w') as f:
229
+ json.dump(results, f, indent=2)
230
+
231
+ # Save decode results
232
+ if decode_results:
233
+ decode_path = os.path.join(tmp_dir, 'decode_results.json')
234
+ with open(decode_path, 'w') as f:
235
+ json.dump({
236
+ 'num_samples': len(decode_results),
237
+ 'samples': decode_results,
238
+ 'average_snr': np.mean([r['snr'] for r in decode_results if 'snr' in r]),
239
+ 'average_mse': np.mean([r['mse'] for r in decode_results if 'mse' in r])
240
+ }, f, indent=2)
241
+
242
+ print(f"[GPU {rank}] Completed processing {len(results)} files")
243
+ if decode_results:
244
+ avg_snr = np.mean([r['snr'] for r in decode_results if 'snr' in r])
245
+ print(f"[GPU {rank}] Average SNR for decoded samples: {avg_snr:.2f} dB")
246
+
247
+
248
+ def find_audio_files(root_path, extensions=['.wav', '.flac', '.mp3']):
249
+ """Find all audio files in root_path with various structures"""
250
+ audio_files = []
251
+
252
+ # Check if root_path is a file
253
+ if os.path.isfile(root_path):
254
+ if any(root_path.endswith(ext) for ext in extensions):
255
+ return [root_path]
256
+
257
+ # Search for audio files
258
+ for ext in extensions:
259
+ # Direct files in root
260
+ audio_files.extend(glob.glob(os.path.join(root_path, f'*{ext}')))
261
+
262
+ # Recursive search
263
+ audio_files.extend(glob.glob(os.path.join(root_path, '**', f'*{ext}'), recursive=True))
264
+
265
+ # Remove duplicates and sort
266
+ audio_files = sorted(list(set(audio_files)))
267
+
268
+ return audio_files
269
+
270
+
271
+ def merge_metadata(output_dir, tmp_dir, world_size):
272
+ """Merge metadata from all GPUs"""
273
+ all_results = []
274
+ failed_files = []
275
+ all_decode_results = []
276
+
277
+ for rank in range(world_size):
278
+ metadata_path = os.path.join(output_dir, f'metadata_gpu{rank}.json')
279
+ if os.path.exists(metadata_path):
280
+ with open(metadata_path, 'r') as f:
281
+ results = json.load(f)
282
+ for r in results:
283
+ if 'error' in r:
284
+ failed_files.append(r)
285
+ else:
286
+ all_results.append(r)
287
+ # Remove individual metadata files
288
+ os.remove(metadata_path)
289
+
290
+ # Load decode results
291
+ decode_path = os.path.join(tmp_dir, f'gpu_{rank}', 'decode_results.json')
292
+ if os.path.exists(decode_path):
293
+ with open(decode_path, 'r') as f:
294
+ decode_data = json.load(f)
295
+ all_decode_results.extend(decode_data['samples'])
296
+
297
+ # Save merged metadata
298
+ metadata_path = os.path.join(output_dir, 'metadata.json')
299
+ with open(metadata_path, 'w') as f:
300
+ json.dump({
301
+ 'total_files': len(all_results),
302
+ 'failed_files': len(failed_files),
303
+ 'files': all_results
304
+ }, f, indent=2)
305
+
306
+ # Save failed files list if any
307
+ if failed_files:
308
+ failed_path = os.path.join(output_dir, 'failed_files.json')
309
+ with open(failed_path, 'w') as f:
310
+ json.dump(failed_files, f, indent=2)
311
+
312
+ # Create summary statistics
313
+ total_duration = sum(r['duration'] for r in all_results)
314
+ latent_dims = defaultdict(int)
315
+ compression_ratios = defaultdict(int)
316
+
317
+ for r in all_results:
318
+ shape_key = str(r['latent_shape'])
319
+ latent_dims[shape_key] += 1
320
+ compression_ratios[r['compression_ratio']] += 1
321
+
322
+ summary = {
323
+ 'total_files': len(all_results),
324
+ 'failed_files': len(failed_files),
325
+ 'total_duration_hours': total_duration / 3600,
326
+ 'latent_dimensions': dict(latent_dims),
327
+ 'compression_ratios': dict(compression_ratios),
328
+ 'average_duration': total_duration / len(all_results) if all_results else 0,
329
+ 'decode_samples': len(all_decode_results)
330
+ }
331
+
332
+ if all_decode_results:
333
+ summary['average_snr'] = np.mean([r['snr'] for r in all_decode_results if 'snr' in r])
334
+ summary['average_mse'] = np.mean([r['mse'] for r in all_decode_results if 'mse' in r])
335
+
336
+ summary_path = os.path.join(output_dir, 'summary.json')
337
+ with open(summary_path, 'w') as f:
338
+ json.dump(summary, f, indent=2)
339
+
340
+ print(f"\nProcessing complete!")
341
+ print(f"Successfully processed: {len(all_results)} files")
342
+ print(f"Failed: {len(failed_files)} files")
343
+ print(f"Total duration: {total_duration/3600:.2f} hours")
344
+ print(f"Average duration: {summary['average_duration']:.2f} seconds")
345
+ print(f"Compression ratios: {dict(compression_ratios)}")
346
+
347
+ if all_decode_results:
348
+ print(f"\nDecode Quality Check:")
349
+ print(f"Samples decoded: {len(all_decode_results)}")
350
+ print(f"Average SNR: {summary['average_snr']:.2f} dB")
351
+ print(f"Average MSE: {summary['average_mse']:.6f}")
352
+ print(f"Check tmp/ folder for audio comparisons")
353
+
354
+ print(f"\nResults saved to: {output_dir}")
355
+
356
+
357
+ def main():
358
+ parser = argparse.ArgumentParser(description='Extract DAC latents with multi-GPU support')
359
+ parser.add_argument('--root_path', type=str, required=True,
360
+ help='Root path containing audio files')
361
+ parser.add_argument('--output_dir', type=str, required=True,
362
+ help='Directory to save metadata (latents saved alongside audio)')
363
+ parser.add_argument('--checkpoint', type=str, required=True,
364
+ help='Path to DAC checkpoint')
365
+ parser.add_argument('--config', type=str, required=True,
366
+ help='Path to DAC config')
367
+ parser.add_argument('--num_gpus', type=int, default=None,
368
+ help='Number of GPUs to use (default: all available)')
369
+ parser.add_argument('--file_list', type=str, default=None,
370
+ help='Optional text file containing list of audio paths')
371
+ parser.add_argument('--skip_existing', action='store_true',
372
+ help='Skip files that already have latents')
373
+ parser.add_argument('--tmp_dir', type=str, default='./tmp',
374
+ help='Directory to save decoded samples for checking')
375
+ parser.add_argument('--num_decode_samples', type=int, default=5,
376
+ help='Number of random samples to decode per GPU for quality check')
377
+ parser.add_argument('--clean_tmp', action='store_true',
378
+ help='Clean tmp directory before starting')
379
+
380
+ args = parser.parse_args()
381
+
382
+ # Clean tmp directory if requested
383
+ if args.clean_tmp and os.path.exists(args.tmp_dir):
384
+ print(f"Cleaning tmp directory: {args.tmp_dir}")
385
+ shutil.rmtree(args.tmp_dir)
386
+
387
+ # Create tmp directory
388
+ os.makedirs(args.tmp_dir, exist_ok=True)
389
+
390
+ # Find audio files
391
+ if args.file_list:
392
+ print(f"Loading file list from {args.file_list}")
393
+ with open(args.file_list, 'r') as f:
394
+ audio_files = [line.strip() for line in f if line.strip()]
395
+ else:
396
+ print(f"Searching for audio files in {args.root_path}")
397
+ audio_files = find_audio_files(args.root_path)
398
+
399
+ if not audio_files:
400
+ print("No audio files found!")
401
+ return
402
+
403
+ # Filter out existing if requested
404
+ if args.skip_existing:
405
+ filtered_files = []
406
+ for audio_path in audio_files:
407
+ base_path = os.path.splitext(audio_path)[0]
408
+ latent_path = f"{base_path}_latent.pt"
409
+ if not os.path.exists(latent_path):
410
+ filtered_files.append(audio_path)
411
+ print(f"Skipping {len(audio_files) - len(filtered_files)} existing files")
412
+ audio_files = filtered_files
413
+
414
+ print(f"Found {len(audio_files)} audio files to process")
415
+
416
+ if not audio_files:
417
+ print("No files to process!")
418
+ return
419
+
420
+ # Create output directory for metadata
421
+ os.makedirs(args.output_dir, exist_ok=True)
422
+
423
+ # Determine number of GPUs
424
+ if args.num_gpus is None:
425
+ args.num_gpus = torch.cuda.device_count()
426
+
427
+ print(f"Using {args.num_gpus} GPUs")
428
+ print(f"Will decode {args.num_decode_samples} random samples per GPU for quality check")
429
+
430
+ if args.num_gpus == 1:
431
+ # Single GPU
432
+ extract_latents_gpu(0, 1, args, audio_files)
433
+ else:
434
+ # Multi-GPU
435
+ mp.spawn(
436
+ extract_latents_gpu,
437
+ args=(args.num_gpus, args, audio_files),
438
+ nprocs=args.num_gpus,
439
+ join=True
440
+ )
441
+
442
+ # Merge metadata
443
+ merge_metadata(args.output_dir, args.tmp_dir, args.num_gpus)
444
+
445
+
446
+ if __name__ == '__main__':
447
+ main()