Update app.py
Browse files
app.py
CHANGED
|
@@ -9,12 +9,14 @@ import numpy as np
|
|
| 9 |
|
| 10 |
import ChatTTS
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
print("loading ChatTTS model...")
|
| 13 |
chat = ChatTTS.Chat()
|
| 14 |
chat.load_models()
|
| 15 |
|
| 16 |
|
| 17 |
-
|
| 18 |
def generate_seed():
|
| 19 |
new_seed = random.randint(1, 100000000)
|
| 20 |
return {
|
|
@@ -23,7 +25,7 @@ def generate_seed():
|
|
| 23 |
}
|
| 24 |
|
| 25 |
@spaces.GPU
|
| 26 |
-
def
|
| 27 |
|
| 28 |
torch.manual_seed(audio_seed_input)
|
| 29 |
rand_spk = torch.randn(768)
|
|
@@ -57,7 +59,67 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_
|
|
| 57 |
sample_rate = 24000
|
| 58 |
text_data = text[0] if isinstance(text, list) else text
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
with gr.Blocks() as demo:
|
|
@@ -83,7 +145,7 @@ with gr.Blocks() as demo:
|
|
| 83 |
|
| 84 |
generate_button = gr.Button("Generate")
|
| 85 |
|
| 86 |
-
text_output = gr.Textbox(label="Refined Text", interactive=False)
|
| 87 |
audio_output = gr.Audio(label="Output Audio")
|
| 88 |
|
| 89 |
generate_audio_seed.click(generate_seed,
|
|
@@ -96,7 +158,7 @@ with gr.Blocks() as demo:
|
|
| 96 |
|
| 97 |
generate_button.click(generate_audio,
|
| 98 |
inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
|
| 99 |
-
outputs=
|
| 100 |
|
| 101 |
parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
|
| 102 |
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
|
|
|
|
| 9 |
|
| 10 |
import ChatTTS
|
| 11 |
|
| 12 |
+
import se_extractor
|
| 13 |
+
from api import BaseSpeakerTTS, ToneColorConverter
|
| 14 |
+
|
| 15 |
print("loading ChatTTS model...")
|
| 16 |
chat = ChatTTS.Chat()
|
| 17 |
chat.load_models()
|
| 18 |
|
| 19 |
|
|
|
|
| 20 |
def generate_seed():
|
| 21 |
new_seed = random.randint(1, 100000000)
|
| 22 |
return {
|
|
|
|
| 25 |
}
|
| 26 |
|
| 27 |
@spaces.GPU
|
| 28 |
+
def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):
|
| 29 |
|
| 30 |
torch.manual_seed(audio_seed_input)
|
| 31 |
rand_spk = torch.randn(768)
|
|
|
|
| 59 |
sample_rate = 24000
|
| 60 |
text_data = text[0] if isinstance(text, list) else text
|
| 61 |
|
| 62 |
+
if output_path is None:
|
| 63 |
+
return [(sample_rate, audio_data), text_data]
|
| 64 |
+
else:
|
| 65 |
+
soundfile.write(output_path, audio_data, sample_rate)
|
| 66 |
+
|
| 67 |
+
# OpenVoice
|
| 68 |
+
|
| 69 |
+
ckpt_base_en = 'checkpoints/base_speakers/EN'
|
| 70 |
+
ckpt_converter_en = 'checkpoints/converter'
|
| 71 |
+
device = 'cuda:0'
|
| 72 |
+
|
| 73 |
+
#device = "cpu"
|
| 74 |
+
|
| 75 |
+
base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device)
|
| 76 |
+
base_speaker_tts.load_ckpt(f'{ckpt_base_en}/checkpoint.pth')
|
| 77 |
+
|
| 78 |
+
tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device)
|
| 79 |
+
tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth')
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def generate_audio(text, audio_ref, style_mode, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
|
| 83 |
+
if style_mode=="default":
|
| 84 |
+
source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
|
| 85 |
+
reference_speaker = audio_ref
|
| 86 |
+
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
|
| 87 |
+
save_path = "output.wav"
|
| 88 |
+
|
| 89 |
+
# Run the base speaker tts
|
| 90 |
+
src_path = "tmp.wav"
|
| 91 |
+
chat_tts(text, text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None, src_path)
|
| 92 |
+
|
| 93 |
+
# Run the tone color converter
|
| 94 |
+
encode_message = "@MyShell"
|
| 95 |
+
tone_color_converter.convert(
|
| 96 |
+
audio_src_path=src_path,
|
| 97 |
+
src_se=source_se,
|
| 98 |
+
tgt_se=target_se,
|
| 99 |
+
output_path=save_path,
|
| 100 |
+
message=encode_message)
|
| 101 |
+
|
| 102 |
+
else:
|
| 103 |
+
source_se = torch.load(f'{ckpt_base_en}/en_style_se.pth').to(device)
|
| 104 |
+
reference_speaker = audio_ref
|
| 105 |
+
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
|
| 106 |
+
|
| 107 |
+
save_path = "output.wav"
|
| 108 |
+
|
| 109 |
+
# Run the base speaker tts
|
| 110 |
+
src_path = "tmp.wav"
|
| 111 |
+
base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=0.9)
|
| 112 |
+
|
| 113 |
+
# Run the tone color converter
|
| 114 |
+
encode_message = "@MyShell"
|
| 115 |
+
tone_color_converter.convert(
|
| 116 |
+
audio_src_path=src_path,
|
| 117 |
+
src_se=source_se,
|
| 118 |
+
tgt_se=target_se,
|
| 119 |
+
output_path=save_path,
|
| 120 |
+
message=encode_message)
|
| 121 |
+
|
| 122 |
+
return "output.wav"
|
| 123 |
|
| 124 |
|
| 125 |
with gr.Blocks() as demo:
|
|
|
|
| 145 |
|
| 146 |
generate_button = gr.Button("Generate")
|
| 147 |
|
| 148 |
+
#text_output = gr.Textbox(label="Refined Text", interactive=False)
|
| 149 |
audio_output = gr.Audio(label="Output Audio")
|
| 150 |
|
| 151 |
generate_audio_seed.click(generate_seed,
|
|
|
|
| 158 |
|
| 159 |
generate_button.click(generate_audio,
|
| 160 |
inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
|
| 161 |
+
outputs=audio_output)
|
| 162 |
|
| 163 |
parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
|
| 164 |
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
|