| from flask import Flask, request, jsonify, send_file |
| import tempfile |
| import logging |
| import json |
|
|
| from torch import no_grad, LongTensor |
| import soundfile as sf |
| import utils |
| import ONNXVITS_infer |
|
|
| app = Flask(__name__) |
| logging.getLogger('numba').setLevel(logging.WARNING) |
|
|
| TRILINGUAL = { |
| "title": "Trilingual", |
| "model_path": "./pretrained_models/G_trilingual.pth", |
| "config_path": "./configs/uma_trilingual.json", |
| "onnx_dir": "./ONNX_net/G_trilingual/" |
| } |
|
|
| JAPANESE = { |
| "title": "Japanese", |
| "model_path": "./pretrained_models/G_jp.pth", |
| "config_path": "./configs/uma87.json", |
| "onnx_dir": "./ONNX_net/G_jp/" |
| } |
|
|
| models_tts = [] |
| models_info = [ |
| TRILINGUAL, |
| JAPANESE |
| ] |
| MODEL = { "japanese": JAPANESE, "trilingual": TRILINGUAL } |
|
|
| def load_models(): |
| for info in models_info: |
| hps = utils.get_hparams_from_file(info['config_path']) |
| model = ONNXVITS_infer.SynthesizerTrn( |
| len(hps.symbols), |
| hps.data.filter_length // 2 + 1, |
| hps.train.segment_size // hps.data.hop_length, |
| n_speakers=hps.data.n_speakers, |
| ONNX_dir=info["onnx_dir"], |
| **hps.model |
| ) |
| utils.load_checkpoint(info['model_path'], model, None) |
| model.eval() |
| models_tts.append({ |
| "name": info["title"], |
| "model": model, |
| "hps": hps, |
| "speaker_ids": hps.speakers |
| }) |
|
|
| load_models() |
|
|
| def get_text(text, hps, is_symbol): |
| from text import text_to_sequence |
| text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) |
| if hps.data.add_blank: |
| from commons import intersperse |
| text_norm = intersperse(text_norm, 0) |
| return LongTensor(text_norm) |
|
|
| def tts_process(text, speaker, speed, model_data, is_symbol): |
| model = model_data["model"] |
| hps = model_data["hps"] |
| speaker_id = model_data["speaker_ids"][speaker] |
| stn_tst = get_text(text, hps, is_symbol) |
| with no_grad(): |
| x_tst = stn_tst.unsqueeze(0) |
| x_tst_lengths = LongTensor([stn_tst.size(0)]) |
| sid = LongTensor([speaker_id]) |
| audio = model.infer( |
| x_tst, x_tst_lengths, sid=sid, |
| noise_scale=0.667, noise_scale_w=0.8, |
| length_scale=1.0 / speed |
| )[0][0, 0].data.cpu().float().numpy() |
| return audio, hps.data.sampling_rate |
|
|
| def read_json(path): |
| with open(path, "r") as f: |
| return json.loads(f.read()) |
|
|
|
|
| def get_model_data(model): |
| return next((m for m in models_tts if m["name"].lower() == model.lower()), None) |
| |
| @app.route("/") |
| def index(): |
| return jsonify({"status": "OK" }) |
|
|
| @app.route("/<model>/speakers", methods=["GET"]) |
| def speakers(model): |
| global MODEL |
| model = model.lower() |
| model_info = MODEL.get(model, None) |
| |
| if model_info is None: |
| return jsonify({ "error": f"Model not found for `{model}`"}), 404 |
| |
| config = read_json(model_info["config_path"]) |
| return jsonify({"model_name": model_info["title"], "speakers": config["speakers"] }) |
|
|
| @app.route("/<model>/generate", methods=["POST", "GET"]) |
| def generate(model): |
| data = request.json if request.method == "POST" else request.args |
| text = data.get("text") |
| speaker = data.get("speaker") |
| speed = float(data.get("speed", 1.0)) |
| is_symbol = data.get("is_symbol", False) |
| speaker_id = data.get("speaker_id") |
| |
| if not text: |
| return jsonify({"error": "Missing parameter 'text'"}), 400 |
|
|
| model_data = get_model_data(model) |
| if not model_data: |
| return jsonify({"error": "Model not found"}), 404 |
| |
| speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() } |
| |
| if not speaker: |
| if speaker_id is not None: |
| speaker = speaker_ids.get(str(speaker_id), None) |
| if not speaker: |
| return jsonify({"error": f"Speaker ID `{speaker_id}` not found"}), 404 |
| else: |
| return jsonify({"error": "Missing 'speaker' or 'speaker_id'"}), 400 |
|
|
| if speaker not in model_data["speaker_ids"]: |
| return jsonify({"error": f"Speaker `{speaker}` not found"}), 404 |
|
|
| try: |
| audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol) |
| temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
| sf.write(temp_wav.name, audio, sampling_rate, format="wav") |
| temp_wav.close() |
| return send_file(temp_wav.name, as_attachment=True, download_name="output.wav") |
| except Exception as e: |
| print(e) |
| return jsonify({"error": str(e)}), 500 |
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860, debug=True) |