Spaces:
Runtime error
Runtime error
Commit
·
b772f7c
1
Parent(s):
7e90749
feat: add support for gpu
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
import time
|
| 3 |
import gradio as gr
|
| 4 |
import utils
|
|
@@ -6,14 +5,16 @@ import commons
|
|
| 6 |
from models import SynthesizerTrn
|
| 7 |
from text import text_to_sequence
|
| 8 |
from torch import no_grad, LongTensor
|
|
|
|
| 9 |
|
| 10 |
hps_ms = utils.get_hparams_from_file(r'./model/config.json')
|
|
|
|
| 11 |
net_g_ms = SynthesizerTrn(
|
| 12 |
len(hps_ms.symbols),
|
| 13 |
hps_ms.data.filter_length // 2 + 1,
|
| 14 |
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
| 15 |
n_speakers=hps_ms.data.n_speakers,
|
| 16 |
-
**hps_ms.model)
|
| 17 |
_ = net_g_ms.eval()
|
| 18 |
speakers = hps_ms.speakers
|
| 19 |
model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)
|
|
@@ -30,7 +31,7 @@ def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale):
|
|
| 30 |
if not len(text):
|
| 31 |
return "输入文本不能为空!", None, None
|
| 32 |
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
| 33 |
-
if len(text) >
|
| 34 |
return f"输入文字过长!{len(text)}>100", None, None
|
| 35 |
if language == 0:
|
| 36 |
text = f"[ZH]{text}[ZH]"
|
|
@@ -44,7 +45,7 @@ def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale):
|
|
| 44 |
x_tst_lengths = LongTensor([stn_tst.size(0)])
|
| 45 |
speaker_id = LongTensor([speaker_id])
|
| 46 |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
| 47 |
-
length_scale=length_scale)[0][0, 0].data.float().numpy()
|
| 48 |
|
| 49 |
return "生成成功!", (22050, audio), f"生成耗时 {round(time.perf_counter()-start, 2)} s"
|
| 50 |
|
|
@@ -116,8 +117,8 @@ if __name__ == '__main__':
|
|
| 116 |
download = gr.Button("Download Audio")
|
| 117 |
btn.click(vits, inputs=[input_text, lang, sid, ns, nsw, ls], outputs=[o1, o2, o3], api_name="generate")
|
| 118 |
download.click(None, [], [], _js=download_audio_js.format())
|
| 119 |
-
btn2.click(search_speaker, inputs=[search], outputs=[sid]
|
| 120 |
-
lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls]
|
| 121 |
with gr.TabItem("可用人物一览"):
|
| 122 |
gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index")
|
| 123 |
-
app.queue(concurrency_count=1).launch()
|
|
|
|
|
|
|
| 1 |
import time
|
| 2 |
import gradio as gr
|
| 3 |
import utils
|
|
|
|
| 5 |
from models import SynthesizerTrn
|
| 6 |
from text import text_to_sequence
|
| 7 |
from torch import no_grad, LongTensor
|
| 8 |
+
import torch
|
| 9 |
|
| 10 |
hps_ms = utils.get_hparams_from_file(r'./model/config.json')
|
| 11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
net_g_ms = SynthesizerTrn(
|
| 13 |
len(hps_ms.symbols),
|
| 14 |
hps_ms.data.filter_length // 2 + 1,
|
| 15 |
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
| 16 |
n_speakers=hps_ms.data.n_speakers,
|
| 17 |
+
**hps_ms.model).to(device)
|
| 18 |
_ = net_g_ms.eval()
|
| 19 |
speakers = hps_ms.speakers
|
| 20 |
model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)
|
|
|
|
| 31 |
if not len(text):
|
| 32 |
return "输入文本不能为空!", None, None
|
| 33 |
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
| 34 |
+
if len(text) > 500:
|
| 35 |
return f"输入文字过长!{len(text)}>100", None, None
|
| 36 |
if language == 0:
|
| 37 |
text = f"[ZH]{text}[ZH]"
|
|
|
|
| 45 |
x_tst_lengths = LongTensor([stn_tst.size(0)])
|
| 46 |
speaker_id = LongTensor([speaker_id])
|
| 47 |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
| 48 |
+
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
|
| 49 |
|
| 50 |
return "生成成功!", (22050, audio), f"生成耗时 {round(time.perf_counter()-start, 2)} s"
|
| 51 |
|
|
|
|
| 117 |
download = gr.Button("Download Audio")
|
| 118 |
btn.click(vits, inputs=[input_text, lang, sid, ns, nsw, ls], outputs=[o1, o2, o3], api_name="generate")
|
| 119 |
download.click(None, [], [], _js=download_audio_js.format())
|
| 120 |
+
btn2.click(search_speaker, inputs=[search], outputs=[sid])
|
| 121 |
+
lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
|
| 122 |
with gr.TabItem("可用人物一览"):
|
| 123 |
gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index")
|
| 124 |
+
app.queue(concurrency_count=1).launch()
|