pluviouse commited on
Commit
7b7f715
·
verified ·
1 Parent(s): 79a5afc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -222
app.py CHANGED
@@ -1,243 +1,143 @@
1
- import argparse
2
- import json
3
- import os
4
- import re
5
  import tempfile
6
  import logging
 
7
 
8
- logging.getLogger('numba').setLevel(logging.WARNING)
9
- import librosa
10
- import numpy as np
11
- import torch
12
  from torch import no_grad, LongTensor
13
- import commons
14
  import utils
15
- import gradio as gr
16
- import gradio.utils as gr_utils
17
- import gradio.processing_utils as gr_processing_utils
18
  import ONNXVITS_infer
19
- import models
20
- from text import text_to_sequence, _clean_text
21
- from text.symbols import symbols
22
- from mel_processing import spectrogram_torch
23
- import psutil
24
- from datetime import datetime
25
-
26
- language_marks = {
27
- "Japanese": "",
28
- "日本語": "[JA]",
29
- "简体中文": "[ZH]",
30
- "English": "[EN]",
31
- "Mix": "",
32
- }
33
-
34
- limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
35
-
36
-
37
- def create_tts_fn(model, hps, speaker_ids):
38
- def tts_fn(text, speaker, language, speed, is_symbol):
39
- if language is not None:
40
- text = language_marks[language] + text + language_marks[language]
41
- speaker_id = speaker_ids[speaker]
42
- stn_tst = get_text(text, hps, is_symbol)
43
- with no_grad():
44
- x_tst = stn_tst.unsqueeze(0)
45
- x_tst_lengths = LongTensor([stn_tst.size(0)])
46
- sid = LongTensor([speaker_id])
47
- audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
48
- length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
49
- del stn_tst, x_tst, x_tst_lengths, sid
50
- return "Success", (hps.data.sampling_rate, audio)
51
-
52
- return tts_fn
53
-
54
-
55
- def create_vc_fn(model, hps, speaker_ids):
56
- def vc_fn(original_speaker, target_speaker, input_audio):
57
- if input_audio is None:
58
- return "You need to upload an audio", None
59
- sampling_rate, audio = input_audio
60
- duration = audio.shape[0] / sampling_rate
61
-
62
- audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
63
- if len(audio.shape) > 1:
64
- audio = librosa.to_mono(audio.transpose(1, 0))
65
- if sampling_rate != hps.data.sampling_rate:
66
- audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=hps.data.sampling_rate)
67
- with no_grad():
68
- y = torch.FloatTensor(audio)
69
- y = y.unsqueeze(0)
70
- spec = spectrogram_torch(y, hps.data.filter_length,
71
- hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
72
- center=False)
73
- spec_lengths = LongTensor([spec.size(-1)])
74
- sid_src = LongTensor([original_speaker_id])
75
- sid_tgt = LongTensor([target_speaker_id])
76
- audio = model.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt)[0][
77
- 0, 0].data.cpu().float().numpy()
78
- del y, spec, spec_lengths, sid_src, sid_tgt
79
- return "Success", (hps.data.sampling_rate, audio)
80
-
81
- return vc_fn
82
-
83
-
84
- def get_text(text, hps, is_symbol):
85
- text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
86
- if hps.data.add_blank:
87
- text_norm = commons.intersperse(text_norm, 0)
88
- text_norm = LongTensor(text_norm)
89
- return text_norm
90
-
91
 
92
- def create_to_symbol_fn(hps):
93
- def to_symbol_fn(is_symbol_input, input_text, temp_text):
94
- return (_clean_text(input_text, hps.data.text_cleaners), input_text) if is_symbol_input \
95
- else (temp_text, temp_text)
96
 
97
- return to_symbol_fn
 
 
 
 
 
98
 
 
 
 
 
 
 
99
 
100
  models_tts = []
