Update app.py
Browse files
app.py
CHANGED
|
@@ -25,6 +25,13 @@ JAPANESE = {
|
|
| 25 |
"onnx_dir": "./ONNX_net/G_jp/"
|
| 26 |
}
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
models_tts = []
|
| 29 |
models_info = [
|
| 30 |
TRILINGUAL,
|
|
@@ -62,10 +69,13 @@ def get_text(text, hps, is_symbol):
|
|
| 62 |
text_norm = intersperse(text_norm, 0)
|
| 63 |
return LongTensor(text_norm)
|
| 64 |
|
| 65 |
-
def tts_process(text, speaker, speed, model_data, is_symbol):
|
| 66 |
model = model_data["model"]
|
| 67 |
hps = model_data["hps"]
|
| 68 |
speaker_id = model_data["speaker_ids"][speaker]
|
|
|
|
|
|
|
|
|
|
| 69 |
stn_tst = get_text(text, hps, is_symbol)
|
| 70 |
with no_grad():
|
| 71 |
x_tst = stn_tst.unsqueeze(0)
|
|
@@ -110,16 +120,24 @@ def generate(model):
|
|
| 110 |
speed = float(data.get("speed", 1.0))
|
| 111 |
is_symbol = data.get("is_symbol", False)
|
| 112 |
speaker_id = data.get("speaker_id")
|
| 113 |
-
|
| 114 |
if not text:
|
| 115 |
return jsonify({"error": "Missing parameter 'text'"}), 400
|
| 116 |
|
| 117 |
model_data = get_model_data(model)
|
| 118 |
if not model_data:
|
| 119 |
return jsonify({"error": "Model not found"}), 404
|
| 120 |
-
|
| 121 |
speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() }
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
if not speaker:
|
| 124 |
if speaker_id is not None:
|
| 125 |
speaker = speaker_ids.get(str(speaker_id), None)
|
|
@@ -132,7 +150,7 @@ def generate(model):
|
|
| 132 |
return jsonify({"error": f"Speaker `{speaker}` not found"}), 404
|
| 133 |
|
| 134 |
try:
|
| 135 |
-
audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol)
|
| 136 |
temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 137 |
sf.write(temp_wav.name, audio, sampling_rate, format="wav")
|
| 138 |
temp_wav.close()
|
|
|
|
| 25 |
"onnx_dir": "./ONNX_net/G_jp/"
|
| 26 |
}
|
| 27 |
|
| 28 |
+
language_marks = {
|
| 29 |
+
"JA": "[JA]",
|
| 30 |
+
"ZH": "[ZH]",
|
| 31 |
+
"ENG": "[EN]",
|
| 32 |
+
"MIX": "",
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
models_tts = []
|
| 36 |
models_info = [
|
| 37 |
TRILINGUAL,
|
|
|
|
| 69 |
text_norm = intersperse(text_norm, 0)
|
| 70 |
return LongTensor(text_norm)
|
| 71 |
|
| 72 |
+
def tts_process(text, speaker, speed, model_data, is_symbol, language = None):
|
| 73 |
model = model_data["model"]
|
| 74 |
hps = model_data["hps"]
|
| 75 |
speaker_id = model_data["speaker_ids"][speaker]
|
| 76 |
+
if language is not None:
|
| 77 |
+
text = language_marks[language] + text + language_marks[language]
|
| 78 |
+
|
| 79 |
stn_tst = get_text(text, hps, is_symbol)
|
| 80 |
with no_grad():
|
| 81 |
x_tst = stn_tst.unsqueeze(0)
|
|
|
|
| 120 |
speed = float(data.get("speed", 1.0))
|
| 121 |
is_symbol = data.get("is_symbol", False)
|
| 122 |
speaker_id = data.get("speaker_id")
|
| 123 |
+
language = data.get("lang")
|
| 124 |
if not text:
|
| 125 |
return jsonify({"error": "Missing parameter 'text'"}), 400
|
| 126 |
|
| 127 |
model_data = get_model_data(model)
|
| 128 |
if not model_data:
|
| 129 |
return jsonify({"error": "Model not found"}), 404
|
| 130 |
+
|
| 131 |
speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() }
|
| 132 |
|
| 133 |
+
if language is not None:
|
| 134 |
+
is_ja = model.lower() == "japanese"
|
| 135 |
+
if is_ja:
|
| 136 |
+
language = None
|
| 137 |
+
elif not is_ja and language_marks.get(language) is None:
|
| 138 |
+
return jsonify({ "error": "language not available", "language": language_marks.keys() })
|
| 139 |
+
|
| 140 |
+
|
| 141 |
if not speaker:
|
| 142 |
if speaker_id is not None:
|
| 143 |
speaker = speaker_ids.get(str(speaker_id), None)
|
|
|
|
| 150 |
return jsonify({"error": f"Speaker `{speaker}` not found"}), 404
|
| 151 |
|
| 152 |
try:
|
| 153 |
+
audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol, language)
|
| 154 |
temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 155 |
sf.write(temp_wav.name, audio, sampling_rate, format="wav")
|
| 156 |
temp_wav.close()
|