|
|
|
|
| import os
|
|
|
| inp_text = os.environ.get("inp_text")
|
| inp_wav_dir = os.environ.get("inp_wav_dir")
|
| exp_name = os.environ.get("exp_name")
|
| i_part = os.environ.get("i_part")
|
| all_parts = os.environ.get("all_parts")
|
| os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
|
| opt_dir = os.environ.get("opt_dir")
|
| bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
|
| import torch
|
| is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
| version = os.environ.get('version', None)
|
| import sys, numpy as np, traceback, pdb
|
| import os.path
|
| from glob import glob
|
| from tqdm import tqdm
|
| from text.cleaner import clean_text
|
| from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| import numpy as np
|
| from tools.my_utils import clean_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from time import time as ttime
|
| import shutil
|
|
|
|
|
| def my_save(fea,path):
|
| dir=os.path.dirname(path)
|
| name=os.path.basename(path)
|
|
|
| tmp_path="%s%s.pth"%(ttime(),i_part)
|
| torch.save(fea,tmp_path)
|
| shutil.move(tmp_path,"%s/%s"%(dir,name))
|
|
|
|
|
| txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
|
| if os.path.exists(txt_path) == False:
|
| bert_dir = "%s/3-bert" % (opt_dir)
|
| os.makedirs(opt_dir, exist_ok=True)
|
| os.makedirs(bert_dir, exist_ok=True)
|
| if torch.cuda.is_available():
|
| device = "cuda:0"
|
|
|
|
|
| else:
|
| device = "cpu"
|
| if os.path.exists(bert_pretrained_dir):...
|
| else:raise FileNotFoundError(bert_pretrained_dir)
|
| tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
| bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
| if is_half == True:
|
| bert_model = bert_model.half().to(device)
|
| else:
|
| bert_model = bert_model.to(device)
|
|
|
| def get_bert_feature(text, word2ph):
|
| with torch.no_grad():
|
| inputs = tokenizer(text, return_tensors="pt")
|
| for i in inputs:
|
| inputs[i] = inputs[i].to(device)
|
| res = bert_model(**inputs, output_hidden_states=True)
|
| res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
|
|
| assert len(word2ph) == len(text)
|
| phone_level_feature = []
|
| for i in range(len(word2ph)):
|
| repeat_feature = res[i].repeat(word2ph[i], 1)
|
| phone_level_feature.append(repeat_feature)
|
|
|
| phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
|
| return phone_level_feature.T
|
|
|
| def process(data, res):
|
| for name, text, lan in data:
|
| try:
|
| name=clean_path(name)
|
| name = os.path.basename(name)
|
| print(name)
|
| phones, word2ph, norm_text = clean_text(
|
| text.replace("%", "-").replace("¥", ","), lan, version
|
| )
|
| path_bert = "%s/%s.pt" % (bert_dir, name)
|
| if os.path.exists(path_bert) == False and lan == "zh":
|
| bert_feature = get_bert_feature(norm_text, word2ph)
|
| assert bert_feature.shape[-1] == len(phones)
|
|
|
| my_save(bert_feature, path_bert)
|
| phones = " ".join(phones)
|
|
|
| res.append([name, phones, word2ph, norm_text])
|
| except:
|
| print(name, text, traceback.format_exc())
|
|
|
| todo = []
|
| res = []
|
| with open(inp_text, "r", encoding="utf8") as f:
|
| lines = f.read().strip("\n").split("\n")
|
|
|
| language_v1_to_language_v2 = {
|
| "ZH": "zh",
|
| "zh": "zh",
|
| "JP": "ja",
|
| "jp": "ja",
|
| "JA": "ja",
|
| "ja": "ja",
|
| "EN": "en",
|
| "en": "en",
|
| "En": "en",
|
| "KO": "ko",
|
| "Ko": "ko",
|
| "ko": "ko",
|
| "yue": "yue",
|
| "YUE": "yue",
|
| "Yue": "yue",
|
| }
|
| for line in lines[int(i_part) :: int(all_parts)]:
|
| try:
|
| wav_name, spk_name, language, text = line.split("|")
|
|
|
| if language in language_v1_to_language_v2.keys():
|
| todo.append(
|
| [wav_name, text, language_v1_to_language_v2.get(language, language)]
|
| )
|
| else:
|
| print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
|
| except:
|
| print(line, traceback.format_exc())
|
|
|
| process(todo, res)
|
| opt = []
|
| for name, phones, word2ph, norm_text in res:
|
| opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
|
| with open(txt_path, "w", encoding="utf8") as f:
|
| f.write("\n".join(opt) + "\n")
|
|
|