Spaces:
Runtime error
Runtime error
| import argparse | |
| import datetime as dt | |
| import warnings | |
| from pathlib import Path | |
| import ffmpeg | |
| import gradio as gr | |
| import IPython.display as ipd | |
| import joblib as jl | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from tqdm.auto import tqdm | |
| from diff_ttsg.hifigan.config import v1 | |
| from diff_ttsg.hifigan.denoiser import Denoiser | |
| from diff_ttsg.hifigan.env import AttrDict | |
| from diff_ttsg.hifigan.models import Generator as HiFiGAN | |
| from diff_ttsg.models.diff_ttsg import Diff_TTSG | |
| from diff_ttsg.text import cmudict, sequence_to_text, text_to_sequence | |
| from diff_ttsg.text.symbols import symbols | |
| from diff_ttsg.utils.model import denormalize | |
| from diff_ttsg.utils.utils import intersperse, plot_tensor | |
| from pymo.preprocessing import MocapParameterizer | |
| from pymo.viz_tools import render_mp4 | |
| from pymo.writers import BVHWriter | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DIFF_TTSG_CHECKPOINT = "diff_ttsg_checkpoint.ckpt" | |
| HIFIGAN_CHECKPOINT = "g_02500000" | |
| MOTION_PIPELINE = "diff_ttsg/resources/data_pipe.expmap_86.1328125fps.sav" | |
| CMU_DICT_PATH = "diff_ttsg/resources/cmu_dictionary" | |
| OUTPUT_FOLDER = "synth_output" | |
| # Model loading tools | |
| def load_model(checkpoint_path): | |
| model = Diff_TTSG.load_from_checkpoint(checkpoint_path, map_location=device) | |
| model.eval() | |
| return model | |
| # Vocoder loading tools | |
| def load_vocoder(checkpoint_path): | |
| h = AttrDict(v1) | |
| hifigan = HiFiGAN(h).to(device) | |
| hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator']) | |
| _ = hifigan.eval() | |
| hifigan.remove_weight_norm() | |
| return hifigan | |
| # Setup text preprocessing | |
| cmu = cmudict.CMUDict(CMU_DICT_PATH) | |
| def process_text(text: str): | |
| x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols))).to(device)[None] | |
| x_lengths = torch.LongTensor([x.shape[-1]]).to(device) | |
| x_phones = sequence_to_text(x.squeeze(0).tolist()) | |
| return { | |
| 'x_orig': text, | |
| 'x': x, | |
| 'x_lengths': x_lengths, | |
| 'x_phones': x_phones | |
| } | |
| # Setup motion visualisation | |
| motion_pipeline = jl.load(MOTION_PIPELINE) | |
| bvh_writer = BVHWriter() | |
| mocap_params = MocapParameterizer("position") | |
| ## Load models | |
| model = load_model(DIFF_TTSG_CHECKPOINT) | |
| vocoder = load_vocoder(HIFIGAN_CHECKPOINT) | |
| denoiser = Denoiser(vocoder, mode='zeros') | |
| # Synthesis functions | |
| def synthesise(text, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp): | |
| ## Number of timesteps to run the reverse denoising process | |
| n_timesteps = { | |
| 'mel': mel_timestep, | |
| 'motion': motion_timestep, | |
| } | |
| ## Sampling temperature | |
| temperature = { | |
| 'mel': mel_temp, | |
| 'motion': motion_temp | |
| } | |
| text_processed = process_text(text) | |
| t = dt.datetime.now() | |
| output = model.synthesise( | |
| text_processed['x'], | |
| text_processed['x_lengths'], | |
| n_timesteps=n_timesteps, | |
| temperature=temperature, | |
| stoc=False, | |
| spk=None, | |
| length_scale=length_scale | |
| ) | |
| t = (dt.datetime.now() - t).total_seconds() | |
| print(f'RTF: {t * 22050 / (output["mel"].shape[-1] * 256)}') | |
| output.update(text_processed) # merge everything to one dict | |
| return output | |
| def to_waveform(mel, vocoder): | |
| audio = vocoder(mel).clamp(-1, 1) | |
| audio = denoiser(audio.squeeze(0)).cpu().squeeze() | |
| return audio | |
| def to_bvh(motion): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| return motion_pipeline.inverse_transform([motion.cpu().squeeze(0).T]) | |
| def save_to_folder(filename: str, output: dict, folder: str): | |
| folder = Path(folder) | |
| folder.mkdir(exist_ok=True, parents=True) | |
| np.save(folder / f'{filename}', output['mel'].cpu().numpy()) | |
| sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24') | |
| with open(folder / f'{filename}.bvh', 'w') as f: | |
| bvh_writer.write(output['bvh'], f) | |
| def to_stick_video(filename, bvh, folder): | |
| folder = Path(folder) | |
| folder.mkdir(exist_ok=True, parents=True) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| X_pos = mocap_params.fit_transform([bvh]) | |
| print(f"rendering {filename} ...") | |
| render_mp4(X_pos[0], folder / f'{filename}.mp4', axis_scale=200) | |
| def combine_audio_video(filename: str, folder: str): | |
| print("Combining audio and video") | |
| folder = Path(folder) | |
| folder.mkdir(exist_ok=True, parents=True) | |
| input_video = ffmpeg.input(str(folder / f'{filename}.mp4')) | |
| input_audio = ffmpeg.input(str(folder / f'{filename}.wav')) | |
| output_filename = folder / f'{filename}_audio.mp4' | |
| ffmpeg.concat(input_video, input_audio, v=1, a=1).output(str(output_filename)).run(overwrite_output=True) | |
| print(f"Final output with audio: {output_filename}") | |
| def run(text, output, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp): | |
| print("Running synthesis") | |
| output = synthesise(text, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp) | |
| output['waveform'] = to_waveform(output['mel'], vocoder) | |
| output['bvh'] = to_bvh(output['motion'])[0] | |
| save_to_folder('temp', output, OUTPUT_FOLDER) | |
| return ( | |
| output, | |
| output['x_phones'], | |
| plot_tensor(output['mel'].squeeze().cpu().numpy()), | |
| plot_tensor(output['motion'].squeeze().cpu().numpy()), | |
| str(Path(OUTPUT_FOLDER) / f'temp.wav'), | |
| gr.update(interactive=True) | |
| ) | |
| def visualize_it(output): | |
| to_stick_video('temp', output['bvh'], OUTPUT_FOLDER) | |
| combine_audio_video('temp', OUTPUT_FOLDER) | |
| return str(Path(OUTPUT_FOLDER) / 'temp_audio.mp4') | |
| with gr.Blocks() as demo: | |
| output = gr.State(value=None) | |
| with gr.Box(): | |
| with gr.Row(): | |
| gr.Markdown("# Diff-TTSG: Denoising probabilistic integrated speech and gesture synthesis") | |
| with gr.Row(): | |
| gr.Markdown("### Read more about it at: [https://shivammehta25.github.io/Diff-TTSG/](https://shivammehta25.github.io/Diff-TTSG/)") | |
| with gr.Row(): | |
| gr.Markdown("# Text Input") | |
| with gr.Row(): | |
| gr.Markdown("Enter , to insert pause and ; for breathing pause.") | |
| with gr.Row(): | |
| gr.Markdown("It is recommended to give spaces between punctuations and words.") | |
| with gr.Row(): | |
| text = gr.Textbox(label="Text Input") | |
| with gr.Row(): | |
| examples = gr.Examples(examples=[ | |
| "Hello world ! This is a demo of Diff T T S G .", | |
| "And the train stopped, The door opened. I got out first, then Jack Kane got out, Ronan got out, Louise got out.", | |
| ], inputs=[text]) | |
| with gr.Box(): | |
| with gr.Row(): | |
| gr.Markdown("### Hyper parameters") | |
| with gr.Row(): | |
| mel_timestep = gr.Slider(label="Number of timesteps (mel)", minimum=0, maximum=1000, step=1, value=50, interactive=True) | |
| motion_timestep = gr.Slider(label="Number of timesteps (motion)", minimum=0, maximum=1000, step=1, value=500, interactive=True) | |
| length_scale = gr.Slider(label="Length scale (Speaking rate)", minimum=0.01, maximum=3.0, step=0.05, value=1.15, interactive=True) | |
| mel_temp = gr.Slider(label="Sampling temperature (mel)", minimum=0.01, maximum=5.0, step=0.05, value=1.3, interactive=True) | |
| motion_temp = gr.Slider(label="Sampling temperature (motion)", minimum=0.01, maximum=5.0, step=0.05, value=1.5, interactive=True) | |
| synth_btn = gr.Button("Synthesise") | |
| with gr.Box(): | |
| with gr.Row(): | |
| gr.Markdown("### Phonetised text") | |
| with gr.Row(): | |
| phonetised_text = gr.Textbox(label="Phonetised text", interactive=False) | |
| with gr.Box(): | |
| with gr.Row(): | |
| mel_spectrogram = gr.Image(interactive=False, label="Mel spectrogram") | |
| motion_representation = gr.Image(interactive=False, label="Motion representation") | |
| with gr.Row(): | |
| audio = gr.Audio(interactive=False, label="Audio") | |
| with gr.Box(): | |
| with gr.Row(): | |
| gr.Markdown("### Generate stick figure visualisation") | |
| with gr.Row(): | |
| gr.Markdown("(This will take a while)") | |
| with gr.Row(): | |
| visualize = gr.Button("Visualize", interactive=False) | |
| with gr.Row(): | |
| video = gr.Video(label="Video", interactive=False) | |
| synth_btn.click( | |
| fn=run, | |
| inputs=[ | |
| text, | |
| output, | |
| mel_timestep, | |
| motion_timestep, | |
| length_scale, | |
| mel_temp, | |
| motion_temp | |
| ], | |
| outputs=[ | |
| output, | |
| phonetised_text, | |
| mel_spectrogram, | |
| motion_representation, | |
| audio, | |
| # video, | |
| visualize | |
| ], api_name="diff_ttsg") | |
| visualize.click( | |
| fn=visualize_it, | |
| inputs=[output], | |
| outputs=[video], | |
| ) | |
| demo.queue(1) | |
| demo.launch() |