101
- models_vc = []
102
  models_info = [
103
- {
104
- "title": "Trilingual",
105
- "languages": ['日本語', '简体中文', 'English', 'Mix'],
106
- "description": """
107
- This model is trained on a mix up of Umamusume, Genshin Impact, Sanoba Witch & VCTK voice data to learn multilanguage.
108
- All characters can speak English, Chinese & Japanese.\n\n
109
- To mix multiple languages in a single sentence, wrap the corresponding part with language tokens
110
- ([JA] for Japanese, [ZH] for Chinese, [EN] for English), as shown in the examples.\n\n
111
- 这个模型在赛马娘,原神,魔女的夜宴以及VCTK数据集上混合训练以学习多种语言。
112
- 所有角色均可说中日英三语。\n\n
113
- 若需要在同一个句子中混合多种语言,使用相应的语言标记包裹句子。
114
- (日语用[JA], 中文用[ZH], 英文用[EN]),参考Examples中的示例。
115
- """,
116
- "model_path": "./pretrained_models/G_trilingual.pth",
117
- "config_path": "./configs/uma_trilingual.json",
118
- "examples": [['你好,训练员先生,很高兴见到你。', '草上飞 Grass Wonder (Umamusume Pretty Derby)', '简体中文', 1, False],
119
- ['To be honest, I have no idea what to say as examples.', '派蒙 Paimon (Genshin Impact)', 'English',
120
- 1, False],
121
- ['授業中に出しだら,学校生活終わるですわ。', '綾地 寧々 Ayachi Nene (Sanoba Witch)', '日本語', 1, False],
122
- ['[JA]こんにちわ。[JA][ZH]你好![ZH][EN]Hello![EN]', '綾地 寧々 Ayachi Nene (Sanoba Witch)', 'Mix', 1, False]],
123
- "onnx_dir": "./ONNX_net/G_trilingual/"
124
- },
125
- {
126
- "title": "Japanese",
127
- "languages": ["Japanese"],
128
- "description": """
129
- This model contains 87 characters from Umamusume: Pretty Derby, Japanese only.\n\n
130
- 这个模型包含赛马娘的所有87名角色,只能合成日语。
131
- """,
132
- "model_path": "./pretrained_models/G_jp.pth",
133
- "config_path": "./configs/uma87.json",
134
- "examples": [['お疲れ様です,トレーナーさん。', '无声铃鹿 Silence Suzuka (Umamusume Pretty Derby)', 'Japanese', 1, False],
135
- ['張り切っていこう!', '北部玄驹 Kitasan Black (Umamusume Pretty Derby)', 'Japanese', 1, False],
136
- ['何でこんなに慣れでんのよ,私のほが先に好きだっだのに。', '草上飞 Grass Wonder (Umamusume Pretty Derby)', 'Japanese', 1, False],
137
- ['授業中に出しだら,学校生活終わるですわ。', '目白麦昆 Mejiro Mcqueen (Umamusume Pretty Derby)', 'Japanese', 1, False],
138
- ['お帰りなさい,お兄様!', '米浴 Rice Shower (Umamusume Pretty Derby)', 'Japanese', 1, False],
139
- ['私の処女をもらっでください!', '米浴 Rice Shower (Umamusume Pretty Derby)', 'Japanese', 1, False]],
140
- "onnx_dir": "./ONNX_net/G_jp/"
141
- },
142
  ]
 
143
 
144
- if __name__ == "__main__":
145
- parser = argparse.ArgumentParser()
146
- parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
147
- args = parser.parse_args()
148
  for info in models_info:
149
- name = info['title']
150
- lang = info['languages']
151
- examples = info['examples']
152
- config_path = info['config_path']
153
- model_path = info['model_path']
154
- description = info['description']
155
- onnx_dir = info["onnx_dir"]
156
- hps = utils.get_hparams_from_file(config_path)
157
  model = ONNXVITS_infer.SynthesizerTrn(
158
  len(hps.symbols),
159
  hps.data.filter_length // 2 + 1,
160
  hps.train.segment_size // hps.data.hop_length,
161
  n_speakers=hps.data.n_speakers,
162
- ONNX_dir=onnx_dir,
163
- **hps.model)
164
- utils.load_checkpoint(model_path, model, None)
 
165
  model.eval()
