| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import multiprocessing as mp |
| import torch |
| import os |
| from functools import partial |
| import gradio as gr |
| import traceback |
| from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav |
|
|
|
|
| def model_worker(input_queue, output_queue, device_id): |
| device = None |
| if device_id is not None: |
| device = torch.device(f'cuda:{device_id}') |
| infer_pipe = MegaTTS3DiTInfer(device=device) |
|
|
| while True: |
| task = input_queue.get() |
| inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task |
| try: |
| convert_to_wav(inp_audio_path) |
| wav_path = os.path.splitext(inp_audio_path)[0] + '.wav' |
| cut_wav(wav_path, max_len=28) |
| with open(wav_path, 'rb') as file: |
| file_content = file.read() |
| resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path) |
| wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w) |
| output_queue.put(wav_bytes) |
| except Exception as e: |
| traceback.print_exc() |
| print(task, str(e)) |
| output_queue.put(None) |
|
|
|
|
| def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue): |
| print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w) |
| input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)) |
| res = output_queue.get() |
| if res is not None: |
| return res |
| else: |
| print("") |
| return None |
|
|
|
|
| if __name__ == '__main__': |
| mp.set_start_method('spawn', force=True) |
| mp_manager = mp.Manager() |
|
|
| devices = os.environ.get('CUDA_VISIBLE_DEVICES', '') |
| if devices != '': |
| devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",") |
| else: |
| devices = None |
| |
| num_workers = 1 |
| input_queue = mp_manager.Queue() |
| output_queue = mp_manager.Queue() |
| processes = [] |
|
|
| print("Start open workers") |
| for i in range(num_workers): |
| p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None)) |
| p.start() |
| processes.append(p) |
|
|
| api_interface = gr.Interface(fn= |
| partial(main, processes=processes, input_queue=input_queue, |
| output_queue=output_queue), |
| inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text", |
| gr.Number(label="infer timestep", value=32), |
| gr.Number(label="Intelligibility Weight", value=1.4), |
| gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")], |
| title="MegaTTS3", |
| description="Upload a speech clip as a reference for timbre, " + |
| "upload the pre-extracted latent file, "+ |
| "input the target text, and receive the cloned voice.", concurrency_limit=1) |
| api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True) |
| for p in processes: |
| p.join() |
|
|