Spaces:
Running
Running
| import json | |
| import os | |
| import random | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from tqdm import tqdm | |
| from scipy.io.wavfile import write | |
| import argparse | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| lang_codes = { | |
| 'English': 'en', | |
| 'Estonian': 'et', | |
| 'Russian': 'ru', | |
| } | |
| ref_metas = { | |
| 'en': '/scratch/project_465001704/data/eng/commonvoice/metadata.csv', | |
| 'et': '/scratch/project_465001704/data/est/commonvoice-14.0/metadata.csv', | |
| 'ru': '/scratch/project_465001704/data/rus/commonvoice-20.0/metadata.csv', | |
| } | |
| refs = {} | |
| def load_ref_list(languages): | |
| for language in languages: | |
| refs[language] = pd.read_csv(ref_metas[language], sep='|')['audio_file'].tolist() | |
| def create_xtts_trainer_parser(): | |
| parser = argparse.ArgumentParser(description="Arguments for XTTS runner") | |
| parser.add_argument("--model_folder", type=str, default='/scratch/project_465001704/output/xtts-gpt/run/training/GPT_XTTS_FT-November-24-2025_01+29AM-8e59ec3', #required=True, | |
| help="Path of model file") | |
| parser.add_argument("--model_name", type=str, default='best_model', # required=True, | |
| help="Name of model file") | |
| parser.add_argument("--vocab_path", type=str, default='/project/project_465001704/rlellep/repos/XTTSv2-Finetuning-for-New-Languages/vocabs/vocab_et-100.json', # required=True, | |
| help="Path of vocab file") | |
| parser.add_argument("--languages", nargs='+', type=str, default=["en", "et"], # required=True, | |
| help="language1 language2") | |
| parser.add_argument("--dataset_meta", type=str, default='/scratch/project_465001704/data/to_synth_split/en_et-EOPC_00.jsonl', # required=True, | |
| help="Path of metadata file") | |
| parser.add_argument("--output_folder", type=str, default='/scratch/project_465001704/output/synth/en_et_EOPC_00', # required=True, | |
| help="Path of output folder") | |
| parser.add_argument("--stream", type=bool, default=True, | |
| help="Run model in stream mode.") | |
| parser.add_argument("--start_id", type=int, default=0) | |
| return parser | |
| def load_model(model_folder, model_name, vocab_file): | |
| # Model paths | |
| xtts_checkpoint = os.path.join(model_folder, f"{model_name}.pth") | |
| xtts_config = os.path.join(model_folder, "config.json") | |
| # Load model | |
| config = XttsConfig() | |
| config.load_json(xtts_config) | |
| XTTS_MODEL = Xtts.init_from_config(config) | |
| XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=vocab_file, use_deepspeed=False) | |
| XTTS_MODEL.to(device) | |
| return XTTS_MODEL | |
| def get_random_reference(language): | |
| while True: | |
| ref_file = random.choice(refs[language]) | |
| try: | |
| metadata = torchaudio.info(ref_file) | |
| duration = metadata.num_frames / metadata.sample_rate | |
| if duration > 1: | |
| return ref_file | |
| except Exception as e: | |
| continue | |
| def reference_latents_and_embedding(language): | |
| ref_clip = get_random_reference(language) | |
| return model.get_conditioning_latents( | |
| audio_path=ref_clip, | |
| gpt_cond_len=model.config.gpt_cond_len, | |
| max_ref_length=model.config.max_ref_len, | |
| sound_norm_refs=model.config.sound_norm_refs, | |
| ) | |
| def perform_synthesis(model, gpt_cond_latent, speaker_embedding, text, language, stream=True): | |
| wav_chunks = [] | |
| if stream: | |
| for chunk in model.inference_stream( | |
| text=text, | |
| language=language, | |
| gpt_cond_latent=gpt_cond_latent, | |
| speaker_embedding=speaker_embedding, | |
| temperature=0.1, | |
| length_penalty=1.0, | |
| repetition_penalty=10.0, | |
| top_k=10, | |
| top_p=0.3, | |
| ): | |
| if chunk is not None: | |
| wav_chunks.append(chunk) | |
| else: | |
| wav_chunk = model.inference( | |
| text=text, | |
| language=language, | |
| gpt_cond_latent=gpt_cond_latent, | |
| speaker_embedding=speaker_embedding, | |
| temperature=0.1, | |
| length_penalty=1.0, | |
| repetition_penalty=10.0, | |
| top_k=10, | |
| top_p=0.3, | |
| ) | |
| wav_chunks.append(torch.tensor(wav_chunk["wav"])) | |
| out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0)[0].detach().cpu().numpy() | |
| return out_wav | |
| def write_output(clip, output_folder, language, id): | |
| relative_path = os.path.join(language, f'{language}_{id:07}.wav') | |
| write(os.path.join(output_folder, relative_path), 24000, clip) | |
| return relative_path | |
| if __name__ == "__main__": | |
| parser = create_xtts_trainer_parser() | |
| args = parser.parse_args() | |
| src_metadata_1 = os.path.join(args.output_folder, f'{args.languages[0]}-src.csv') | |
| src_metadata_2 = os.path.join(args.output_folder, f'{args.languages[1]}-src.csv') | |
| tgt_metadata_1 = os.path.join(args.output_folder, f'{args.languages[0]}-tgt.csv') | |
| tgt_metadata_2 = os.path.join(args.output_folder, f'{args.languages[1]}-tgt.csv') | |
| need_header_src_1 = not os.path.exists(src_metadata_1) | |
| need_header_src_2 = not os.path.exists(src_metadata_2) | |
| need_header_tgt_1 = not os.path.exists(tgt_metadata_1) | |
| need_header_tgt_2 = not os.path.exists(tgt_metadata_2) | |
| ref_clips = load_ref_list(args.languages) | |
| meta_paths = {} | |
| for language in args.languages: | |
| os.makedirs(os.path.join(args.output_folder, language), exist_ok=True) | |
| model = load_model(model_folder=args.model_folder, model_name=args.model_name, vocab_file=args.vocab_path) | |
| outer_id = int(args.output_folder[-2:]) * 100000 | |
| # 1. Open the file (using utf-8 to handle special characters like "õ") | |
| with open(args.dataset_meta, 'r', encoding='utf-8', buffering=1) as source_file, \ | |
| open(src_metadata_1, 'a', encoding='utf-8', buffering=1) as f_src_1, \ | |
| open(src_metadata_2, 'a', encoding='utf-8', buffering=1) as f_src_2, \ | |
| open(tgt_metadata_1, 'a', encoding='utf-8', buffering=1) as f_tgt_1, \ | |
| open(tgt_metadata_2, 'a', encoding='utf-8', buffering=1) as f_tgt_2: | |
| if need_header_src_1: | |
| f_src_1.write("audio_file|text\n") | |
| if need_header_src_2: | |
| f_src_2.write("audio_file|text\n") | |
| if need_header_tgt_1: | |
| f_tgt_1.write("audio_file|text\n") | |
| if need_header_tgt_2: | |
| f_tgt_2.write("audio_file|text\n") | |
| # 2. Iterate through the file line by line | |
| id = 1 | |
| line = source_file.readline() | |
| with tqdm() as pbar: | |
| while line: | |
| try: | |
| if id <= args.start_id: | |
| continue | |
| # 3. Parse the current line into a dictionary | |
| data = json.loads(line) | |
| # 4. Assign values to variables as requested | |
| src_segm = data.get('src_segm') | |
| tgt_segm = data.get('tgt_segm') | |
| if len(src_segm.split(" ")) < 3 or max(len(src_segm), len(tgt_segm)) > 400: | |
| continue | |
| src_lang = lang_codes[data.get('src_lang')] | |
| tgt_lang = lang_codes[data.get('tgt_lang')] | |
| gpt_cond_latent, speaker_embedding = reference_latents_and_embedding(src_lang) | |
| # print(f"Source language: {src_lang}, source text: {src_segm}") | |
| src_clip = perform_synthesis(model, gpt_cond_latent, speaker_embedding, src_segm, src_lang) | |
| src_path = write_output(src_clip, args.output_folder, src_lang, id + outer_id) | |
| if src_lang == args.languages[0]: | |
| f_src_1.write('|'.join([src_path, src_segm]) + '\n') | |
| else: | |
| f_src_2.write('|'.join([src_path, src_segm]) + '\n') | |
| # print(f"Target language: {tgt_lang}, target text: {tgt_segm}") | |
| tgt_clip = perform_synthesis(model, gpt_cond_latent, speaker_embedding, tgt_segm, tgt_lang) | |
| tgt_path = write_output(tgt_clip, args.output_folder, tgt_lang, id + outer_id) | |
| if tgt_lang == args.languages[0]: | |
| f_tgt_1.write('|'.join([tgt_path, tgt_segm]) + '\n') | |
| else: | |
| f_tgt_2.write('|'.join([tgt_path, tgt_segm]) + '\n') | |
| except json.JSONDecodeError: | |
| print(f"Skipping invalid JSON line: {line}") | |
| finally: | |
| if not id <= args.start_id and id % 100 == 0: | |
| for f in (f_src_1, f_src_2, f_tgt_1, f_tgt_2): | |
| f.flush() | |
| id += 1 | |
| line = source_file.readline() | |
| pbar.update(1) | |