Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| from scipy.io.wavfile import write | |
| import torch | |
| import random | |
| from importlib import import_module | |
| from omegaconf import OmegaConf | |
| import sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "./speech_tokenization/UVITS")) | |
| import utils | |
| from models.speech_tokenization.UVITS.models import SynthesizerTrn | |
| from text import text_to_sequence | |
| from my_synthesis.my_synthesis_for_speech_unit_sequence_recombination import get_U2S_config_checkpoint_file | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "./speech_tokenization/SPIRAL_L2_BN_FSQ_CTC")) | |
| from my_extract_unit_for_speech.extract_unit_construct_wav_unit_text import \ | |
| get_S2U_ckpt_config_path, sample_extract_unit, batch_extract_unit | |
| from nemo.collections.asr.models.spec2vec.vq_ctc_finetune import VQCTCFinetuneModel | |
| from nemo.utils import logging | |
| ################# | |
| # S2U | |
| ################# | |
| def load_config(config=None): | |
| if config is not None: | |
| print("Config: ", config) | |
| cfg_module = import_module(config.replace('/', '.')) | |
| cfg = OmegaConf.structured(cfg_module.cfg) | |
| OmegaConf.set_struct(cfg, True) | |
| return cfg | |
| def load_S2U_model(ckpt_path, config_path, model_name): | |
| assert model_name in ['SPIRAL-FSQ-CTC'] | |
| cfg = load_config(config=config_path) | |
| cfg.model.pretrain_chkpt_path = None | |
| model = VQCTCFinetuneModel(cfg=cfg.model, trainer=None).eval() | |
| model = model.to(dtype=torch.float32) | |
| checkpoint = torch.load(os.path.join(os.path.dirname(__file__), ckpt_path), map_location='cpu') | |
| missing_keys, unexpected_keys = model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| if(missing_keys): | |
| logging.warning('Missing Keys: {}'.format(missing_keys)) | |
| if(unexpected_keys): | |
| logging.warning('Unexpected Keys: {}'.format(unexpected_keys)) | |
| return model | |
| def s2u_extract_unit_demo(model, wav_path, model_name, reduced=True): | |
| assert model_name in ['SPIRAL-FSQ-CTC'] | |
| wav_file_list = [wav_path] | |
| wav_file_list_len = 1 | |
| extracted_wav_file_list, skipped_wav_file_list, unreduced_unit_sequence_list, reduced_unit_sequence_list = batch_extract_unit(wav_file_list, model, max_chunk=960000) | |
| target_unit_sequence_list = reduced_unit_sequence_list if reduced else unreduced_unit_sequence_list | |
| if len(extracted_wav_file_list) != 0: | |
| target_unit_sequence = target_unit_sequence_list[0] | |
| else: | |
| wav_file = skipped_wav_file_list[0] | |
| unreduced_unit_sequence, reduced_unit_sequence = sample_extract_unit(wav_file, model) | |
| target_unit_sequence = reduced_unit_sequence if reduced else unreduced_unit_sequence | |
| return "".join(["<|speech_{}|>".format(each) for each in target_unit_sequence.split(" ")]) | |
| ################# | |
| # U2S | |
| ################# | |
| def load_condition_centroid(condition2style_centroid_file): | |
| with open(os.path.join(os.path.dirname(__file__), condition2style_centroid_file), 'r') as f: | |
| line_list = [line.replace('\n', '') for line in f] | |
| assert line_list[0] == 'condition|style_centroid_file' | |
| condition2style_centroid_file_dict, condition2style_centroid_embedding_dict = {}, {} | |
| for line in line_list[1:]: | |
| condition, style_centroid_file = line.split('|') | |
| condition2style_centroid_file_dict[condition] = style_centroid_file | |
| style_centroid_embedding = np.load(os.path.join(os.path.dirname(__file__), style_centroid_file)) | |
| style_centroid_embedding = torch.FloatTensor(style_centroid_embedding).unsqueeze(1).unsqueeze(0) | |
| condition2style_centroid_embedding_dict[condition] = style_centroid_embedding | |
| return condition2style_centroid_file_dict, condition2style_centroid_embedding_dict | |
| def load_U2S_config(model_config_file): | |
| hps = utils.get_hparams_from_file(os.path.join(os.path.dirname(__file__), model_config_file)) | |
| from text.symbols import symbols_with_4096 as symbols | |
| hps.num_symbols = len(symbols) | |
| return hps | |
| def load_U2S_model(model_config_file, model_checkpoint_file, unit_type, ): | |
| # load model | |
| hps = utils.get_hparams_from_file(os.path.join(os.path.dirname(__file__), model_config_file)) | |
| from text.symbols import symbols_with_4096 as symbols | |
| net_g = SynthesizerTrn( | |
| len(symbols), | |
| hps.data.filter_length // 2 + 1, | |
| hps.train.segment_size // hps.data.hop_length, | |
| n_speakers=hps.data.n_speakers, | |
| **hps.model) | |
| net_g.eval() | |
| utils.load_checkpoint(os.path.join(os.path.dirname(__file__), model_checkpoint_file), net_g, None) | |
| return net_g, hps | |
| def synthesis(unit_sequence, style_embedding, hps, net_g, output_wav_file='output.wav'): | |
| # synthesize speech | |
| device = next(net_g.parameters()).device # we assume speech tokenizer is stored in a single device | |
| logging.info("Generating audios on {}".format(device)) | |
| with torch.no_grad(): | |
| unit_sequence = text_to_sequence(unit_sequence, hps.data.text_cleaners) | |
| unit_sequence = torch.LongTensor(unit_sequence) | |
| unit_sequence = unit_sequence.unsqueeze(0).to(device) | |
| unit_lengths = torch.LongTensor([unit_sequence.size(1)]).to(device) | |
| if style_embedding is not None: | |
| style_embedding = style_embedding.to(device) | |
| audio = net_g.synthesis_from_content_unit_style_embedding( | |
| unit_sequence, unit_lengths, style_embedding, | |
| noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0, 0].data.cpu().float().numpy() | |
| write(output_wav_file, hps.data.sampling_rate, audio) | |
| print(f'synthesized sample is saved as {output_wav_file}') | |
| return audio | |
| if __name__ == "__main__": | |
| ################# | |
| # NPU | |
| ################# | |
| try: | |
| import torch_npu | |
| from torch_npu.npu import amp | |
| from torch_npu.contrib import transfer_to_npu | |
| print('Successful import torch_npu') | |
| except Exception as e: | |
| print(e) | |
| ############ | |
| # S2U | |
| ############ | |
| reduced = True | |
| reduced_mark = 'reduced' if reduced else 'unreduced' | |
| unit_type = '40ms_multilingual_8888' | |
| S2U_model_name = 'SPIRAL-FSQ-CTC' | |
| S2U_ckpt_path, S2U_config_path = get_S2U_ckpt_config_path(unit_type) | |
| S2U_model = load_S2U_model(S2U_ckpt_path, S2U_config_path, S2U_model_name) | |
| S2U_model = S2U_model.cuda() | |
| wav_file = "./examples/s2u/example.wav" | |
| speech_unit = s2u_extract_unit_demo(S2U_model, wav_file, model_name=S2U_model_name, reduced=reduced) | |
| print(speech_unit) | |
| ############ | |
| # U2S | |
| ############ | |
| condition2style_centroid_file = "./speech_tokenization/condition_style_centroid/condition2style_centroid.txt" | |
| condition2style_centroid_file_dict, condition2style_centroid_embedding_dict = load_condition_centroid(condition2style_centroid_file) | |
| unit_type = '40ms_multilingual_8888_xujing_cosyvoice_FT' | |
| U2S_config_file, U2S_checkpoint_file = get_U2S_config_checkpoint_file(unit_type) | |
| net_g, hps = load_U2S_model(U2S_config_file, U2S_checkpoint_file, unit_type) | |
| net_g = net_g.cuda() | |
| content_unit = speech_unit.replace('<|speech_', '').replace('|>', ' ').strip() | |
| emotion = random.choice(['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'sad', 'surprised']) | |
| speed = random.choice(['normal', 'fast', 'slow']) | |
| pitch = random.choice(['normal', 'high', 'low']) | |
| gender = random.choice(['female', 'male']) | |
| condition = f'gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}' | |
| style_centroid_file = condition2style_centroid_file_dict[condition] | |
| style_centroid_embedding = condition2style_centroid_embedding_dict[condition] | |
| output_wav_file = f'./examples/u2s/{condition}_output.wav' | |
| synthesis(content_unit, style_centroid_embedding, hps, net_g, output_wav_file) |