cfm_svc / preprocess_teacher.py
Hector Li
Initial commit for Hugging Face
df93d13
import argparse
import glob
import os
import numpy as np
import torch
from omegaconf import OmegaConf
from models.codec_wrapper import CodecWrapper
from vits.models import SynthesizerInfer
def load_teacher_model(config_path, checkpoint_path, device):
hp = OmegaConf.load(config_path)
teacher = SynthesizerInfer(
hp.data.filter_length // 2 + 1,
hp.data.segment_size // hp.data.hop_length,
hp,
).to(device)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint["model_g"] if "model_g" in checkpoint else checkpoint
model_state = teacher.state_dict()
load_state = {k: saved_state_dict.get(k, v) for k, v in model_state.items()}
teacher.load_state_dict(load_state, strict=False)
teacher.eval()
for p in teacher.parameters():
p.requires_grad = False
return teacher, int(hp.data.sampling_rate)
def load_cond_features(speaker_dir, file_id, data_root="./data_svc"):
ppg = np.load(f"{data_root}/whisper/{speaker_dir}/{file_id}.ppg.npy")
hubert = np.load(f"{data_root}/hubert/{speaker_dir}/{file_id}.vec.npy")
f0 = np.load(f"{data_root}/pitch/{speaker_dir}/{file_id}.pit.npy")
spk = np.load(f"{data_root}/speaker/{speaker_dir}/{file_id}.spk.npy")
# Match so-vits inference convention: repeat 50Hz features to pitch frame rate.
ppg = np.repeat(ppg, 2, axis=0)
hubert = np.repeat(hubert, 2, axis=0)
# Trim all to shared length.
t = min(len(f0), len(ppg), len(hubert))
f0 = f0[:t]
ppg = ppg[:t]
hubert = hubert[:t]
return (
torch.tensor(ppg, dtype=torch.float32),
torch.tensor(hubert, dtype=torch.float32),
torch.tensor(f0, dtype=torch.float32),
torch.tensor(spk, dtype=torch.float32),
)
@torch.no_grad()
def generate_teacher_codec_targets(args):
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
teacher, teacher_sr = load_teacher_model(args.teacher_config, args.teacher_ckpt, device)
codec = CodecWrapper(latent_dim=1024).to(device).eval()
src_files = glob.glob(os.path.join(args.codec_target_dir, "**", "*.pt"), recursive=True)
if not src_files:
raise RuntimeError(f"No source codec targets found under {args.codec_target_dir}")
os.makedirs(args.out_dir, exist_ok=True)
processed = 0
skipped = 0
for src in src_files:
rel = os.path.relpath(src, args.codec_target_dir)
speaker_dir = os.path.basename(os.path.dirname(src))
file_id = os.path.basename(src).replace(".pt", "").replace("_ztarget", "")
out_path = os.path.join(args.out_dir, rel)
out_dir = os.path.dirname(out_path)
os.makedirs(out_dir, exist_ok=True)
if os.path.isfile(out_path) and not args.overwrite:
continue
try:
ppg, hubert, f0, spk = load_cond_features(speaker_dir, file_id, data_root=args.data_root)
ppg = ppg.unsqueeze(0).to(device)
hubert = hubert.unsqueeze(0).to(device)
pit = f0.unsqueeze(0).to(device) # (1, T)
spk = spk.unsqueeze(0).to(device)
lengths = torch.LongTensor([pit.shape[1]]).to(device)
source = teacher.pitch2source(pit)
wav_teacher = teacher.inference(ppg, hubert, pit, spk, lengths, source)
z_teacher = codec.encode(wav_teacher, sample_rate=teacher_sr) # (1, 1024, T)
torch.save(z_teacher.cpu(), out_path)
processed += 1
if processed % args.log_interval == 0:
print(f"Processed {processed} samples...")
except Exception as e:
skipped += 1
print(f"Skip {speaker_dir}/{file_id}: {e}")
print(f"Teacher preprocessing done. processed={processed}, skipped={skipped}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--teacher_ckpt", type=str, required=True)
parser.add_argument("--teacher_config", type=str, default="configs/base.yaml")
parser.add_argument("--codec_target_dir", type=str, default="./data_svc/codec_targets")
parser.add_argument("--data_root", type=str, default="./data_svc")
parser.add_argument("--out_dir", type=str, default="./data_svc/teacher_codec_targets")
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--log_interval", type=int, default=20)
args = parser.parse_args()
generate_teacher_codec_targets(args)