| import argparse |
| import importlib |
| import logging |
| import os |
| from argparse import RawTextHelpFormatter |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from trainer.io import load_checkpoint |
|
|
| from src.utils.TTS.config import load_config |
| from src.utils.TTS.tts.datasets.TTSDataset import TTSDataset |
| from src.utils.TTS.tts.models import setup_model |
| from src.utils.TTS.tts.utils.text.characters import make_symbols, phonemes, symbols |
| from src.utils.TTS.utils.audio import AudioProcessor |
| from src.utils.TTS.utils.generic_utils import ConsoleFormatter, setup_logger |
|
|
| if __name__ == "__main__": |
| setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) |
|
|
| |
| parser = argparse.ArgumentParser( |
| description="""Extract attention masks from trained Tacotron/Tacotron2 models. |
| These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" |
| """Each attention mask is written to the same path as the input wav file with ".npy" file extension. |
| (e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" |
| """ |
| Example run: |
| CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py |
| --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth |
| --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json |
| --dataset_metafile metadata.csv |
| --data_path /root/LJSpeech-1.1/ |
| --batch_size 32 |
| --dataset ljspeech |
| --use_cuda |
| """, |
| formatter_class=RawTextHelpFormatter, |
| ) |
| parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") |
| parser.add_argument( |
| "--config_path", |
| type=str, |
| required=True, |
| help="Path to Tacotron/Tacotron2 config file.", |
| ) |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| default="", |
| required=True, |
| help="Target dataset processor name from src.utils.TTS.tts.dataset.preprocess.", |
| ) |
|
|
| parser.add_argument( |
| "--dataset_metafile", |
| type=str, |
| default="", |
| required=True, |
| help="Dataset metafile inclusing file paths with transcripts.", |
| ) |
| parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") |
| parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="enable/disable cuda.") |
|
|
| parser.add_argument( |
| "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." |
| ) |
| args = parser.parse_args() |
|
|
| C = load_config(args.config_path) |
| ap = AudioProcessor(**C.audio) |
|
|
| |
| if "characters" in C.keys(): |
| symbols, phonemes = make_symbols(**C.characters) |
|
|
| |
| num_chars = len(phonemes) if C.use_phonemes else len(symbols) |
| |
| model = setup_model(C) |
| model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) |
|
|
| |
| preprocessor = importlib.import_module("TTS.tts.datasets.formatters") |
| preprocessor = getattr(preprocessor, args.dataset) |
| meta_data = preprocessor(args.data_path, args.dataset_metafile) |
| dataset = TTSDataset( |
| model.decoder.r, |
| C.text_cleaner, |
| compute_linear_spec=False, |
| ap=ap, |
| meta_data=meta_data, |
| characters=C.characters if "characters" in C.keys() else None, |
| add_blank=C["add_blank"] if "add_blank" in C.keys() else False, |
| use_phonemes=C.use_phonemes, |
| phoneme_cache_path=C.phoneme_cache_path, |
| phoneme_language=C.phoneme_language, |
| enable_eos_bos=C.enable_eos_bos_chars, |
| ) |
|
|
| dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| num_workers=4, |
| collate_fn=dataset.collate_fn, |
| shuffle=False, |
| drop_last=False, |
| ) |
|
|
| |
| file_paths = [] |
| with torch.no_grad(): |
| for data in tqdm(loader): |
| |
| text_input = data[0] |
| text_lengths = data[1] |
| linear_input = data[3] |
| mel_input = data[4] |
| mel_lengths = data[5] |
| stop_targets = data[6] |
| item_idxs = data[7] |
|
|
| |
| if args.use_cuda: |
| text_input = text_input.cuda() |
| text_lengths = text_lengths.cuda() |
| mel_input = mel_input.cuda() |
| mel_lengths = mel_lengths.cuda() |
|
|
| model_outputs = model.forward(text_input, text_lengths, mel_input) |
|
|
| alignments = model_outputs["alignments"].detach() |
| for idx, alignment in enumerate(alignments): |
| item_idx = item_idxs[idx] |
| |
| alignment = ( |
| torch.nn.functional.interpolate( |
| alignment.transpose(0, 1).unsqueeze(0), |
| size=None, |
| scale_factor=model.decoder.r, |
| mode="nearest", |
| align_corners=None, |
| recompute_scale_factor=None, |
| ) |
| .squeeze(0) |
| .transpose(0, 1) |
| ) |
| |
| alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() |
| |
| wav_file_name = os.path.basename(item_idx) |
| align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" |
| file_path = item_idx.replace(wav_file_name, align_file_name) |
| |
| wav_file_abs_path = os.path.abspath(item_idx) |
| file_abs_path = os.path.abspath(file_path) |
| file_paths.append([wav_file_abs_path, file_abs_path]) |
| np.save(file_path, alignment) |
|
|
| |
| metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") |
|
|
| with open(metafile, "w", encoding="utf-8") as f: |
| for p in file_paths: |
| f.write(f"{p[0]}|{p[1]}\n") |
| print(f" >> Metafile created: {metafile}") |
|
|