| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| import argparse |
| import os |
| from datetime import datetime |
| from importlib.resources import files |
| from pathlib import Path |
| import sys |
| import tqdm |
|
|
| import soundfile as sf |
| import time |
|
|
| from omegaconf import OmegaConf |
| import torchaudio |
| import torch.multiprocessing as mp |
|
|
| from f5_tts.infer.utils_infer import ( |
| load_model, |
| load_vocoder, |
| remove_silence_for_generated_wav, |
| ) |
|
|
| sys.path.append(str(Path(__file__).parent)) |
| from utils_infer import ( |
| mel_spec_type, |
| target_rms, |
| nfe_step, |
| cfg_strength, |
| sway_sampling_coef, |
| speed, |
| fix_duration, |
| chunk_infer_batch_process |
| ) |
| from model.cadit import CADiT |
|
|
| import logging |
| console_format = logging.Formatter( |
| "[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s" |
| ) |
| console_handler = logging.StreamHandler() |
| console_handler.setFormatter(console_format) |
| console_handler.setLevel(logging.INFO) |
| if len(logging.root.handlers) > 0: |
| for handler in logging.root.handlers: |
| logging.root.removeHandler(handler) |
| logging.root.addHandler(console_handler) |
| logging.root.setLevel(logging.INFO) |
|
|
|
|
| TOKENIZE_ON_NPU = os.environ.get("TOKENIZE_ON_NPU") |
| if TOKENIZE_ON_NPU is not None and TOKENIZE_ON_NPU == "1": |
| import torch_npu |
| import f5tts_npu_patch |
| from torch_npu.contrib import transfer_to_npu |
|
|
| logging.info("Applying Patches for NPU!!!") |
| f5tts_npu_patch.patch_for_npu() |
|
|
|
|
| class SpeechDetokenizer: |
| def __init__(self, |
| vocoder_path:str, |
| model_cfg:str = str((Path(__file__).parent / "ckpt/model.yaml").absolute()), |
| ckpt_file:str = str((Path(__file__).parent / "ckpt/model.pt").absolute()), |
| vocab_file:str = str((Path(__file__).parent / "ckpt/vocab_4096.txt").absolute()), |
| device="cuda:0"): |
| self.model_cfg = model_cfg |
| self.ckpt_file = ckpt_file |
| self.vocab_file = vocab_file |
| self.vocoder_path = vocoder_path |
| self.device = device |
|
|
| self.cross_fade_duration = 0 |
| self.initialize() |
|
|
| def initialize(self): |
| self.model = "CADiT" |
| load_vocoder_from_local = True |
|
|
| self.vocoder_name = mel_spec_type |
| |
| vocoder_local_path = self.vocoder_path |
|
|
| |
| model_cls = CADiT |
| model_cfg = OmegaConf.load(self.model_cfg).model.arch |
| logging.info(f"Using {self.model}...") |
|
|
| |
| self.vocoder = load_vocoder( |
| vocoder_name=self.vocoder_name, |
| is_local=load_vocoder_from_local, |
| local_path=vocoder_local_path, |
| device=self.device, |
| ) |
|
|
| |
| self.ema_model = load_model( |
| model_cls, |
| model_cfg, |
| self.ckpt_file, |
| mel_spec_type=self.vocoder_name, |
| vocab_file=self.vocab_file, |
| device=self.device, |
| ) |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
|
|
| def chunk_text_with_look_ahead( |
| self, text_list, chunk_look_ahead_len, chunk_len=135, merge_short_last=False |
| ): |
| chunks = [] |
|
|
| stride = chunk_len - chunk_look_ahead_len |
| for i in range(0, len(text_list), stride): |
| chk = text_list[i : i + chunk_len] |
| chunks.append(chk) |
| if i + chunk_len >= len(text_list): |
| break |
|
|
| if ( |
| merge_short_last |
| and len(chunks) >= 2 |
| and len(chunks[-1]) < stride |
| ): |
| |
| last = chunks.pop() |
| second_last = chunks.pop() |
| if chunk_look_ahead_len <= 0: |
| chunks.append(second_last + last) |
| else: |
| chunks.append(second_last[:-chunk_look_ahead_len] + last) |
|
|
| actual_chunks = [] |
| for idx in range(len(chunks)): |
| chk = chunks[idx] |
| if chunk_look_ahead_len <= 0: |
| actual_chunks.extend(chk) |
| else: |
| if idx < len(chunks) - 1: |
| actual_chunks.extend(chk[:-chunk_look_ahead_len]) |
| else: |
| actual_chunks.extend(chk) |
|
|
| assert(len(actual_chunks) == len(text_list)) |
| assert(actual_chunks == text_list) |
| return chunks |
|
|
| def chunk_generate( |
| self, |
| ref_audio, |
| ref_text_list, |
| gen_text_list, |
| token_chunk_len, |
| chunk_cond_proportion, |
| chunk_look_ahead_len=0, |
| max_ref_duration=4.5, |
| ref_head_cut=False, |
| ): |
|
|
| gen_text_batches = self.chunk_text_with_look_ahead( |
| gen_text_list, |
| chunk_look_ahead_len, |
| chunk_len=token_chunk_len, |
| merge_short_last=True, |
| ) |
|
|
| if len(gen_text_batches) == 0: |
| return None, None |
|
|
| for i, gen_text in enumerate(gen_text_batches): |
| logging.info(f"gen_text {i} with {len(gen_text)} tokens : {gen_text}") |
|
|
| audio, sr = torchaudio.load(ref_audio) |
| logging.info(f"Generating audio in {len(gen_text_batches)} batches...") |
|
|
| target_wave, target_sample_rate, combined_spectrogram = chunk_infer_batch_process( |
| (audio, sr), |
| ref_text_list, |
| gen_text_batches, |
| self.ema_model, |
| self.vocoder, |
| mel_spec_type=mel_spec_type, |
| progress=tqdm, |
| target_rms=target_rms, |
| cross_fade_duration=self.cross_fade_duration, |
| nfe_step=nfe_step, |
| cfg_strength=cfg_strength, |
| sway_sampling_coef=sway_sampling_coef, |
| speed=speed, |
| fix_duration=fix_duration, |
| device=self.device, |
| chunk_cond_proportion=chunk_cond_proportion, |
| chunk_look_ahead=chunk_look_ahead_len, |
| max_ref_duration=max_ref_duration, |
| ref_head_cut=ref_head_cut, |
| ) |
| return target_wave, target_sample_rate |
|
|
|
|
| def get_audio_duration(audio_path): |
| audio, sample_rate = torchaudio.load(audio_path) |
| return audio.shape[1] / sample_rate |
|
|
|
|
| def get_test_list(testset_path): |
| testset_file_path = testset_path |
| testset_list = [] |
|
|
| with open(testset_file_path, "r") as f: |
| for line in f: |
| content = line.strip().split("|") |
| if len(content) == 2 or len(content) == 3: |
| testset_list.append([content[1], content[0], content[1]]) |
| elif len(content) == 4 or len(content) == 5: |
| testset_list.append([content[1], content[0], content[3], content[2]]) |
| return testset_list |
|
|
|
|
| def infer(args, task_queue, rank=0): |
| device_spec = f"cuda:{rank}" |
|
|
| if args.model_cfg is None or args.ckpt is None or args.vocab is None: |
| detoker = SpeechDetokenizer( |
| vocoder_path=args.vocoder, |
| device=device_spec, |
| ) |
| else: |
| detoker = SpeechDetokenizer( |
| vocoder_path=args.vocoder, |
| model_cfg=args.model_cfg, |
| ckpt_file=args.ckpt, |
| vocab_file=args.vocab, |
| device=device_spec, |
| ) |
|
|
| token_chunk_len = args.chunk_token |
| chunk_cond_proportion = args.chunk_cond_portion |
| if chunk_cond_proportion > 1 or chunk_cond_proportion <= 0: |
| chunk_cond_proportion = 0.5 |
|
|
| chunk_look_ahead = args.chunk_look_ahead |
| if chunk_look_ahead >= token_chunk_len: |
| chunk_look_ahead = 0 |
|
|
| remove_silence = False |
|
|
| output_dir = args.output |
| if not os.path.exists(Path(output_dir)): |
| os.makedirs(Path(output_dir)) |
|
|
| |
| logging.info(f"infer with chunk of {token_chunk_len} tokens") |
| logging.info(f"the last {chunk_cond_proportion} of each chunk added into condition") |
| logging.info(f"Using the last {chunk_look_ahead} tokens as look ahead") |
|
|
| gen_nums = 0 |
| while True: |
| try: |
| _tst = task_queue.get() |
| if _tst is None: |
| logging.info("FINISH processing all inputs") |
| break |
|
|
| ref_text_list = _tst[0].split() |
| ref_audio = _tst[1] |
| gen_text_list = _tst[2].split() |
|
|
| if len(_tst) == 4: |
| gen_audio = _tst[3] |
| else: |
| gen_audio = None |
|
|
| ref_wave_path = ( |
| Path(output_dir) / f"{ref_audio.split('/')[-1].split('.')[0]}_ref.wav" |
| ) |
| if gen_audio is None: |
| gen_wave_path = ( |
| Path(output_dir) |
| / f"{ref_audio.split('/')[-1].split('.')[0]}_gen.wav" |
| ) |
| orig_wave_path = None |
| else: |
| gen_wave_path = ( |
| Path(output_dir) |
| / f"{gen_audio.split('/')[-1].split('.')[0]}_gen.wav" |
| ) |
| orig_wave_path = ( |
| Path(output_dir) |
| / f"{gen_audio.split('/')[-1].split('.')[0]}_orig.wav" |
| ) |
|
|
| if os.path.exists(gen_wave_path): |
| logging.info(f"{gen_wave_path} already exist, skip") |
| continue |
|
|
| if not os.path.exists(ref_wave_path): |
| os.system(f"cp {ref_audio} {ref_wave_path}") |
|
|
| if gen_audio is not None and os.path.exists(gen_audio) and orig_wave_path: |
| os.system(f"cp {gen_audio} {orig_wave_path}") |
|
|
| generated_wave, target_sample_rate = detoker.chunk_generate( |
| ref_audio, |
| ref_text_list, |
| gen_text_list, |
| token_chunk_len, |
| chunk_cond_proportion, |
| chunk_look_ahead, |
| args.max_ref_duration, |
| args.ref_audio_cut_from_head, |
| ) |
|
|
| if generated_wave is None: |
| continue |
|
|
| with open(gen_wave_path, "wb") as f: |
| sf.write(f.name, generated_wave, target_sample_rate) |
| |
| if remove_silence: |
| remove_silence_for_generated_wav(f.name) |
| logging.info(f"write output to: {f.name}") |
|
|
| gen_nums += 1 |
| |
| |
| except: |
| logging.info(f"Fail to get new task") |
|
|
|
|
| def run_infer_mp(args): |
|
|
| device_list = [0] |
| if "CUDA_VISIBLE_DEVICES" in os.environ: |
| device_list = [int(x.strip()) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] |
| elif "ASCEND_RT_VISIBLE_DEVICES" in os.environ: |
| device_list = [int(x.strip()) for x in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")] |
|
|
| logging.info(f"Using devices: {device_list}") |
| n_procs = len(device_list) |
|
|
| |
| testset_list = get_test_list(args.testset_path) |
|
|
| ctx = mp.get_context("spawn") |
| with ctx.Manager() as manager: |
| task_queue = manager.Queue() |
| for task in testset_list: |
| task_queue.put(task) |
|
|
| processes = [] |
| for idx in range(n_procs): |
| task_queue.put(None) |
| rank = idx |
| p = mp.Process(target=infer, args=(args, task_queue, rank)) |
| p.start() |
| processes.append(p) |
|
|
| for p in processes: |
| p.join() |
|
|
| os.system(f"cp {args.testset_path} {args.output}") |
| logging.info(f"Finish processing of {n_procs}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--ckpt", |
| required=False, |
| help="path to ckpt", |
| ) |
| parser.add_argument( |
| "--model-cfg", |
| required=False, |
| help="path to model_cfg", |
| ) |
| parser.add_argument( |
| "--vocab", |
| required=False, |
| help="path to vocab", |
| ) |
| parser.add_argument( |
| "--vocoder", |
| required=True, |
| help="path to vocoder", |
| ) |
| parser.add_argument( |
| "--testset", |
| dest="testset_path", |
| required=True, |
| help="path of testset file", |
| ) |
| parser.add_argument( |
| "--output", |
| required=True, |
| help="path to output generated audio", |
| ) |
| parser.add_argument( |
| "--chunk-token", |
| required=True, |
| type=int, |
| default=25, |
| help="max number of tokens in a chunk", |
| ) |
| parser.add_argument( |
| "--chunk-look-ahead", |
| required=False, |
| type=int, |
| default=0, |
| help="number of tokens in a chunk as look ahead", |
| ) |
| parser.add_argument( |
| "--chunk-cond-portion", |
| required=True, |
| type=float, |
| default=25, |
| help="the portion at the tail of the prev chunk as condition", |
| ) |
| parser.add_argument( |
| "--max-ref-duration", |
| required=False, |
| type=float, |
| default=4.5, |
| help="the max duration of ref audio in seconds", |
| ) |
| parser.add_argument( |
| "--ref-audio-cut-from-head", |
| default=False, |
| action="store_true", |
| help="cut ref audio from head, if not set, from tail by default", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| start_time = time.perf_counter() |
|
|
| run_infer_mp(args) |
|
|
| end_time = time.perf_counter() |
| logging.info("processig time: %f sec\n" % (end_time - start_time)) |
| logging.info(f"Finished! output to : {args.output}") |
|
|