|
|
from flask import Flask, request, jsonify, send_file |
|
|
import tempfile |
|
|
import logging |
|
|
import json |
|
|
from text import text_to_sequence, _clean_text |
|
|
|
|
|
from torch import no_grad, LongTensor |
|
|
import soundfile as sf |
|
|
import utils |
|
|
import ONNXVITS_infer |
|
|
|
|
|
app = Flask(__name__) |
|
|
logging.getLogger('numba').setLevel(logging.WARNING) |
|
|
|
|
|
TRILINGUAL = { |
|
|
"title": "Trilingual", |
|
|
"model_path": "./pretrained_models/G_trilingual.pth", |
|
|
"config_path": "./configs/uma_trilingual.json", |
|
|
"onnx_dir": "./ONNX_net/G_trilingual/" |
|
|
} |
|
|
|
|
|
JAPANESE = { |
|
|
"title": "Japanese", |
|
|
"model_path": "./pretrained_models/G_jp.pth", |
|
|
"config_path": "./configs/uma87.json", |
|
|
"onnx_dir": "./ONNX_net/G_jp/" |
|
|
} |
|
|
|
|
|
language_marks = { |
|
|
"JA": "[JA]", |
|
|
"ZH": "[ZH]", |
|
|
"ENG": "[EN]", |
|
|
"MIX": "", |
|
|
} |
|
|
|
|
|
models_tts = [] |
|
|
models_info = [ |
|
|
TRILINGUAL, |
|
|
JAPANESE |
|
|
] |
|
|
MODEL = { "japanese": JAPANESE, "trilingual": TRILINGUAL } |
|
|
|
|
|
def load_models(): |
|
|
for info in models_info: |
|
|
hps = utils.get_hparams_from_file(info['config_path']) |
|
|
model = ONNXVITS_infer.SynthesizerTrn( |
|
|
len(hps.symbols), |
|
|
hps.data.filter_length // 2 + 1, |
|
|
hps.train.segment_size // hps.data.hop_length, |
|
|
n_speakers=hps.data.n_speakers, |
|
|
ONNX_dir=info["onnx_dir"], |
|
|
**hps.model |
|
|
) |
|
|
utils.load_checkpoint(info['model_path'], model, None) |
|
|
model.eval() |
|
|
models_tts.append({ |
|
|
"name": info["title"], |
|
|
"model": model, |
|
|
"hps": hps, |
|
|
"speaker_ids": hps.speakers |
|
|
}) |
|
|
|
|
|
load_models() |
|
|
|
|
|
def get_text(text, hps, is_symbol): |
|
|
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) |
|
|
if hps.data.add_blank: |
|
|
from commons import intersperse |
|
|
text_norm = intersperse(text_norm, 0) |
|
|
return LongTensor(text_norm) |
|
|
|
|
|
def tts_process(text, speaker, speed, model_data, is_symbol, language = None): |
|
|
model = model_data["model"] |
|
|
hps = model_data["hps"] |
|
|
speaker_id = model_data["speaker_ids"][speaker] |
|
|
if language is not None: |
|
|
text = language_marks[language] + text + language_marks[language] |
|
|
|
|
|
stn_tst = get_text(text, hps, is_symbol) |
|
|
with no_grad(): |
|
|
x_tst = stn_tst.unsqueeze(0) |
|
|
x_tst_lengths = LongTensor([stn_tst.size(0)]) |
|
|
sid = LongTensor([speaker_id]) |
|
|
audio = model.infer( |
|
|
x_tst, x_tst_lengths, sid=sid, |
|
|
noise_scale=0.667, noise_scale_w=0.8, |
|
|
length_scale=1.0 / speed |
|
|
)[0][0, 0].data.cpu().float().numpy() |
|
|
return audio, hps.data.sampling_rate |
|
|
|
|
|
def read_json(path): |
|
|
with open(path, "r") as f: |
|
|
return json.loads(f.read()) |
|
|
|
|
|
|
|
|
def get_model_data(model): |
|
|
return next((m for m in models_tts if m["name"].lower() == model.lower()), None) |
|
|
|
|
|
@app.route("/") |
|
|
def index(): |
|
|
return jsonify({"status": "OK" }) |
|
|
|
|
|
@app.route("/tosymbol", methods=["GET", "POST"]) |
|
|
def to_symbol(): |
|
|
text = request.args.get("text") if request.method == "GET" else request.json.get("text") |
|
|
if text is None: |
|
|
return jsonify({ "error": "text is required"}), 400 |
|
|
return _clean_text(text) |
|
|
|
|
|
@app.route("/<model>/speakers", methods=["GET"]) |
|
|
def speakers(model): |
|
|
global MODEL |
|
|
model = model.lower() |
|
|
model_info = MODEL.get(model, None) |
|
|
|
|
|
if model_info is None: |
|
|
return jsonify({ "error": f"Model not found for `{model}`"}), 404 |
|
|
|
|
|
config = read_json(model_info["config_path"]) |
|
|
return jsonify({"model_name": model_info["title"], "speakers": config["speakers"] }) |
|
|
|
|
|
@app.route("/<model>/generate", methods=["POST", "GET"]) |
|
|
def generate(model): |
|
|
data = request.json if request.method == "POST" else request.args |
|
|
text = data.get("text") |
|
|
speaker = data.get("speaker") |
|
|
speed = float(data.get("speed", 1.0)) |
|
|
is_symbol = data.get("is_symbol", False) |
|
|
speaker_id = data.get("speaker_id") |
|
|
language = data.get("lang") |
|
|
if not text: |
|
|
return jsonify({"error": "Missing parameter 'text'"}), 400 |
|
|
|
|
|
model_data = get_model_data(model) |
|
|
if not model_data: |
|
|
return jsonify({"error": "Model not found"}), 404 |
|
|
|
|
|
speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() } |
|
|
|
|
|
if language is not None: |
|
|
is_ja = model.lower() == "japanese" |
|
|
if is_ja: |
|
|
language = None |
|
|
elif not is_ja and language_marks.get(language) is None: |
|
|
return jsonify({ "error": "language not available", "language": language_marks.keys() }) |
|
|
|
|
|
|
|
|
if not speaker: |
|
|
if speaker_id is not None: |
|
|
speaker = speaker_ids.get(str(speaker_id), None) |
|
|
if not speaker: |
|
|
return jsonify({"error": f"Speaker ID `{speaker_id}` not found"}), 404 |
|
|
else: |
|
|
return jsonify({"error": "Missing 'speaker' or 'speaker_id'"}), 400 |
|
|
|
|
|
if speaker not in model_data["speaker_ids"]: |
|
|
return jsonify({"error": f"Speaker `{speaker}` not found"}), 404 |
|
|
|
|
|
try: |
|
|
audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol, language) |
|
|
temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
|
sf.write(temp_wav.name, audio, sampling_rate, format="wav") |
|
|
temp_wav.close() |
|
|
return send_file(temp_wav.name, as_attachment=True, download_name="output.wav") |
|
|
except Exception as e: |
|
|
print(e) |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.run(host="0.0.0.0", port=7860, debug=True) |