166
- speaker_ids = hps.speakers
167
- speakers = list(hps.speakers.keys())
168
- models_tts.append((name, description, speakers, lang, examples,
169
- hps.symbols, create_tts_fn(model, hps, speaker_ids),
170
- create_to_symbol_fn(hps)))
171
- models_vc.append((name, description, speakers, create_vc_fn(model, hps, speaker_ids)))
172
- app = gr.Blocks()
173
- with app:
174
- gr.Markdown("# English & Chinese & Japanese Anime TTS\n\n"
175
- "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=Plachta.VITS-Umamusume-voice-synthesizer)\n\n"
176
- "Including Japanese TTS & Trilingual TTS, speakers are all anime characters. \n\n包含一个纯日语TTS和一个中日英三语TTS模型,主要为二次元角色。\n\n"
177
- "If you have any suggestions or bug reports, feel free to open discussion in [Community](https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/discussions).\n\n"
178
- "若有bug反馈或建议,请在[Community](https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/discussions)下开启一个新的Discussion。 \n\n"
179
- )
180
- with gr.Tabs():
181
- with gr.TabItem("TTS"):
182
- with gr.Tabs():
183
- for i, (name, description, speakers, lang, example, symbols, tts_fn, to_symbol_fn) in enumerate(
184
- models_tts):
185
- with gr.TabItem(name):
186
- gr.Markdown(description)
187
- with gr.Row():
188
- with gr.Column():
189
- textbox = gr.TextArea(label="Text",
190
- placeholder="Type your sentence here (Maximum 150 words)",
191
- value="こんにちわ。", elem_id=f"tts-input")
192
- with gr.Accordion(label="Phoneme Input", open=False):
193
- temp_text_var = gr.Variable()
194
- symbol_input = gr.Checkbox(value=False, label="Symbol input")
195
- symbol_list = gr.Dataset(label="Symbol list", components=[textbox],
196
- samples=[[x] for x in symbols],
197
- elem_id=f"symbol-list")
198
- symbol_list_json = gr.Json(value=symbols, visible=False)
199
- symbol_input.change(to_symbol_fn,
200
- [symbol_input, textbox, temp_text_var],
201
- [textbox, temp_text_var])
202
- symbol_list.click(None, [symbol_list, symbol_list_json], textbox,
203
- _js=f"""
204
- (i, symbols, text) => {{
205
- let root = document.querySelector("body > gradio-app");
206
- if (root.shadowRoot != null)
207
- root = root.shadowRoot;
208
- let text_input = root.querySelector("#tts-input").querySelector("textarea");
209
- let startPos = text_input.selectionStart;
210
- let endPos = text_input.selectionEnd;
211
- let oldTxt = text_input.value;
212
- let result = oldTxt.substring(0, startPos) + symbols[i] + oldTxt.substring(endPos);
213
- text_input.value = result;
214
- let x = window.scrollX, y = window.scrollY;
215
- text_input.focus();
216
- text_input.selectionStart = startPos + symbols[i].length;
217
- text_input.selectionEnd = startPos + symbols[i].length;
218
- text_input.blur();
219
- window.scrollTo(x, y);
220
- text = text_input.value;
221
- return text;
222
- }}""")
223
- # select character
224
- char_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label='character')
225
- language_dropdown = gr.Dropdown(choices=lang, value=lang[0], label='language')
226
- duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1,
227
- label='速度 Speed')
228
- with gr.Column():
229
- text_output = gr.Textbox(label="Message")
230
- audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
231
- btn = gr.Button("Generate!")
232
- btn.click(tts_fn,
233
- inputs=[textbox, char_dropdown, language_dropdown, duration_slider,
234
- symbol_input],
235
- outputs=[text_output, audio_output])
236
- gr.Examples(
237
- examples=example,
238
- inputs=[textbox, char_dropdown, language_dropdown,
239
- duration_slider, symbol_input],
240
- outputs=[text_output, audio_output],
241
- fn=tts_fn
242
- )
243
- app.queue(concurrency_count=3).launch(show_api=False, share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file
 
 
 
2
  import tempfile
3
  import logging
4
+ import json
5
 
 
 
 
 
6
  from torch import no_grad, LongTensor
7
+ import soundfile as sf
8
  import utils
 
 
 
9
  import ONNXVITS_infer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ app = Flask(__name__)
12
+ logging.getLogger('numba').setLevel(logging.WARNING)
 
 
13
 
14
+ TRILANGUAL = {
15
+ "title": "Trilingual",
16
+ "model_path": "./pretrained_models/G_trilingual.pth",
17
+ "config_path": "./configs/uma_trilingual.json",
18
+ "onnx_dir": "./ONNX_net/G_trilingual/"
19
+ }
20
 
21
+ JAPANESE = {
22
+ "title": "Japanese",
23
+ "model_path": "./pretrained_models/G_jp.pth",
24
+ "config_path": "./configs/uma87.json",
25
+ "onnx_dir": "./ONNX_net/G_jp/"
26
+ }
27
 
28
  models_tts = []
 
29
  models_info = [
30
+ TRILANGUAL,
31
+ JAPANESE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ]
33
+ MODEL = { "japanese": JAPANESE, "trilangual": TRILANGUAL }
34
 
35
+ def load_models():
 
 
 
36
  for info in models_info:
37
+ hps = utils.get_hparams_from_file(info['config_path'])
 
 
 
 
 
 
 
38
  model = ONNXVITS_infer.SynthesizerTrn(
39
  len(hps.symbols),
40
  hps.data.filter_length // 2 + 1,
41
  hps.train.segment_size // hps.data.hop_length,
42
  n_speakers=hps.data.n_speakers,
43
+ ONNX_dir=info["onnx_dir"],
44
+ **hps.model
45
+ )
46
+ utils.load_checkpoint(info['model_path'], model, None)
47
  model.eval()
48
+ models_tts.append({
49
+ "name": info["title"],
50
+ "model": model,
51
+ "hps": hps,
52
+ "speaker_ids": hps.speakers
53
+ })
54
+
55
+ load_models()
56
+
57
+ def get_text(text, hps, is_symbol):
58
+ from text import text_to_sequence
59
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
60
+ if hps.data.add_blank:
61
+ from commons import intersperse
62
+ text_norm = intersperse(text_norm, 0)
63
+ return LongTensor(text_norm)
64
+
65
+ def tts_process(text, speaker, speed, model_data, is_symbol):
66
+ model = model_data["model"]
67
+ hps = model_data["hps"]
68
+ speaker_id = model_data["speaker_ids"][speaker]
69
+ stn_tst = get_text(text, hps, is_symbol)
70
+ with no_grad():
71
+ x_tst = stn_tst.unsqueeze(0)
72
+ x_tst_lengths = LongTensor([stn_tst.size(0)])
73
+ sid = LongTensor([speaker_id])
74
+ audio = model.infer(
75
+ x_tst, x_tst_lengths, sid=sid,
76
+ noise_scale=0.667, noise_scale_w=0.8,
77
+ length_scale=1.0 / speed
78
+ )[0][0, 0].data.cpu().float().numpy()
79
+ return audio, hps.data.sampling_rate
80
+
81
+ def read_json(path):
82
+ with open(path, "r") as f:
83
+ return json.loads(f.read())
84
+
85
+
86
+ def get_model_data(model):
87
+ return next((m for m in models_tts if m["name"].lower() == model.lower()), None)
88
+
89
+ @app.route("/")
90
+ def index():
91
+ return jsonify({ status: "OK" })
92
+
93
+ @app.route("/<model>/speakers", methods=["GET"])
94
+ def speakers(model):
95
+ global MODEL
96
+ model = model.lower()
97
+ model_info = MODEL.get(model, None)
98
+
99
+ if model_info is None:
100
+ return jsonify({ "error": f"Model not found for `{model}`"}), 404
101
+
102
+ config = read_json(model_info["config_path"])
103
+ return jsonify({"model_name": model_info["title"], "speakers": config["speakers"] })
104
+
105
+ @app.route("/<model>/generate", methods=["POST", "GET"])
106
+ def generate(model):
107
+ data = request.json if request.method == "POST" else request.args
108
+ text = data.get("text")
109
+ speaker = data.get("speaker")
110
+ speed = float(data.get("speed", 1.0))
111
+ is_symbol = data.get("is_symbol", False)
112
+ speaker_id = data.get("speaker_id")
113
+
114
+ if not text:
115
+ return jsonify({"error": "Missing parameter 'text'"}), 400
116
+
117
+ model_data = get_model_data(model)
118
+ if not model_data:
119
+ return jsonify({"error": "Model not found"}), 404
120
+
121
+ if not speaker:
122
+ if speaker_id is not None:
123
+ speaker = next((k for k, v in model_data["speaker_ids"].items() if str(v) == speaker_id), None)
124
+ if not speaker:
125
+ return jsonify({"error": f"Speaker ID `{speaker_id}` not found"}), 404
126
+ else:
127
+ return jsonify({"error": "Missing 'speaker' or 'speaker_id'"}), 400
128
+
129
+ if speaker not in model_data["speaker_ids"]:
130
+ return jsonify({"error": f"Speaker `{speaker}` not found"}), 404
131
+
132
+ try:
133
+ audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol)
134
+ temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
135
+ sf.write(temp_wav.name, audio, sampling_rate, format="wav")
136
+ temp_wav.close()
137
+ return send_file(temp_wav.name, as_attachment=True, download_name="output.wav")
138
+ except Exception as e:
139
+ print(e)
140
+ return jsonify({"error": str(e)}), 500
141
+
142
+ if __name__ == "__main__":
143
+ app.run(host="0.0.0.0", port=7860, debug=True)