jaeikkim
Reinit Space without binary assets
7bfbdc3
import os
import numpy as np
from scipy.io.wavfile import write
import torch
import random
from importlib import import_module
from omegaconf import OmegaConf
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "./speech_tokenization/UVITS"))
import utils
from models.speech_tokenization.UVITS.models import SynthesizerTrn
from text import text_to_sequence
from my_synthesis.my_synthesis_for_speech_unit_sequence_recombination import get_U2S_config_checkpoint_file
sys.path.append(os.path.join(os.path.dirname(__file__), "./speech_tokenization/SPIRAL_L2_BN_FSQ_CTC"))
from my_extract_unit_for_speech.extract_unit_construct_wav_unit_text import \
get_S2U_ckpt_config_path, sample_extract_unit, batch_extract_unit
from nemo.collections.asr.models.spec2vec.vq_ctc_finetune import VQCTCFinetuneModel
from nemo.utils import logging
#################
# S2U
#################
def load_config(config=None):
if config is not None:
print("Config: ", config)
cfg_module = import_module(config.replace('/', '.'))
cfg = OmegaConf.structured(cfg_module.cfg)
OmegaConf.set_struct(cfg, True)
return cfg
def load_S2U_model(ckpt_path, config_path, model_name):
assert model_name in ['SPIRAL-FSQ-CTC']
cfg = load_config(config=config_path)
cfg.model.pretrain_chkpt_path = None
model = VQCTCFinetuneModel(cfg=cfg.model, trainer=None).eval()
model = model.to(dtype=torch.float32)
checkpoint = torch.load(os.path.join(os.path.dirname(__file__), ckpt_path), map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
if(missing_keys):
logging.warning('Missing Keys: {}'.format(missing_keys))
if(unexpected_keys):
logging.warning('Unexpected Keys: {}'.format(unexpected_keys))
return model
def s2u_extract_unit_demo(model, wav_path, model_name, reduced=True):
assert model_name in ['SPIRAL-FSQ-CTC']
wav_file_list = [wav_path]
wav_file_list_len = 1
extracted_wav_file_list, skipped_wav_file_list, unreduced_unit_sequence_list, reduced_unit_sequence_list = batch_extract_unit(wav_file_list, model, max_chunk=960000)
target_unit_sequence_list = reduced_unit_sequence_list if reduced else unreduced_unit_sequence_list
if len(extracted_wav_file_list) != 0:
target_unit_sequence = target_unit_sequence_list[0]
else:
wav_file = skipped_wav_file_list[0]
unreduced_unit_sequence, reduced_unit_sequence = sample_extract_unit(wav_file, model)
target_unit_sequence = reduced_unit_sequence if reduced else unreduced_unit_sequence
return "".join(["<|speech_{}|>".format(each) for each in target_unit_sequence.split(" ")])
#################
# U2S
#################
def load_condition_centroid(condition2style_centroid_file):
with open(os.path.join(os.path.dirname(__file__), condition2style_centroid_file), 'r') as f:
line_list = [line.replace('\n', '') for line in f]
assert line_list[0] == 'condition|style_centroid_file'
condition2style_centroid_file_dict, condition2style_centroid_embedding_dict = {}, {}
for line in line_list[1:]:
condition, style_centroid_file = line.split('|')
condition2style_centroid_file_dict[condition] = style_centroid_file
style_centroid_embedding = np.load(os.path.join(os.path.dirname(__file__), style_centroid_file))
style_centroid_embedding = torch.FloatTensor(style_centroid_embedding).unsqueeze(1).unsqueeze(0)
condition2style_centroid_embedding_dict[condition] = style_centroid_embedding
return condition2style_centroid_file_dict, condition2style_centroid_embedding_dict
def load_U2S_config(model_config_file):
hps = utils.get_hparams_from_file(os.path.join(os.path.dirname(__file__), model_config_file))
from text.symbols import symbols_with_4096 as symbols
hps.num_symbols = len(symbols)
return hps
def load_U2S_model(model_config_file, model_checkpoint_file, unit_type, ):
# load model
hps = utils.get_hparams_from_file(os.path.join(os.path.dirname(__file__), model_config_file))
from text.symbols import symbols_with_4096 as symbols
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
net_g.eval()
utils.load_checkpoint(os.path.join(os.path.dirname(__file__), model_checkpoint_file), net_g, None)
return net_g, hps
def synthesis(unit_sequence, style_embedding, hps, net_g, output_wav_file='output.wav'):
# synthesize speech
device = next(net_g.parameters()).device # we assume speech tokenizer is stored in a single device
logging.info("Generating audios on {}".format(device))
with torch.no_grad():
unit_sequence = text_to_sequence(unit_sequence, hps.data.text_cleaners)
unit_sequence = torch.LongTensor(unit_sequence)
unit_sequence = unit_sequence.unsqueeze(0).to(device)
unit_lengths = torch.LongTensor([unit_sequence.size(1)]).to(device)
if style_embedding is not None:
style_embedding = style_embedding.to(device)
audio = net_g.synthesis_from_content_unit_style_embedding(
unit_sequence, unit_lengths, style_embedding,
noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0, 0].data.cpu().float().numpy()
write(output_wav_file, hps.data.sampling_rate, audio)
print(f'synthesized sample is saved as {output_wav_file}')
return audio
if __name__ == "__main__":
#################
# NPU
#################
try:
import torch_npu
from torch_npu.npu import amp
from torch_npu.contrib import transfer_to_npu
print('Successful import torch_npu')
except Exception as e:
print(e)
############
# S2U
############
reduced = True
reduced_mark = 'reduced' if reduced else 'unreduced'
unit_type = '40ms_multilingual_8888'
S2U_model_name = 'SPIRAL-FSQ-CTC'
S2U_ckpt_path, S2U_config_path = get_S2U_ckpt_config_path(unit_type)
S2U_model = load_S2U_model(S2U_ckpt_path, S2U_config_path, S2U_model_name)
S2U_model = S2U_model.cuda()
wav_file = "./examples/s2u/example.wav"
speech_unit = s2u_extract_unit_demo(S2U_model, wav_file, model_name=S2U_model_name, reduced=reduced)
print(speech_unit)
############
# U2S
############
condition2style_centroid_file = "./speech_tokenization/condition_style_centroid/condition2style_centroid.txt"
condition2style_centroid_file_dict, condition2style_centroid_embedding_dict = load_condition_centroid(condition2style_centroid_file)
unit_type = '40ms_multilingual_8888_xujing_cosyvoice_FT'
U2S_config_file, U2S_checkpoint_file = get_U2S_config_checkpoint_file(unit_type)
net_g, hps = load_U2S_model(U2S_config_file, U2S_checkpoint_file, unit_type)
net_g = net_g.cuda()
content_unit = speech_unit.replace('<|speech_', '').replace('|>', ' ').strip()
emotion = random.choice(['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'sad', 'surprised'])
speed = random.choice(['normal', 'fast', 'slow'])
pitch = random.choice(['normal', 'high', 'low'])
gender = random.choice(['female', 'male'])
condition = f'gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}'
style_centroid_file = condition2style_centroid_file_dict[condition]
style_centroid_embedding = condition2style_centroid_embedding_dict[condition]
output_wav_file = f'./examples/u2s/{condition}_output.wav'
synthesis(content_unit, style_centroid_embedding, hps, net_g, output_wav_file)