File size: 2,633 Bytes
b9c74dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
#!/usr/bin/env python3
"""
VITS2 Remy - Luxembourgish TTS Inference Script
Usage:
python inference.py "Moien, wéi geet et dir?"
python inference.py "Moien, wéi geet et dir?" -o output.wav
python inference.py "Moien, wéi geet et dir?" --noise_scale 0.5
"""
import argparse
import torch
import scipy.io.wavfile as wavfile
import utils
import commons
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def main():
parser = argparse.ArgumentParser(description="VITS2 Remy TTS")
parser.add_argument("text", type=str, help="Text to synthesize")
parser.add_argument("-o", "--output", type=str, default="output.wav", help="Output WAV file")
parser.add_argument("--noise_scale", type=float, default=0.667, help="Noise scale (default: 0.667)")
parser.add_argument("--noise_scale_w", type=float, default=0.8, help="Noise scale W (default: 0.8)")
parser.add_argument("--length_scale", type=float, default=1.0, help="Length scale (default: 1.0)")
parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
args = parser.parse_args()
device = "cpu" if args.cpu else "cuda"
# Load config
hps = utils.get_hparams_from_file("config.json")
# Load model
if getattr(hps.model, 'use_mel_posterior_encoder', False):
posterior_channels = hps.data.n_mel_channels
else:
posterior_channels = hps.data.filter_length // 2 + 1
net_g = SynthesizerTrn(
len(symbols),
posterior_channels,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
).to(device)
_ = utils.load_checkpoint("model.pth", net_g, None)
net_g.eval()
# Synthesize
text = args.text.lower()
print(f"Synthesizing: {text}")
with torch.no_grad():
stn_tst = get_text(text, hps)
x_tst = stn_tst.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
audio = net_g.infer(
x_tst, x_tst_lengths,
noise_scale=args.noise_scale,
noise_scale_w=args.noise_scale_w,
length_scale=args.length_scale
)[0][0, 0].data.cpu().float().numpy()
wavfile.write(args.output, hps.data.sampling_rate, audio)
print(f"Saved to: {args.output}")
if __name__ == "__main__":
main()
|