tts / app.py
pluviouse's picture
Update app.py
b3f61ad verified
from flask import Flask, request, jsonify, send_file
import tempfile
import logging
import json
from text import text_to_sequence, _clean_text
from torch import no_grad, LongTensor
import soundfile as sf
import utils
import ONNXVITS_infer
app = Flask(__name__)
logging.getLogger('numba').setLevel(logging.WARNING)
TRILINGUAL = {
"title": "Trilingual",
"model_path": "./pretrained_models/G_trilingual.pth",
"config_path": "./configs/uma_trilingual.json",
"onnx_dir": "./ONNX_net/G_trilingual/"
}
JAPANESE = {
"title": "Japanese",
"model_path": "./pretrained_models/G_jp.pth",
"config_path": "./configs/uma87.json",
"onnx_dir": "./ONNX_net/G_jp/"
}
language_marks = {
"JA": "[JA]",
"ZH": "[ZH]",
"ENG": "[EN]",
"MIX": "",
}
models_tts = []
models_info = [
TRILINGUAL,
JAPANESE
]
MODEL = { "japanese": JAPANESE, "trilingual": TRILINGUAL }
def load_models():
for info in models_info:
hps = utils.get_hparams_from_file(info['config_path'])
model = ONNXVITS_infer.SynthesizerTrn(
len(hps.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
ONNX_dir=info["onnx_dir"],
**hps.model
)
utils.load_checkpoint(info['model_path'], model, None)
model.eval()
models_tts.append({
"name": info["title"],
"model": model,
"hps": hps,
"speaker_ids": hps.speakers
})
load_models()
def get_text(text, hps, is_symbol):
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
if hps.data.add_blank:
from commons import intersperse
text_norm = intersperse(text_norm, 0)
return LongTensor(text_norm)
def tts_process(text, speaker, speed, model_data, is_symbol, language = None):
model = model_data["model"]
hps = model_data["hps"]
speaker_id = model_data["speaker_ids"][speaker]
if language is not None:
text = language_marks[language] + text + language_marks[language]
stn_tst = get_text(text, hps, is_symbol)
with no_grad():
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = LongTensor([stn_tst.size(0)])
sid = LongTensor([speaker_id])
audio = model.infer(
x_tst, x_tst_lengths, sid=sid,
noise_scale=0.667, noise_scale_w=0.8,
length_scale=1.0 / speed
)[0][0, 0].data.cpu().float().numpy()
return audio, hps.data.sampling_rate
def read_json(path):
with open(path, "r") as f:
return json.loads(f.read())
def get_model_data(model):
return next((m for m in models_tts if m["name"].lower() == model.lower()), None)
@app.route("/")
def index():
return jsonify({"status": "OK" })
@app.route("/tosymbol", methods=["GET", "POST"])
def to_symbol():
text = request.args.get("text") if request.method == "GET" else request.json.get("text")
if text is None:
return jsonify({ "error": "text is required"}), 400
return _clean_text(text)
@app.route("/<model>/speakers", methods=["GET"])
def speakers(model):
global MODEL
model = model.lower()
model_info = MODEL.get(model, None)
if model_info is None:
return jsonify({ "error": f"Model not found for `{model}`"}), 404
config = read_json(model_info["config_path"])
return jsonify({"model_name": model_info["title"], "speakers": config["speakers"] })
@app.route("/<model>/generate", methods=["POST", "GET"])
def generate(model):
data = request.json if request.method == "POST" else request.args
text = data.get("text")
speaker = data.get("speaker")
speed = float(data.get("speed", 1.0))
is_symbol = data.get("is_symbol", False)
speaker_id = data.get("speaker_id")
language = data.get("lang")
if not text:
return jsonify({"error": "Missing parameter 'text'"}), 400
model_data = get_model_data(model)
if not model_data:
return jsonify({"error": "Model not found"}), 404
speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() }
if language is not None:
is_ja = model.lower() == "japanese"
if is_ja:
language = None
elif not is_ja and language_marks.get(language) is None:
return jsonify({ "error": "language not available", "language": language_marks.keys() })
if not speaker:
if speaker_id is not None:
speaker = speaker_ids.get(str(speaker_id), None)
if not speaker:
return jsonify({"error": f"Speaker ID `{speaker_id}` not found"}), 404
else:
return jsonify({"error": "Missing 'speaker' or 'speaker_id'"}), 400
if speaker not in model_data["speaker_ids"]:
return jsonify({"error": f"Speaker `{speaker}` not found"}), 404
try:
audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol, language)
temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(temp_wav.name, audio, sampling_rate, format="wav")
temp_wav.close()
return send_file(temp_wav.name, as_attachment=True, download_name="output.wav")
except Exception as e:
print(e)
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)