| import argparse |
| import os |
| from pathlib import Path |
|
|
| import logging |
| import re_matching |
|
|
| logging.getLogger("numba").setLevel(logging.WARNING) |
| logging.getLogger("markdown_it").setLevel(logging.WARNING) |
| logging.getLogger("urllib3").setLevel(logging.WARNING) |
| logging.getLogger("matplotlib").setLevel(logging.WARNING) |
|
|
| logging.basicConfig( |
| level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s" |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| import librosa |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| from clap_wrapper import get_clap_audio_feature, get_clap_text_feature |
|
|
| import uuid |
| from flask import Flask, request, jsonify, render_template_string |
| from flask_cors import CORS |
|
|
| import gradio as gr |
|
|
| import utils |
| from config import config |
|
|
| import torch |
| import commons |
| from text import cleaned_text_to_sequence, get_bert |
| from text.cleaner import clean_text |
| import utils |
|
|
| from models import SynthesizerTrn |
| from text.symbols import symbols |
| import sys |
| from scipy.io.wavfile import write |
| from threading import Thread |
|
|
| net_g = None |
|
|
| device = ( |
| "cuda:0" |
| if torch.cuda.is_available() |
| else ( |
| "mps" |
| if sys.platform == "darwin" and torch.backends.mps.is_available() |
| else "cpu" |
| ) |
| ) |
|
|
| |
| BandList = { |
| "PoppinParty":["香澄","有咲","たえ","りみ","沙綾"], |
| "Afterglow":["蘭","モカ","ひまり","巴","つぐみ"], |
| "HelloHappyWorld":["こころ","美咲","薫","花音","はぐみ"], |
| "PastelPalettes":["彩","日菜","千聖","イヴ","麻弥"], |
| "Roselia":["友希那","紗夜","リサ","燐子","あこ"], |
| "RaiseASuilen":["レイヤ","ロック","ますき","チュチュ","パレオ"], |
| "Morfonica":["ましろ","瑠唯","つくし","七深","透子"], |
| "MyGo":["燈","愛音","そよ","立希","楽奈"], |
| "AveMujica":["祥子","睦","海鈴","にゃむ","初華"], |
| "圣翔音乐学园":["華戀","光","香子","雙葉","真晝","純那","克洛迪娜","真矢","奈奈"], |
| "凛明馆女子学校":["珠緒","壘","文","悠悠子","一愛"], |
| "弗隆提亚艺术学校":["艾露","艾露露","菈樂菲","司","靜羽"], |
| "西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"] |
| } |
|
|
| def get_net_g(model_path: str, device: str, hps): |
| net_g = SynthesizerTrn( |
| len(symbols), |
| hps.data.filter_length // 2 + 1, |
| hps.train.segment_size // hps.data.hop_length, |
| n_speakers=hps.data.n_speakers, |
| **hps.model, |
| ).to(device) |
| _ = net_g.eval() |
| _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True) |
| return net_g |
|
|
| def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7): |
| style_text = None if style_text == "" else style_text |
| norm_text, phone, tone, word2ph = clean_text(text, language_str) |
| phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) |
|
|
| if hps.data.add_blank: |
| phone = commons.intersperse(phone, 0) |
| tone = commons.intersperse(tone, 0) |
| language = commons.intersperse(language, 0) |
| for i in range(len(word2ph)): |
| word2ph[i] = word2ph[i] * 2 |
| word2ph[0] += 1 |
| bert_ori = get_bert( |
| norm_text, word2ph, language_str, device, style_text, style_weight |
| ) |
| del word2ph |
| assert bert_ori.shape[-1] == len(phone), phone |
|
|
| if language_str == "ZH": |
| bert = bert_ori |
| ja_bert = torch.randn(1024, len(phone)) |
| en_bert = torch.randn(1024, len(phone)) |
| elif language_str == "JP": |
| bert = torch.randn(1024, len(phone)) |
| ja_bert = bert_ori |
| en_bert = torch.randn(1024, len(phone)) |
| elif language_str == "EN": |
| bert = torch.randn(1024, len(phone)) |
| ja_bert = torch.randn(1024, len(phone)) |
| en_bert = bert_ori |
| else: |
| raise ValueError("language_str should be ZH, JP or EN") |
|
|
| assert bert.shape[-1] == len( |
| phone |
| ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" |
|
|
| phone = torch.LongTensor(phone) |
| tone = torch.LongTensor(tone) |
| language = torch.LongTensor(language) |
| return bert, ja_bert, en_bert, phone, tone, language |
|
|
|
|
| def infer( |
| text, |
| sdp_ratio, |
| noise_scale, |
| noise_scale_w, |
| length_scale, |
| sid, |
| style_text=None, |
| style_weight=0.7, |
| ): |
|
|
| language= 'JP' if is_japanese(text) else 'ZH' |
| bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( |
| text, |
| language, |
| hps, |
| device, |
| style_text=style_text, |
| style_weight=style_weight, |
| ) |
| with torch.no_grad(): |
| x_tst = phones.to(device).unsqueeze(0) |
| tones = tones.to(device).unsqueeze(0) |
| lang_ids = lang_ids.to(device).unsqueeze(0) |
| bert = bert.to(device).unsqueeze(0) |
| ja_bert = ja_bert.to(device).unsqueeze(0) |
| en_bert = en_bert.to(device).unsqueeze(0) |
| x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) |
| |
| del phones |
| speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) |
| audio = ( |
| net_g.infer( |
| x_tst, |
| x_tst_lengths, |
| speakers, |
| tones, |
| lang_ids, |
| bert, |
| ja_bert, |
| en_bert, |
| sdp_ratio=sdp_ratio, |
| noise_scale=noise_scale, |
| noise_scale_w=noise_scale_w, |
| length_scale=length_scale, |
| )[0][0, 0] |
| .data.cpu() |
| .float() |
| .numpy() |
| ) |
| del ( |
| x_tst, |
| tones, |
| lang_ids, |
| bert, |
| x_tst_lengths, |
| speakers, |
| ja_bert, |
| en_bert, |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio)) |
|
|
| def inferAPI( |
| text, |
| sdp_ratio, |
| noise_scale, |
| noise_scale_w, |
| length_scale, |
| sid, |
| style_text=None, |
| style_weight=0.7, |
| ): |
|
|
| language= 'JP' if is_japanese(text) else 'ZH' |
| bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( |
| text, |
| language, |
| hps, |
| device, |
| style_text=style_text, |
| style_weight=style_weight, |
| ) |
| with torch.no_grad(): |
| x_tst = phones.to(device).unsqueeze(0) |
| tones = tones.to(device).unsqueeze(0) |
| lang_ids = lang_ids.to(device).unsqueeze(0) |
| bert = bert.to(device).unsqueeze(0) |
| ja_bert = ja_bert.to(device).unsqueeze(0) |
| en_bert = en_bert.to(device).unsqueeze(0) |
| x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) |
| |
| del phones |
| speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) |
| audio = ( |
| net_g.infer( |
| x_tst, |
| x_tst_lengths, |
| speakers, |
| tones, |
| lang_ids, |
| bert, |
| ja_bert, |
| en_bert, |
| sdp_ratio=sdp_ratio, |
| noise_scale=noise_scale, |
| noise_scale_w=noise_scale_w, |
| length_scale=length_scale, |
| )[0][0, 0] |
| .data.cpu() |
| .float() |
| .numpy() |
| ) |
| del ( |
| x_tst, |
| tones, |
| lang_ids, |
| bert, |
| x_tst_lengths, |
| speakers, |
| ja_bert, |
| en_bert, |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| unique_filename = f"temp{uuid.uuid4()}.wav" |
| write(unique_filename, 44100, audio) |
| return unique_filename |
|
|
| def is_japanese(string): |
| for ch in string: |
| if ord(ch) > 0x3040 and ord(ch) < 0x30FF: |
| return True |
| return False |
|
|
| def loadmodel(model): |
| try: |
| _ = net_g.eval() |
| _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True) |
| return "success" |
| except: |
| return "error" |
|
|
| Flaskapp = Flask(__name__) |
| CORS(Flaskapp) |
| @Flaskapp.route('/') |
|
|
| @Flaskapp.route('/') |
|
|
| def tts(): |
| global last_text, last_model |
| speaker = request.args.get('speaker') |
| sdp_ratio = float(request.args.get('sdp_ratio', 0.2)) |
| noise_scale = float(request.args.get('noise_scale', 0.6)) |
| noise_scale_w = float(request.args.get('noise_scale_w', 0.8)) |
| length_scale = float(request.args.get('length_scale', 1)) |
| style_weight = float(request.args.get('style_weight', 0.7)) |
| style_text = request.args.get('style_text', 'happy') |
| text = request.args.get('text') |
| is_chat = request.args.get('is_chat', 'false').lower() == 'true' |
| model = request.args.get('model',modelPaths[-1]) |
| |
| if not speaker or not text: |
| return render_template_string(""" |
| <!DOCTYPE html> |
| <html> |
| <head> |
| <title>TTS API Documentation</title> |
| </head> |
| <body> |
| <iframe src="http://127.0.0.1:7860" style="width:100%; height:100vh; border:none;"></iframe> |
| </body> |
| </html> |
| """) |
| |
| if model != last_model: |
| unique_filename = loadmodel(model) |
| last_model = model |
| if is_chat and text == last_text: |
| |
| unique_filename = 'blank.wav' |
| silence = np.zeros(44100, dtype=np.int16) |
| write(unique_filename , 44100, silence) |
| else: |
| last_text = text |
| unique_filename = inferAPI(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, style_text=style_text, style_weight=style_weight) |
| with open(unique_filename ,'rb') as bit: |
| wav_bytes = bit.read() |
| os.remove(unique_filename) |
| headers = { |
| 'Content-Type': 'audio/wav', |
| 'Text': unique_filename .encode('utf-8')} |
| return wav_bytes, 200, headers |
|
|
| def gradio_interface(): |
| return app.launch(share=True) |
|
|
| if __name__ == "__main__": |
| languages = [ "Auto", "ZH", "JP"] |
| modelPaths = [] |
| for dirpath, dirnames, filenames in os.walk('Data/Data/V23/models/'): |
| for filename in filenames: |
| modelPaths.append(os.path.join(dirpath, filename)) |
| hps = utils.get_hparams_from_file('Data/Data/V23/configs/config.json') |
| net_g = get_net_g( |
| model_path=modelPaths[-1], device=device, hps=hps |
| ) |
| speaker_ids = hps.data.spk2id |
| speakers = list(speaker_ids.keys()) |
| last_text = "" |
| last_model = modelPaths[-1] |
| with gr.Blocks() as app: |
| for band in BandList: |
| with gr.TabItem(band): |
| for name in BandList[band]: |
| with gr.TabItem(name): |
| with gr.Row(): |
| with gr.Column(): |
| with gr.Row(): |
| gr.Markdown( |
| '<div align="center">' |
| f'<img style="width:auto;height:400px;" src="https://mahiruoshi-bangdream-bert-vits2.hf.space/file/image/{name}.png">' |
| '</div>' |
| ) |
| length_scale = gr.Slider( |
| minimum=0.1, maximum=2, value=1, step=0.01, label="语速调节" |
| ) |
| with gr.Accordion(label="参数设定", open=False): |
| sdp_ratio = gr.Slider( |
| minimum=0, maximum=1, value=0.5, step=0.01, label="SDP/DP混合比" |
| ) |
| noise_scale = gr.Slider( |
| minimum=0.1, maximum=2, value=0.6, step=0.01, label="感情调节" |
| ) |
| noise_scale_w = gr.Slider( |
| minimum=0.1, maximum=2, value=0.667, step=0.01, label="音素长度" |
| ) |
| speaker = gr.Dropdown( |
| choices=speakers, value=name, label="说话人" |
| ) |
| with gr.Accordion(label="切换模型", open=False): |
| modelstrs = gr.Dropdown(label = "模型", choices = modelPaths, value = modelPaths[0], type = "value") |
| btnMod = gr.Button("载入模型") |
| statusa = gr.TextArea() |
| btnMod.click(loadmodel, inputs=[modelstrs], outputs = [statusa]) |
| with gr.Column(): |
| text = gr.TextArea( |
| label="输入纯日语或者中文", |
| placeholder="输入纯日语或者中文", |
| value="为什么要演奏春日影!", |
| ) |
| style_text = gr.Textbox(label="辅助文本") |
| style_weight = gr.Slider( |
| minimum=0, |
| maximum=1, |
| value=0.7, |
| step=0.1, |
| label="Weight", |
| info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本", |
| ) |
| btn = gr.Button("点击生成", variant="primary") |
| audio_output = gr.Audio(label="Output Audio") |
| ''' |
| btntran = gr.Button("快速中翻日") |
| translateResult = gr.TextArea("从这复制翻译后的文本") |
| btntran.click(translate, inputs=[text], outputs = [translateResult]) |
| ''' |
| btn.click( |
| infer, |
| inputs=[ |
| text, |
| sdp_ratio, |
| noise_scale, |
| noise_scale_w, |
| length_scale, |
| speaker, |
| style_text, |
| style_weight, |
| ], |
| outputs=[audio_output], |
| ) |
|
|
| api_thread = Thread(target=Flaskapp.run, args=("0.0.0.0", 5000)) |
| gradio_thread = Thread(target=gradio_interface) |
| gradio_thread.start() |
| print("推理页面已开启!") |
| api_thread.start() |
| print("api页面已开启!运行在5000端口") |