import os from scipy.io.wavfile import write from text import text_to_sequence, _clean_text from models import SynthesizerTrn import utils import commons import sys import re from torch import no_grad, LongTensor import logging from flask import Flask, request, send_file import uuid import subprocess import ffmpeg from io import BytesIO app = Flask(__name__) app.config['JSON_AS_ASCII'] = False logging.getLogger('numba').setLevel(logging.WARNING) class Voice: def __init__(self, model, config, out_path=None): self.out_path = out_path self.hps_ms = utils.get_hparams_from_file(config) n_speakers = self.hps_ms.data.n_speakers if 'n_speakers' in self.hps_ms.data.keys() else 0 n_symbols = len(self.hps_ms.symbols) if 'symbols' in self.hps_ms.keys() else 0 self.speakers = self.hps_ms.speakers if 'speakers' in self.hps_ms.keys() else ['0'] use_f0 = self.hps_ms.data.use_f0 if 'use_f0' in self.hps_ms.data.keys() else False self.emotion_embedding = self.hps_ms.data.emotion_embedding if 'emotion_embedding' in self.hps_ms.data.keys() else False self.net_g_ms = SynthesizerTrn( n_symbols, self.hps_ms.data.filter_length // 2 + 1, self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, n_speakers=n_speakers, emotion_embedding=self.emotion_embedding, **self.hps_ms.model) _ = self.net_g_ms.eval() utils.load_checkpoint(model, self.net_g_ms) def generate(self, text, speaker_id, format): if not self.emotion_embedding: length_scale, text = self.get_label_value( text, 'LENGTH', 1, 'length scale') noise_scale, text = self.get_label_value( text, 'NOISE', 0.667, 'noise scale') noise_scale_w, text = self.get_label_value( text, 'NOISEW', 0.8, 'deviation of noise') cleaned, text = self.get_label(text, 'CLEANED') stn_tst = self.get_text(text, self.hps_ms, cleaned=cleaned) with no_grad(): x_tst = stn_tst.unsqueeze(0) x_tst_lengths = LongTensor([stn_tst.size(0)]) sid = LongTensor([speaker_id]) audio = \ self.net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0, 0].data.cpu().float().numpy() file_name = str(uuid.uuid1()) with BytesIO() as f: if format == 'ogg': file_path = self.out_path+"/"+file_name+".wav" out_path = self.out_path+"/"+file_name+".ogg" write(file_path, self.hps_ms.data.sampling_rate, audio) f.seek(0, 0) #file=BytesIO(f.getvalue()) with BytesIO() as ofp: ffmpeg.input(file_path).output(out_path).run() return out_path, "audio/ogg", file_name + ".ogg", else: write(f, self.hps_ms.data.sampling_rate, audio) f.seek(0, 0) return BytesIO(f.getvalue()), "audio/wav", file_name + ".wav", def run_script(self, file_path): out_path = file_path.split('.')[0] + ".ogg" ffmpeg.input(file_path).output(out_path).run() subprocess.run(["rm " + file_path], shell=True, timeout=5) return out_path def get_text(self, text, hps, cleaned=False): if cleaned: text_norm = text_to_sequence(text, hps.symbols, []) else: text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) if hps.data.add_blank: text_norm = commons.intersperse(text_norm, 0) text_norm = LongTensor(text_norm) return text_norm def get_label_value(self, text, label, default, warning_name='value'): value = re.search(rf'\[{label}=(.+?)\]', text) if value: try: text = re.sub(rf'\[{label}=(.+?)\]', '', text, 1) value = float(value.group(1)) except: print(f'Invalid {warning_name}!') sys.exit(1) else: value = default return value, text def ex_return(self, text, escape=False): if escape: return text.encode('unicode_escape').decode() else: return text def return_speakers(self, escape=False): if len(self.speakers) > 100: return # print('ID\tSpeaker') speakers_list = [] for id, name in enumerate(self.speakers): speakers_list.append(self.ex_return(str(id) + '\t' + name, escape)) return speakers_list def get_label(self, text, label): if f'[{label}]' in text: return True, text.replace(f'[{label}]', '') else: return False, text """ VITS Model example model_zh = "model_path" config_zh = "config.json_path" voice = Voice(model, config) """ # 可能遇到获取不到绝对路径的情况,取消以下注释使用可以取到绝对路径的方法替换下面的路径即可 # print("os.path.dirname(__file__)",os.path.dirname(__file__)) # print("os.path.dirname(sys.argv[0])",os.path.dirname(sys.argv[0])) # print("os.path.realpath(sys.argv[0])",os.path.realpath(sys.argv[0])) # print("os.path.dirname(os.path.realpath(sys.argv[0]))",os.path.dirname(__file__)) out_path = os.path.dirname(__file__) + "/output/" model_zh = os.path.dirname(__file__) + "/Model/Nene_Nanami_Rong_Tang/1374_epochs.pth" config_zh = os.path.dirname(__file__) + "/Model/Nene_Nanami_Rong_Tang/config.json" voice_zh = Voice(model_zh, config_zh, out_path) model_ja = os.path.dirname(__file__) + "/Model/Zero_no_tsukaima/1158_epochs.pth" config_ja = os.path.dirname(__file__) + "/Model/Zero_no_tsukaima/config.json" voice_ja = Voice(model_ja, config_ja, out_path) @app.route('/api/') def index(): return "usage:/api/zh?text=text&id=3&format=wav" @app.route('/api/ja/speakers') def voice_speakers_ja(): escape = False speakers_list = voice_ja.return_speakers(escape) return speakers_list @app.route('/api/ja', methods=["GET"]) def api_voice_ja(): text = "[JA]" + request.args.get("text") + "[JA]" speaker_id = int(request.args.get("id", 0)) format = request.args.get("format", "wav") output = voice_ja.generate(text, speaker_id, format) return send_file(output) @app.route('/api/zh/speakers') def voice_speakers_zh(): escape = False speakers_list = voice_zh.return_speakers(escape) return speakers_list @app.route('/api/zh', methods=["GET"]) def api_voice_zh(): text = "[ZH]" + request.args.get("text") + "[ZH]" speaker_id = int(request.args.get("id", 3)) format = request.args.get("format", "wav") output, type, file_name = voice_zh.generate(text, speaker_id, format) return send_file(path_or_file=output, mimetype=type, download_name=file_name) if __name__ == '__main__': app.run(host='0.0.0.0', port=23456, debug=True) # 如果对外开放用这个 # app.run(host='127.0.0.1', port=23456) # 本地运行