cfm_svc / data /codec_targets.py
Hector Li
Initial commit for Hugging Face
df93d13
import os
import argparse
import glob
import torch
from pathlib import Path
import dac
import soundfile as sf
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
class CodecExtractor:
def __init__(self, device='cuda'):
self.device = device
print(f"Loading DAC model on {device}...")
self.codec = dac.utils.load_model(tag="latest", model_type="44khz").to(self.device).eval()
for param in self.codec.parameters():
param.requires_grad = False
print("Initialized Frozen Codec")
@torch.no_grad()
def extract_targets(self, wav_tensor, sample_rate):
"""
Runs the waveform through the frozen codec encoder to get the quantized continuous vectors `z_target`.
"""
from torchaudio.functional import resample
# DAC 44khz model requires exactly 44100 Hz
if sample_rate != 44100:
wav_tensor = resample(wav_tensor, sample_rate, 44100)
wav_tensor = self.codec.preprocess(wav_tensor, 44100)
# 'encode' returns z (continuous), codes (discrete), latents, _, _
z, _, _, _, _ = self.codec.encode(wav_tensor)
return z
def process_corpus(wav_dir, out_dir, device='cuda'):
extractor = CodecExtractor(device=device)
os.makedirs(out_dir, exist_ok=True)
wav_paths = glob.glob(os.path.join(wav_dir, "**/*.wav"), recursive=True)
print(f"Found {len(wav_paths)} wav files.")
for wav_path in wav_paths:
try:
wav_data, sr = sf.read(wav_path)
# Ensure shape is (1, 1, T)
if len(wav_data.shape) > 1:
wav_data = wav_data[:, 0] # take first channel
wav_tensor = torch.from_numpy(wav_data).unsqueeze(0).unsqueeze(0).float().to(device)
z_target = extractor.extract_targets(wav_tensor, sample_rate=sr)
file_id = Path(wav_path).stem
speaker_dir = Path(wav_path).parent.name
out_spk_dir = os.path.join(out_dir, speaker_dir)
os.makedirs(out_spk_dir, exist_ok=True)
out_path = os.path.join(out_spk_dir, f"{file_id}_ztarget.pt")
torch.save(z_target.cpu(), out_path)
print(f"Saved extracted target for {speaker_dir}/{file_id}: shape {z_target.shape}")
except Exception as e:
print(f"Skipping {wav_path} due to error: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-w", "--wav_dir", default="./data_svc/waves-32k")
parser.add_argument("-o", "--out_dir", default="./data_svc/codec_targets")
args = parser.parse_args()
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
process_corpus(args.wav_dir, args.out_dir, device)
print("Offline processing complete.")