import os import argparse import glob import torch from pathlib import Path import dac import soundfile as sf import warnings warnings.filterwarnings("ignore", category=FutureWarning) class CodecExtractor: def __init__(self, device='cuda'): self.device = device print(f"Loading DAC model on {device}...") self.codec = dac.utils.load_model(tag="latest", model_type="44khz").to(self.device).eval() for param in self.codec.parameters(): param.requires_grad = False print("Initialized Frozen Codec") @torch.no_grad() def extract_targets(self, wav_tensor, sample_rate): """ Runs the waveform through the frozen codec encoder to get the quantized continuous vectors `z_target`. """ from torchaudio.functional import resample # DAC 44khz model requires exactly 44100 Hz if sample_rate != 44100: wav_tensor = resample(wav_tensor, sample_rate, 44100) wav_tensor = self.codec.preprocess(wav_tensor, 44100) # 'encode' returns z (continuous), codes (discrete), latents, _, _ z, _, _, _, _ = self.codec.encode(wav_tensor) return z def process_corpus(wav_dir, out_dir, device='cuda'): extractor = CodecExtractor(device=device) os.makedirs(out_dir, exist_ok=True) wav_paths = glob.glob(os.path.join(wav_dir, "**/*.wav"), recursive=True) print(f"Found {len(wav_paths)} wav files.") for wav_path in wav_paths: try: wav_data, sr = sf.read(wav_path) # Ensure shape is (1, 1, T) if len(wav_data.shape) > 1: wav_data = wav_data[:, 0] # take first channel wav_tensor = torch.from_numpy(wav_data).unsqueeze(0).unsqueeze(0).float().to(device) z_target = extractor.extract_targets(wav_tensor, sample_rate=sr) file_id = Path(wav_path).stem speaker_dir = Path(wav_path).parent.name out_spk_dir = os.path.join(out_dir, speaker_dir) os.makedirs(out_spk_dir, exist_ok=True) out_path = os.path.join(out_spk_dir, f"{file_id}_ztarget.pt") torch.save(z_target.cpu(), out_path) print(f"Saved extracted target for {speaker_dir}/{file_id}: shape {z_target.shape}") except Exception as e: print(f"Skipping {wav_path} due to error: {e}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-w", "--wav_dir", default="./data_svc/waves-32k") parser.add_argument("-o", "--out_dir", default="./data_svc/codec_targets") args = parser.parse_args() if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" process_corpus(args.wav_dir, args.out_dir, device) print("Offline processing complete.")