| | |
| | """Extract Mel spectrograms with teacher forcing.""" |
| |
|
| | import argparse |
| | import os |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| |
|
| | from TTS.config import load_config |
| | from TTS.tts.datasets import TTSDataset, load_tts_samples |
| | from TTS.tts.models import setup_model |
| | from TTS.tts.utils.speakers import SpeakerManager |
| | from TTS.tts.utils.text.tokenizer import TTSTokenizer |
| | from TTS.utils.audio import AudioProcessor |
| | from TTS.utils.generic_utils import count_parameters |
| |
|
| | use_cuda = torch.cuda.is_available() |
| |
|
| |
|
| | def setup_loader(ap, r, verbose=False): |
| | tokenizer, _ = TTSTokenizer.init_from_config(c) |
| | dataset = TTSDataset( |
| | outputs_per_step=r, |
| | compute_linear_spec=False, |
| | samples=meta_data, |
| | tokenizer=tokenizer, |
| | ap=ap, |
| | batch_group_size=0, |
| | min_text_len=c.min_text_len, |
| | max_text_len=c.max_text_len, |
| | min_audio_len=c.min_audio_len, |
| | max_audio_len=c.max_audio_len, |
| | phoneme_cache_path=c.phoneme_cache_path, |
| | precompute_num_workers=0, |
| | use_noise_augment=False, |
| | verbose=verbose, |
| | speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, |
| | d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, |
| | ) |
| |
|
| | if c.use_phonemes and c.compute_input_seq_cache: |
| | |
| | dataset.compute_input_seq(c.num_loader_workers) |
| | dataset.preprocess_samples() |
| |
|
| | loader = DataLoader( |
| | dataset, |
| | batch_size=c.batch_size, |
| | shuffle=False, |
| | collate_fn=dataset.collate_fn, |
| | drop_last=False, |
| | sampler=None, |
| | num_workers=c.num_loader_workers, |
| | pin_memory=False, |
| | ) |
| | return loader |
| |
|
| |
|
| | def set_filename(wav_path, out_path): |
| | wav_file = os.path.basename(wav_path) |
| | file_name = wav_file.split(".")[0] |
| | os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) |
| | os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) |
| | os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) |
| | os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) |
| | wavq_path = os.path.join(out_path, "quant", file_name) |
| | mel_path = os.path.join(out_path, "mel", file_name) |
| | wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav") |
| | wav_path = os.path.join(out_path, "wav", file_name + ".wav") |
| | return file_name, wavq_path, mel_path, wav_gl_path, wav_path |
| |
|
| |
|
| | def format_data(data): |
| | |
| | text_input = data["token_id"] |
| | text_lengths = data["token_id_lengths"] |
| | mel_input = data["mel"] |
| | mel_lengths = data["mel_lengths"] |
| | item_idx = data["item_idxs"] |
| | d_vectors = data["d_vectors"] |
| | speaker_ids = data["speaker_ids"] |
| | attn_mask = data["attns"] |
| | avg_text_length = torch.mean(text_lengths.float()) |
| | avg_spec_length = torch.mean(mel_lengths.float()) |
| |
|
| | |
| | if use_cuda: |
| | text_input = text_input.cuda(non_blocking=True) |
| | text_lengths = text_lengths.cuda(non_blocking=True) |
| | mel_input = mel_input.cuda(non_blocking=True) |
| | mel_lengths = mel_lengths.cuda(non_blocking=True) |
| | if speaker_ids is not None: |
| | speaker_ids = speaker_ids.cuda(non_blocking=True) |
| | if d_vectors is not None: |
| | d_vectors = d_vectors.cuda(non_blocking=True) |
| | if attn_mask is not None: |
| | attn_mask = attn_mask.cuda(non_blocking=True) |
| | return ( |
| | text_input, |
| | text_lengths, |
| | mel_input, |
| | mel_lengths, |
| | speaker_ids, |
| | d_vectors, |
| | avg_text_length, |
| | avg_spec_length, |
| | attn_mask, |
| | item_idx, |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def inference( |
| | model_name, |
| | model, |
| | ap, |
| | text_input, |
| | text_lengths, |
| | mel_input, |
| | mel_lengths, |
| | speaker_ids=None, |
| | d_vectors=None, |
| | ): |
| | if model_name == "glow_tts": |
| | speaker_c = None |
| | if speaker_ids is not None: |
| | speaker_c = speaker_ids |
| | elif d_vectors is not None: |
| | speaker_c = d_vectors |
| | outputs = model.inference_with_MAS( |
| | text_input, |
| | text_lengths, |
| | mel_input, |
| | mel_lengths, |
| | aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, |
| | ) |
| | model_output = outputs["model_outputs"] |
| | model_output = model_output.detach().cpu().numpy() |
| |
|
| | elif "tacotron" in model_name: |
| | aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} |
| | outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input) |
| | postnet_outputs = outputs["model_outputs"] |
| | |
| | if model_name == "tacotron": |
| | mel_specs = [] |
| | postnet_outputs = postnet_outputs.data.cpu().numpy() |
| | for b in range(postnet_outputs.shape[0]): |
| | postnet_output = postnet_outputs[b] |
| | mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T)) |
| | model_output = torch.stack(mel_specs).cpu().numpy() |
| |
|
| | elif model_name == "tacotron2": |
| | model_output = postnet_outputs.detach().cpu().numpy() |
| | return model_output |
| |
|
| |
|
| | def extract_spectrograms( |
| | data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt" |
| | ): |
| | model.eval() |
| | export_metadata = [] |
| | for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): |
| | |
| | ( |
| | text_input, |
| | text_lengths, |
| | mel_input, |
| | mel_lengths, |
| | speaker_ids, |
| | d_vectors, |
| | _, |
| | _, |
| | _, |
| | item_idx, |
| | ) = format_data(data) |
| |
|
| | model_output = inference( |
| | c.model.lower(), |
| | model, |
| | ap, |
| | text_input, |
| | text_lengths, |
| | mel_input, |
| | mel_lengths, |
| | speaker_ids, |
| | d_vectors, |
| | ) |
| |
|
| | for idx in range(text_input.shape[0]): |
| | wav_file_path = item_idx[idx] |
| | wav = ap.load_wav(wav_file_path) |
| | _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) |
| |
|
| | |
| | if quantized_wav: |
| | wavq = ap.quantize(wav) |
| | np.save(wavq_path, wavq) |
| |
|
| | |
| | mel = model_output[idx] |
| | mel_length = mel_lengths[idx] |
| | mel = mel[:mel_length, :].T |
| | np.save(mel_path, mel) |
| |
|
| | export_metadata.append([wav_file_path, mel_path]) |
| | if save_audio: |
| | ap.save_wav(wav, wav_path) |
| |
|
| | if debug: |
| | print("Audio for debug saved at:", wav_gl_path) |
| | wav = ap.inv_melspectrogram(mel) |
| | ap.save_wav(wav, wav_gl_path) |
| |
|
| | with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f: |
| | for data in export_metadata: |
| | f.write(f"{data[0]}|{data[1]+'.npy'}\n") |
| |
|
| |
|
| | def main(args): |
| | |
| | global meta_data, speaker_manager |
| |
|
| | |
| | ap = AudioProcessor(**c.audio) |
| |
|
| | |
| | meta_data_train, meta_data_eval = load_tts_samples( |
| | c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size |
| | ) |
| |
|
| | |
| | meta_data = meta_data_train + meta_data_eval |
| |
|
| | |
| | if c.use_speaker_embedding: |
| | speaker_manager = SpeakerManager(data_items=meta_data) |
| | elif c.use_d_vector_file: |
| | speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file) |
| | else: |
| | speaker_manager = None |
| |
|
| | |
| | model = setup_model(c) |
| |
|
| | |
| | model.load_checkpoint(c, args.checkpoint_path, eval=True) |
| |
|
| | if use_cuda: |
| | model.cuda() |
| |
|
| | num_params = count_parameters(model) |
| | print("\n > Model has {} parameters".format(num_params), flush=True) |
| | |
| | r = 1 if c.model.lower() == "glow_tts" else model.decoder.r |
| | own_loader = setup_loader(ap, r, verbose=True) |
| |
|
| | extract_spectrograms( |
| | own_loader, |
| | model, |
| | ap, |
| | args.output_path, |
| | quantized_wav=args.quantized, |
| | save_audio=args.save_audio, |
| | debug=args.debug, |
| | metada_name="metada.txt", |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) |
| | parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) |
| | parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) |
| | parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") |
| | parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") |
| | parser.add_argument("--quantized", action="store_true", help="Save quantized audio files") |
| | parser.add_argument("--eval", type=bool, help="compute eval.", default=True) |
| | args = parser.parse_args() |
| |
|
| | c = load_config(args.config_path) |
| | c.audio.trim_silence = False |
| | main(args) |
| |
|