from flask import Flask, request, jsonify, send_file import tempfile import logging import json from text import text_to_sequence, _clean_text from torch import no_grad, LongTensor import soundfile as sf import utils import ONNXVITS_infer app = Flask(__name__) logging.getLogger('numba').setLevel(logging.WARNING) TRILINGUAL = { "title": "Trilingual", "model_path": "./pretrained_models/G_trilingual.pth", "config_path": "./configs/uma_trilingual.json", "onnx_dir": "./ONNX_net/G_trilingual/" } JAPANESE = { "title": "Japanese", "model_path": "./pretrained_models/G_jp.pth", "config_path": "./configs/uma87.json", "onnx_dir": "./ONNX_net/G_jp/" } language_marks = { "JA": "[JA]", "ZH": "[ZH]", "ENG": "[EN]", "MIX": "", } models_tts = [] models_info = [ TRILINGUAL, JAPANESE ] MODEL = { "japanese": JAPANESE, "trilingual": TRILINGUAL } def load_models(): for info in models_info: hps = utils.get_hparams_from_file(info['config_path']) model = ONNXVITS_infer.SynthesizerTrn( len(hps.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, ONNX_dir=info["onnx_dir"], **hps.model ) utils.load_checkpoint(info['model_path'], model, None) model.eval() models_tts.append({ "name": info["title"], "model": model, "hps": hps, "speaker_ids": hps.speakers }) load_models() def get_text(text, hps, is_symbol): text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) if hps.data.add_blank: from commons import intersperse text_norm = intersperse(text_norm, 0) return LongTensor(text_norm) def tts_process(text, speaker, speed, model_data, is_symbol, language = None): model = model_data["model"] hps = model_data["hps"] speaker_id = model_data["speaker_ids"][speaker] if language is not None: text = language_marks[language] + text + language_marks[language] stn_tst = get_text(text, hps, is_symbol) with no_grad(): x_tst = stn_tst.unsqueeze(0) x_tst_lengths = LongTensor([stn_tst.size(0)]) sid = LongTensor([speaker_id]) audio = model.infer( x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0 / speed )[0][0, 0].data.cpu().float().numpy() return audio, hps.data.sampling_rate def read_json(path): with open(path, "r") as f: return json.loads(f.read()) def get_model_data(model): return next((m for m in models_tts if m["name"].lower() == model.lower()), None) @app.route("/") def index(): return jsonify({"status": "OK" }) @app.route("/tosymbol", methods=["GET", "POST"]) def to_symbol(): text = request.args.get("text") if request.method == "GET" else request.json.get("text") if text is None: return jsonify({ "error": "text is required"}), 400 return _clean_text(text) @app.route("//speakers", methods=["GET"]) def speakers(model): global MODEL model = model.lower() model_info = MODEL.get(model, None) if model_info is None: return jsonify({ "error": f"Model not found for `{model}`"}), 404 config = read_json(model_info["config_path"]) return jsonify({"model_name": model_info["title"], "speakers": config["speakers"] }) @app.route("//generate", methods=["POST", "GET"]) def generate(model): data = request.json if request.method == "POST" else request.args text = data.get("text") speaker = data.get("speaker") speed = float(data.get("speed", 1.0)) is_symbol = data.get("is_symbol", False) speaker_id = data.get("speaker_id") language = data.get("lang") if not text: return jsonify({"error": "Missing parameter 'text'"}), 400 model_data = get_model_data(model) if not model_data: return jsonify({"error": "Model not found"}), 404 speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() } if language is not None: is_ja = model.lower() == "japanese" if is_ja: language = None elif not is_ja and language_marks.get(language) is None: return jsonify({ "error": "language not available", "language": language_marks.keys() }) if not speaker: if speaker_id is not None: speaker = speaker_ids.get(str(speaker_id), None) if not speaker: return jsonify({"error": f"Speaker ID `{speaker_id}` not found"}), 404 else: return jsonify({"error": "Missing 'speaker' or 'speaker_id'"}), 400 if speaker not in model_data["speaker_ids"]: return jsonify({"error": f"Speaker `{speaker}` not found"}), 404 try: audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol, language) temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") sf.write(temp_wav.name, audio, sampling_rate, format="wav") temp_wav.close() return send_file(temp_wav.name, as_attachment=True, download_name="output.wav") except Exception as e: print(e) return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=True)