| import argparse |
| import glob |
| import os |
|
|
| import numpy as np |
| import torch |
| from omegaconf import OmegaConf |
|
|
| from models.codec_wrapper import CodecWrapper |
| from vits.models import SynthesizerInfer |
|
|
|
|
| def load_teacher_model(config_path, checkpoint_path, device): |
| hp = OmegaConf.load(config_path) |
| teacher = SynthesizerInfer( |
| hp.data.filter_length // 2 + 1, |
| hp.data.segment_size // hp.data.hop_length, |
| hp, |
| ).to(device) |
|
|
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| saved_state_dict = checkpoint["model_g"] if "model_g" in checkpoint else checkpoint |
| model_state = teacher.state_dict() |
| load_state = {k: saved_state_dict.get(k, v) for k, v in model_state.items()} |
| teacher.load_state_dict(load_state, strict=False) |
| teacher.eval() |
| for p in teacher.parameters(): |
| p.requires_grad = False |
| return teacher, int(hp.data.sampling_rate) |
|
|
|
|
| def load_cond_features(speaker_dir, file_id, data_root="./data_svc"): |
| ppg = np.load(f"{data_root}/whisper/{speaker_dir}/{file_id}.ppg.npy") |
| hubert = np.load(f"{data_root}/hubert/{speaker_dir}/{file_id}.vec.npy") |
| f0 = np.load(f"{data_root}/pitch/{speaker_dir}/{file_id}.pit.npy") |
| spk = np.load(f"{data_root}/speaker/{speaker_dir}/{file_id}.spk.npy") |
|
|
| |
| ppg = np.repeat(ppg, 2, axis=0) |
| hubert = np.repeat(hubert, 2, axis=0) |
|
|
| |
| t = min(len(f0), len(ppg), len(hubert)) |
| f0 = f0[:t] |
| ppg = ppg[:t] |
| hubert = hubert[:t] |
|
|
| return ( |
| torch.tensor(ppg, dtype=torch.float32), |
| torch.tensor(hubert, dtype=torch.float32), |
| torch.tensor(f0, dtype=torch.float32), |
| torch.tensor(spk, dtype=torch.float32), |
| ) |
|
|
|
|
| @torch.no_grad() |
| def generate_teacher_codec_targets(args): |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
|
|
| teacher, teacher_sr = load_teacher_model(args.teacher_config, args.teacher_ckpt, device) |
| codec = CodecWrapper(latent_dim=1024).to(device).eval() |
|
|
| src_files = glob.glob(os.path.join(args.codec_target_dir, "**", "*.pt"), recursive=True) |
| if not src_files: |
| raise RuntimeError(f"No source codec targets found under {args.codec_target_dir}") |
|
|
| os.makedirs(args.out_dir, exist_ok=True) |
|
|
| processed = 0 |
| skipped = 0 |
| for src in src_files: |
| rel = os.path.relpath(src, args.codec_target_dir) |
| speaker_dir = os.path.basename(os.path.dirname(src)) |
| file_id = os.path.basename(src).replace(".pt", "").replace("_ztarget", "") |
|
|
| out_path = os.path.join(args.out_dir, rel) |
| out_dir = os.path.dirname(out_path) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| if os.path.isfile(out_path) and not args.overwrite: |
| continue |
|
|
| try: |
| ppg, hubert, f0, spk = load_cond_features(speaker_dir, file_id, data_root=args.data_root) |
|
|
| ppg = ppg.unsqueeze(0).to(device) |
| hubert = hubert.unsqueeze(0).to(device) |
| pit = f0.unsqueeze(0).to(device) |
| spk = spk.unsqueeze(0).to(device) |
| lengths = torch.LongTensor([pit.shape[1]]).to(device) |
|
|
| source = teacher.pitch2source(pit) |
| wav_teacher = teacher.inference(ppg, hubert, pit, spk, lengths, source) |
|
|
| z_teacher = codec.encode(wav_teacher, sample_rate=teacher_sr) |
| torch.save(z_teacher.cpu(), out_path) |
| processed += 1 |
| if processed % args.log_interval == 0: |
| print(f"Processed {processed} samples...") |
| except Exception as e: |
| skipped += 1 |
| print(f"Skip {speaker_dir}/{file_id}: {e}") |
|
|
| print(f"Teacher preprocessing done. processed={processed}, skipped={skipped}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--teacher_ckpt", type=str, required=True) |
| parser.add_argument("--teacher_config", type=str, default="configs/base.yaml") |
| parser.add_argument("--codec_target_dir", type=str, default="./data_svc/codec_targets") |
| parser.add_argument("--data_root", type=str, default="./data_svc") |
| parser.add_argument("--out_dir", type=str, default="./data_svc/teacher_codec_targets") |
| parser.add_argument("--overwrite", action="store_true") |
| parser.add_argument("--log_interval", type=int, default=20) |
| args = parser.parse_args() |
| generate_teacher_codec_targets(args) |
|
|