| import os |
| import argparse |
| import glob |
| import torch |
| from pathlib import Path |
| import dac |
| import soundfile as sf |
| import warnings |
|
|
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| class CodecExtractor: |
| def __init__(self, device='cuda'): |
| self.device = device |
| print(f"Loading DAC model on {device}...") |
| self.codec = dac.utils.load_model(tag="latest", model_type="44khz").to(self.device).eval() |
| for param in self.codec.parameters(): |
| param.requires_grad = False |
| print("Initialized Frozen Codec") |
|
|
| @torch.no_grad() |
| def extract_targets(self, wav_tensor, sample_rate): |
| """ |
| Runs the waveform through the frozen codec encoder to get the quantized continuous vectors `z_target`. |
| """ |
| from torchaudio.functional import resample |
| |
| if sample_rate != 44100: |
| wav_tensor = resample(wav_tensor, sample_rate, 44100) |
| |
| wav_tensor = self.codec.preprocess(wav_tensor, 44100) |
| |
| |
| z, _, _, _, _ = self.codec.encode(wav_tensor) |
| return z |
|
|
| def process_corpus(wav_dir, out_dir, device='cuda'): |
| extractor = CodecExtractor(device=device) |
| os.makedirs(out_dir, exist_ok=True) |
| |
| wav_paths = glob.glob(os.path.join(wav_dir, "**/*.wav"), recursive=True) |
| print(f"Found {len(wav_paths)} wav files.") |
| |
| for wav_path in wav_paths: |
| try: |
| wav_data, sr = sf.read(wav_path) |
| |
| if len(wav_data.shape) > 1: |
| wav_data = wav_data[:, 0] |
| wav_tensor = torch.from_numpy(wav_data).unsqueeze(0).unsqueeze(0).float().to(device) |
| |
| z_target = extractor.extract_targets(wav_tensor, sample_rate=sr) |
| |
| file_id = Path(wav_path).stem |
| speaker_dir = Path(wav_path).parent.name |
| out_spk_dir = os.path.join(out_dir, speaker_dir) |
| os.makedirs(out_spk_dir, exist_ok=True) |
| |
| out_path = os.path.join(out_spk_dir, f"{file_id}_ztarget.pt") |
| torch.save(z_target.cpu(), out_path) |
| print(f"Saved extracted target for {speaker_dir}/{file_id}: shape {z_target.shape}") |
| except Exception as e: |
| print(f"Skipping {wav_path} due to error: {e}") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("-w", "--wav_dir", default="./data_svc/waves-32k") |
| parser.add_argument("-o", "--out_dir", default="./data_svc/codec_targets") |
| args = parser.parse_args() |
| |
| if torch.cuda.is_available(): |
| device = "cuda" |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
| process_corpus(args.wav_dir, args.out_dir, device) |
| print("Offline processing complete.") |
|
|