XTTSv2-multi / generate_parallel_dataset.py
Rasmus Lellep
initial commit
e9a0669
import json
import os
import random
import pandas as pd
import numpy as np
import torch
import torchaudio
from tqdm import tqdm
from scipy.io.wavfile import write
import argparse
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
device = "cuda:0" if torch.cuda.is_available() else "cpu"
lang_codes = {
'English': 'en',
'Estonian': 'et',
'Russian': 'ru',
}
ref_metas = {
'en': '/scratch/project_465001704/data/eng/commonvoice/metadata.csv',
'et': '/scratch/project_465001704/data/est/commonvoice-14.0/metadata.csv',
'ru': '/scratch/project_465001704/data/rus/commonvoice-20.0/metadata.csv',
}
refs = {}
def load_ref_list(languages):
for language in languages:
refs[language] = pd.read_csv(ref_metas[language], sep='|')['audio_file'].tolist()
def create_xtts_trainer_parser():
parser = argparse.ArgumentParser(description="Arguments for XTTS runner")
parser.add_argument("--model_folder", type=str, default='/scratch/project_465001704/output/xtts-gpt/run/training/GPT_XTTS_FT-November-24-2025_01+29AM-8e59ec3', #required=True,
help="Path of model file")
parser.add_argument("--model_name", type=str, default='best_model', # required=True,
help="Name of model file")
parser.add_argument("--vocab_path", type=str, default='/project/project_465001704/rlellep/repos/XTTSv2-Finetuning-for-New-Languages/vocabs/vocab_et-100.json', # required=True,
help="Path of vocab file")
parser.add_argument("--languages", nargs='+', type=str, default=["en", "et"], # required=True,
help="language1 language2")
parser.add_argument("--dataset_meta", type=str, default='/scratch/project_465001704/data/to_synth_split/en_et-EOPC_00.jsonl', # required=True,
help="Path of metadata file")
parser.add_argument("--output_folder", type=str, default='/scratch/project_465001704/output/synth/en_et_EOPC_00', # required=True,
help="Path of output folder")
parser.add_argument("--stream", type=bool, default=True,
help="Run model in stream mode.")
parser.add_argument("--start_id", type=int, default=0)
return parser
def load_model(model_folder, model_name, vocab_file):
# Model paths
xtts_checkpoint = os.path.join(model_folder, f"{model_name}.pth")
xtts_config = os.path.join(model_folder, "config.json")
# Load model
config = XttsConfig()
config.load_json(xtts_config)
XTTS_MODEL = Xtts.init_from_config(config)
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=vocab_file, use_deepspeed=False)
XTTS_MODEL.to(device)
return XTTS_MODEL
def get_random_reference(language):
while True:
ref_file = random.choice(refs[language])
try:
metadata = torchaudio.info(ref_file)
duration = metadata.num_frames / metadata.sample_rate
if duration > 1:
return ref_file
except Exception as e:
continue
def reference_latents_and_embedding(language):
ref_clip = get_random_reference(language)
return model.get_conditioning_latents(
audio_path=ref_clip,
gpt_cond_len=model.config.gpt_cond_len,
max_ref_length=model.config.max_ref_len,
sound_norm_refs=model.config.sound_norm_refs,
)
def perform_synthesis(model, gpt_cond_latent, speaker_embedding, text, language, stream=True):
wav_chunks = []
if stream:
for chunk in model.inference_stream(
text=text,
language=language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=0.1,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=10,
top_p=0.3,
):
if chunk is not None:
wav_chunks.append(chunk)
else:
wav_chunk = model.inference(
text=text,
language=language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=0.1,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=10,
top_p=0.3,
)
wav_chunks.append(torch.tensor(wav_chunk["wav"]))
out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0)[0].detach().cpu().numpy()
return out_wav
def write_output(clip, output_folder, language, id):
relative_path = os.path.join(language, f'{language}_{id:07}.wav')
write(os.path.join(output_folder, relative_path), 24000, clip)
return relative_path
if __name__ == "__main__":
parser = create_xtts_trainer_parser()
args = parser.parse_args()
src_metadata_1 = os.path.join(args.output_folder, f'{args.languages[0]}-src.csv')
src_metadata_2 = os.path.join(args.output_folder, f'{args.languages[1]}-src.csv')
tgt_metadata_1 = os.path.join(args.output_folder, f'{args.languages[0]}-tgt.csv')
tgt_metadata_2 = os.path.join(args.output_folder, f'{args.languages[1]}-tgt.csv')
need_header_src_1 = not os.path.exists(src_metadata_1)
need_header_src_2 = not os.path.exists(src_metadata_2)
need_header_tgt_1 = not os.path.exists(tgt_metadata_1)
need_header_tgt_2 = not os.path.exists(tgt_metadata_2)
ref_clips = load_ref_list(args.languages)
meta_paths = {}
for language in args.languages:
os.makedirs(os.path.join(args.output_folder, language), exist_ok=True)
model = load_model(model_folder=args.model_folder, model_name=args.model_name, vocab_file=args.vocab_path)
outer_id = int(args.output_folder[-2:]) * 100000
# 1. Open the file (using utf-8 to handle special characters like "õ")
with open(args.dataset_meta, 'r', encoding='utf-8', buffering=1) as source_file, \
open(src_metadata_1, 'a', encoding='utf-8', buffering=1) as f_src_1, \
open(src_metadata_2, 'a', encoding='utf-8', buffering=1) as f_src_2, \
open(tgt_metadata_1, 'a', encoding='utf-8', buffering=1) as f_tgt_1, \
open(tgt_metadata_2, 'a', encoding='utf-8', buffering=1) as f_tgt_2:
if need_header_src_1:
f_src_1.write("audio_file|text\n")
if need_header_src_2:
f_src_2.write("audio_file|text\n")
if need_header_tgt_1:
f_tgt_1.write("audio_file|text\n")
if need_header_tgt_2:
f_tgt_2.write("audio_file|text\n")
# 2. Iterate through the file line by line
id = 1
line = source_file.readline()
with tqdm() as pbar:
while line:
try:
if id <= args.start_id:
continue
# 3. Parse the current line into a dictionary
data = json.loads(line)
# 4. Assign values to variables as requested
src_segm = data.get('src_segm')
tgt_segm = data.get('tgt_segm')
if len(src_segm.split(" ")) < 3 or max(len(src_segm), len(tgt_segm)) > 400:
continue
src_lang = lang_codes[data.get('src_lang')]
tgt_lang = lang_codes[data.get('tgt_lang')]
gpt_cond_latent, speaker_embedding = reference_latents_and_embedding(src_lang)
# print(f"Source language: {src_lang}, source text: {src_segm}")
src_clip = perform_synthesis(model, gpt_cond_latent, speaker_embedding, src_segm, src_lang)
src_path = write_output(src_clip, args.output_folder, src_lang, id + outer_id)
if src_lang == args.languages[0]:
f_src_1.write('|'.join([src_path, src_segm]) + '\n')
else:
f_src_2.write('|'.join([src_path, src_segm]) + '\n')
# print(f"Target language: {tgt_lang}, target text: {tgt_segm}")
tgt_clip = perform_synthesis(model, gpt_cond_latent, speaker_embedding, tgt_segm, tgt_lang)
tgt_path = write_output(tgt_clip, args.output_folder, tgt_lang, id + outer_id)
if tgt_lang == args.languages[0]:
f_tgt_1.write('|'.join([tgt_path, tgt_segm]) + '\n')
else:
f_tgt_2.write('|'.join([tgt_path, tgt_segm]) + '\n')
except json.JSONDecodeError:
print(f"Skipping invalid JSON line: {line}")
finally:
if not id <= args.start_id and id % 100 == 0:
for f in (f_src_1, f_src_2, f_tgt_1, f_tgt_2):
f.flush()
id += 1
line = source_file.readline()
pbar.update(1)