VITS2-Claude / inference.py
ZLSCompLing's picture
Upload folder using huggingface_hub
b9c74dd verified
#!/usr/bin/env python3
"""
VITS2 Remy - Luxembourgish TTS Inference Script
Usage:
python inference.py "Moien, wéi geet et dir?"
python inference.py "Moien, wéi geet et dir?" -o output.wav
python inference.py "Moien, wéi geet et dir?" --noise_scale 0.5
"""
import argparse
import torch
import scipy.io.wavfile as wavfile
import utils
import commons
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def main():
parser = argparse.ArgumentParser(description="VITS2 Remy TTS")
parser.add_argument("text", type=str, help="Text to synthesize")
parser.add_argument("-o", "--output", type=str, default="output.wav", help="Output WAV file")
parser.add_argument("--noise_scale", type=float, default=0.667, help="Noise scale (default: 0.667)")
parser.add_argument("--noise_scale_w", type=float, default=0.8, help="Noise scale W (default: 0.8)")
parser.add_argument("--length_scale", type=float, default=1.0, help="Length scale (default: 1.0)")
parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
args = parser.parse_args()
device = "cpu" if args.cpu else "cuda"
# Load config
hps = utils.get_hparams_from_file("config.json")
# Load model
if getattr(hps.model, 'use_mel_posterior_encoder', False):
posterior_channels = hps.data.n_mel_channels
else:
posterior_channels = hps.data.filter_length // 2 + 1
net_g = SynthesizerTrn(
len(symbols),
posterior_channels,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
).to(device)
_ = utils.load_checkpoint("model.pth", net_g, None)
net_g.eval()
# Synthesize
text = args.text.lower()
print(f"Synthesizing: {text}")
with torch.no_grad():
stn_tst = get_text(text, hps)
x_tst = stn_tst.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
audio = net_g.infer(
x_tst, x_tst_lengths,
noise_scale=args.noise_scale,
noise_scale_w=args.noise_scale_w,
length_scale=args.length_scale
)[0][0, 0].data.cpu().float().numpy()
wavfile.write(args.output, hps.data.sampling_rate, audio)
print(f"Saved to: {args.output}")
if __name__ == "__main__":
main()