File size: 5,532 Bytes
7b7f715
84859a3
 
7b7f715
b3f61ad
 
84859a3
7b7f715
84859a3
 
 
7b7f715
 
84859a3
50de85b
7b7f715
 
 
 
 
84859a3
7b7f715
 
 
 
 
 
84859a3
dc3a626
 
 
 
 
 
 
84859a3
 
50de85b
7b7f715
84859a3
11365f7
84859a3
7b7f715
84859a3
7b7f715
84859a3
 
 
 
 
7b7f715
 
 
 
84859a3
7b7f715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc3a626
7b7f715
 
 
dc3a626
 
 
7b7f715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50de85b
7b7f715
b3f61ad
 
 
 
 
 
 
7b7f715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc3a626
7b7f715
 
 
 
 
 
dc3a626
7f7b0fa
1a03e80
dc3a626
 
 
 
 
 
 
 
7b7f715
 
fdbb4ea
7b7f715
 
 
 
 
 
 
 
 
dc3a626
7b7f715
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from flask import Flask, request, jsonify, send_file
import tempfile
import logging
import json
from text import text_to_sequence, _clean_text
   
from torch import no_grad, LongTensor
import soundfile as sf
import utils
import ONNXVITS_infer

app = Flask(__name__)
logging.getLogger('numba').setLevel(logging.WARNING)

TRILINGUAL = {
    "title": "Trilingual",
    "model_path": "./pretrained_models/G_trilingual.pth",
    "config_path": "./configs/uma_trilingual.json",
    "onnx_dir": "./ONNX_net/G_trilingual/"
}

JAPANESE = {
    "title": "Japanese",
    "model_path": "./pretrained_models/G_jp.pth",
    "config_path": "./configs/uma87.json",
    "onnx_dir": "./ONNX_net/G_jp/"
}

language_marks = {
    "JA": "[JA]",
    "ZH": "[ZH]",
    "ENG": "[EN]",
    "MIX": "",
}

models_tts = []
models_info = [
    TRILINGUAL,
    JAPANESE
]
MODEL = { "japanese": JAPANESE, "trilingual": TRILINGUAL }

def load_models():
    for info in models_info:
        hps = utils.get_hparams_from_file(info['config_path'])
        model = ONNXVITS_infer.SynthesizerTrn(
            len(hps.symbols),
            hps.data.filter_length // 2 + 1,
            hps.train.segment_size // hps.data.hop_length,
            n_speakers=hps.data.n_speakers,
            ONNX_dir=info["onnx_dir"],
            **hps.model
        )
        utils.load_checkpoint(info['model_path'], model, None)
        model.eval()
        models_tts.append({
            "name": info["title"],
            "model": model,
            "hps": hps,
            "speaker_ids": hps.speakers
        })

load_models()

def get_text(text, hps, is_symbol):
    text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
    if hps.data.add_blank:
        from commons import intersperse
        text_norm = intersperse(text_norm, 0)
    return LongTensor(text_norm)

def tts_process(text, speaker, speed, model_data, is_symbol, language = None):
    model = model_data["model"]
    hps = model_data["hps"]
    speaker_id = model_data["speaker_ids"][speaker]
    if language is not None:
        text = language_marks[language] + text + language_marks[language]
        
    stn_tst = get_text(text, hps, is_symbol)
    with no_grad():
        x_tst = stn_tst.unsqueeze(0)
        x_tst_lengths = LongTensor([stn_tst.size(0)])
        sid = LongTensor([speaker_id])
        audio = model.infer(
            x_tst, x_tst_lengths, sid=sid, 
            noise_scale=0.667, noise_scale_w=0.8, 
            length_scale=1.0 / speed
        )[0][0, 0].data.cpu().float().numpy()
    return audio, hps.data.sampling_rate

def read_json(path):
    with open(path, "r") as f:
        return json.loads(f.read())


def get_model_data(model):
    return next((m for m in models_tts if m["name"].lower() == model.lower()), None)
   
@app.route("/")
def index():
    return jsonify({"status": "OK" })

@app.route("/tosymbol", methods=["GET", "POST"])
def to_symbol():
    text = request.args.get("text") if request.method == "GET" else request.json.get("text")
    if text is None:
        return jsonify({ "error": "text is required"}), 400
    return _clean_text(text)

@app.route("/<model>/speakers", methods=["GET"])
def speakers(model):
    global MODEL
    model = model.lower()
    model_info = MODEL.get(model, None)
    
    if model_info is None:
        return jsonify({ "error": f"Model not found for `{model}`"}), 404
        
    config = read_json(model_info["config_path"])
    return jsonify({"model_name": model_info["title"], "speakers": config["speakers"] })

@app.route("/<model>/generate", methods=["POST", "GET"])
def generate(model):
    data = request.json if request.method == "POST" else request.args
    text = data.get("text")
    speaker = data.get("speaker")
    speed = float(data.get("speed", 1.0))
    is_symbol = data.get("is_symbol", False)
    speaker_id = data.get("speaker_id")
    language = data.get("lang")
    if not text:
        return jsonify({"error": "Missing parameter 'text'"}), 400

    model_data = get_model_data(model)
    if not model_data:
        return jsonify({"error": "Model not found"}), 404
    
    speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() }
    
    if language is not None:
        is_ja = model.lower() == "japanese"
        if is_ja:
            language = None        
        elif not is_ja and language_marks.get(language) is None:
            return jsonify({ "error": "language not available", "language": language_marks.keys() })
        
        
    if not speaker:
        if speaker_id is not None:
            speaker = speaker_ids.get(str(speaker_id), None)       
            if not speaker:
                return jsonify({"error": f"Speaker ID `{speaker_id}` not found"}), 404
        else:
            return jsonify({"error": "Missing 'speaker' or 'speaker_id'"}), 400

    if speaker not in model_data["speaker_ids"]:
        return jsonify({"error": f"Speaker `{speaker}` not found"}), 404

    try:
        audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol, language)
        temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
        sf.write(temp_wav.name, audio, sampling_rate, format="wav")
        temp_wav.close()
        return send_file(temp_wav.name, as_attachment=True, download_name="output.wav")
    except Exception as e:
        print(e)
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, debug=True)