File size: 4,225 Bytes
07b5cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
"""
inference_first.py β€” quick Stage-1 sanity-check for StyleTTS-2

Example
-------
python inference_first.py \
      --ckpt logs/pod_90h_30k/epoch_1st_0004.pth \
      --ref  data/wavs/123_abcd_part042_00.wav \
      --text "<evt_gasp> Γ°Ιͺs Ιͺz ɐ tΙ›st ˈsΙ›ntΙ™ns"
It writes preview.wav in the current directory.
"""
import argparse, yaml, torch, torchaudio
from models            import build_model, load_ASR_models, load_F0_models
from Utils.PLBERT.util import load_plbert
from utils             import recursive_munch, log_norm, length_to_mask
from meldataset        import TextCleaner, preprocess

# ────────────────────────── helpers ────────────────────────────
def _restore_batch(x):
    """(T,) β–Έ (1,T)  or  (C,T) β–Έ (1,C,T)   (handles squeeze in JDCNet)."""
    return x.unsqueeze(0) if x.dim() == 1 else x

def _match_len(x, target_len):
    """Crop or zero-pad last axis to target_len."""
    cur = x.shape[-1]
    if cur > target_len:
        return x[..., :target_len]
    if cur < target_len:
        pad = target_len - cur
        return torch.nn.functional.pad(x, (0, pad))
    return x

# ────────────────────────── CLI ────────────────────────────────
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True, help="epoch_1st_*.pth")
p.add_argument("--ref",  required=True, help="reference wav (24 kHz mono)")
p.add_argument("--text", required=True, help="IPA / phoneme string")
p.add_argument("--cfg",  default="Configs/config_ft_single.yml")
args = p.parse_args()

# ───────────────── net & cfg ───────────────────────────────────
cfg    = yaml.safe_load(open(args.cfg))
sr     = cfg["preprocess_params"]["sr"]
device = "cuda"

asr   = load_ASR_models(cfg["ASR_path"], cfg["ASR_config"])
f0    = load_F0_models(cfg["F0_path"])
bert  = load_plbert(cfg["PLBERT_dir"])
model = build_model(recursive_munch(cfg["model_params"]), asr, f0, bert)

state = torch.load(args.ckpt, map_location="cpu")["net"]
for k in model:
    model[k].load_state_dict(state[k], strict=False)
    model[k].eval().to(device)

# ───────────────── prepare inputs ──────────────────────────────
cleaner        = TextCleaner()
text_ids       = torch.LongTensor(cleaner(args.text)).unsqueeze(0).to(device)
input_lengths  = torch.LongTensor([text_ids.shape[1]]).to(device)
text_mask      = length_to_mask(input_lengths).to(device)

wav, _  = torchaudio.load(args.ref)                     # (1,N)
mel_ref = preprocess(wav.squeeze().numpy()).to(device)  # (1,80,T)

style = model.style_encoder(mel_ref.unsqueeze(1))       # (1,128)

F0_real, _, _ = model.pitch_extractor(mel_ref.unsqueeze(1))
F0_real       = _restore_batch(F0_real)                 # (1,T')

real_norm     = log_norm(mel_ref.unsqueeze(1)).squeeze(1)  # (1,T")
real_norm     = _restore_batch(real_norm)               # (1,T")

# ───────────────── align lengths ───────────────────────────────
enc      = model.text_encoder(text_ids, input_lengths, text_mask)  # (1,512,L)
enc_len  = enc.shape[-1]                             # L
target   = enc_len * 2                               # decoder expects 2Γ—L

F0_real   = _match_len(F0_real,   target)            # (1,2L)
real_norm = _match_len(real_norm, target)            # (1,2L)

# ───────────────── decode & save ───────────────────────────────
with torch.no_grad():
    y = model.decoder(enc, F0_real, real_norm, style)

# ─── make it (channels, samples) = (1, T) ────────────────────────────
y = y.squeeze(0)        # (1, T)

torchaudio.save("preview.wav", y.cpu(), sr)
print("βœ…  wrote preview.wav")