File size: 4,225 Bytes
07b5cfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
#!/usr/bin/env python3
"""
inference_first.py β quick Stage-1 sanity-check for StyleTTS-2
Example
-------
python inference_first.py \
--ckpt logs/pod_90h_30k/epoch_1st_0004.pth \
--ref data/wavs/123_abcd_part042_00.wav \
--text "<evt_gasp> Γ°Ιͺs Ιͺz Ι tΙst ΛsΙntΙns"
It writes preview.wav in the current directory.
"""
import argparse, yaml, torch, torchaudio
from models import build_model, load_ASR_models, load_F0_models
from Utils.PLBERT.util import load_plbert
from utils import recursive_munch, log_norm, length_to_mask
from meldataset import TextCleaner, preprocess
# ββββββββββββββββββββββββββ helpers ββββββββββββββββββββββββββββ
def _restore_batch(x):
"""(T,) βΈ (1,T) or (C,T) βΈ (1,C,T) (handles squeeze in JDCNet)."""
return x.unsqueeze(0) if x.dim() == 1 else x
def _match_len(x, target_len):
"""Crop or zero-pad last axis to target_len."""
cur = x.shape[-1]
if cur > target_len:
return x[..., :target_len]
if cur < target_len:
pad = target_len - cur
return torch.nn.functional.pad(x, (0, pad))
return x
# ββββββββββββββββββββββββββ CLI ββββββββββββββββββββββββββββββββ
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True, help="epoch_1st_*.pth")
p.add_argument("--ref", required=True, help="reference wav (24 kHz mono)")
p.add_argument("--text", required=True, help="IPA / phoneme string")
p.add_argument("--cfg", default="Configs/config_ft_single.yml")
args = p.parse_args()
# βββββββββββββββββ net & cfg βββββββββββββββββββββββββββββββββββ
cfg = yaml.safe_load(open(args.cfg))
sr = cfg["preprocess_params"]["sr"]
device = "cuda"
asr = load_ASR_models(cfg["ASR_path"], cfg["ASR_config"])
f0 = load_F0_models(cfg["F0_path"])
bert = load_plbert(cfg["PLBERT_dir"])
model = build_model(recursive_munch(cfg["model_params"]), asr, f0, bert)
state = torch.load(args.ckpt, map_location="cpu")["net"]
for k in model:
model[k].load_state_dict(state[k], strict=False)
model[k].eval().to(device)
# βββββββββββββββββ prepare inputs ββββββββββββββββββββββββββββββ
cleaner = TextCleaner()
text_ids = torch.LongTensor(cleaner(args.text)).unsqueeze(0).to(device)
input_lengths = torch.LongTensor([text_ids.shape[1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
wav, _ = torchaudio.load(args.ref) # (1,N)
mel_ref = preprocess(wav.squeeze().numpy()).to(device) # (1,80,T)
style = model.style_encoder(mel_ref.unsqueeze(1)) # (1,128)
F0_real, _, _ = model.pitch_extractor(mel_ref.unsqueeze(1))
F0_real = _restore_batch(F0_real) # (1,T')
real_norm = log_norm(mel_ref.unsqueeze(1)).squeeze(1) # (1,T")
real_norm = _restore_batch(real_norm) # (1,T")
# βββββββββββββββββ align lengths βββββββββββββββββββββββββββββββ
enc = model.text_encoder(text_ids, input_lengths, text_mask) # (1,512,L)
enc_len = enc.shape[-1] # L
target = enc_len * 2 # decoder expects 2ΓL
F0_real = _match_len(F0_real, target) # (1,2L)
real_norm = _match_len(real_norm, target) # (1,2L)
# βββββββββββββββββ decode & save βββββββββββββββββββββββββββββββ
with torch.no_grad():
y = model.decoder(enc, F0_real, real_norm, style)
# βββ make it (channels, samples) = (1, T) ββββββββββββββββββββββββββββ
y = y.squeeze(0) # (1, T)
torchaudio.save("preview.wav", y.cpu(), sr)
print("β
wrote preview.wav") |