{ "cells": [ { "cell_type": "markdown", "id": "ed07e3c2", "metadata": {}, "source": [ "# FastPitch Adapter Finetuning\n", "\n", "This notebook is designed to provide a guide on how to run FastPitch Adapter Finetuning Pipeline. It contains the following sections:\n", "1. **Fine-tune FastPitch on adaptation data**: fine-tune pre-trained multi-speaker FastPitch for a new speaker\n", "* Dataset Preparation: download dataset and extract manifest files. (duration more than 15 mins)\n", "* Preprocessing: add absolute audio paths in manifest and extract Supplementary Data.\n", "* **Model Setting: transform pre-trained checkpoint to adapter-compatible checkpoint and precompute speaker embedding**\n", "* Training: fine-tune frozen multispeaker FastPitch with trainable adapters.\n", "2. **Fine-tune HiFiGAN on adaptation data**: fine-tune a vocoder for the fine-tuned multi-speaker FastPitch\n", "* Dataset Preparation: extract mel-spectrograms from fine-tuned FastPitch.\n", "* Training: fine-tune HiFiGAN with fine-tuned adaptation data.\n", "3. **Inference**: generate speech from adapted FastPitch\n", "* Load Model: load pre-trained multi-speaker FastPitch with **fine-tuned adapters**.\n", "* Output Audio: generate audio files." ] }, { "cell_type": "markdown", "id": "772e7404", "metadata": {}, "source": [ "# License\n", "\n", "> Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n", "> \n", "> Licensed under the Apache License, Version 2.0 (the \"License\");\n", "> you may not use this file except in compliance with the License.\n", "> You may obtain a copy of the License at\n", "> \n", "> http://www.apache.org/licenses/LICENSE-2.0\n", "> \n", "> Unless required by applicable law or agreed to in writing, software\n", "> distributed under the License is distributed on an \"AS IS\" BASIS,\n", "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "> See the License for the specific language governing permissions and\n", "> limitations under the License." ] }, { "cell_type": "code", "execution_count": null, "id": "8f799aa0", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "You can either run this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", "Instructions for setting up Colab are as follows:\n", "1. Open a new Python 3 notebook.\n", "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", "4. Run this cell to set up dependencies# .\n", "\"\"\"\n", "# # If you're using Colab and not running locally, uncomment and run this cell.\n", "# BRANCH = 'main'\n", "# !apt-get install sox libsndfile1 ffmpeg\n", "# !pip install wget unidecode pynini==2.1.4 scipy==1.7.3\n", "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", "\n", "# # Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,\n", "# # comment out the below lines and set `code_dir` to your local path.\n", "code_dir = 'NeMoTTS' \n", "!git clone https://github.com/NVIDIA/NeMo.git {code_dir}" ] }, { "cell_type": "code", "execution_count": null, "id": "0a4d3371", "metadata": {}, "outputs": [], "source": [ "!wandb login #PASTE_WANDB_APIKEY_HERE" ] }, { "cell_type": "markdown", "id": "b73283fc", "metadata": {}, "source": [ "## Set finetuning params\n", "\n", "This notebook expects a pretrained model to finetune. If you have a pretrained multispeaker checkpoint, set the path in next block to the path of pretrained checkpoint. You can also pretrain a multispeaker adapter checkpoint using the [FastPitch_MultiSpeaker_Pretraining tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tts/FastPitch_MultiSpeaker_Pretraining.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "id": "25d94e3a", "metadata": {}, "outputs": [], "source": [ "# .nemo files for your pre-trained FastPitch and HiFiGAN\n", "pretrained_fastpitch_checkpoint = \"\"\n", "finetuned_hifigan_on_multispeaker_checkpoint = \"\"\n", "use_ipa = False #Set to False while using Arpabet." ] }, { "cell_type": "code", "execution_count": null, "id": "79cb9932", "metadata": {}, "outputs": [], "source": [ "sample_rate = 44100\n", "# Store all manifest and audios\n", "data_dir = 'NeMoTTS_dataset'\n", "# Store all supplementary files\n", "supp_dir = \"NeMoTTS_sup_data\"\n", "# Store all training logs\n", "logs_dir = \"NeMoTTS_logs\"\n", "# Store all mel-spectrograms for vocoder training\n", "mels_dir = \"NeMoTTS_mels\"" ] }, { "cell_type": "code", "execution_count": null, "id": "ec7fed4e", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import shutil\n", "import nemo\n", "import torch\n", "import numpy as np\n", "\n", "from pathlib import Path\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "f815deff", "metadata": {}, "outputs": [], "source": [ "os.makedirs(code_dir, exist_ok=True)\n", "code_dir = os.path.abspath(code_dir)\n", "os.makedirs(data_dir, exist_ok=True)\n", "data_dir = os.path.abspath(data_dir)\n", "os.makedirs(supp_dir, exist_ok=True)\n", "supp_dir = os.path.abspath(supp_dir)\n", "os.makedirs(logs_dir, exist_ok=True)\n", "logs_dir = os.path.abspath(logs_dir)\n", "os.makedirs(mels_dir, exist_ok=True)\n", "mels_dir = os.path.abspath(mels_dir)" ] }, { "cell_type": "markdown", "id": "539e8f0d", "metadata": {}, "source": [ "# 1. Fine-tune FastPitch on adaptation data" ] }, { "cell_type": "markdown", "id": "270ed53f", "metadata": {}, "source": [ "## a. Data Preparation\n", "For our tutorial, we use small part of VCTK dataset with a new target speaker (p267). Usually, the audios should have total duration more than 15 minutes." ] }, { "cell_type": "code", "execution_count": null, "id": "21ce4a34", "metadata": {}, "outputs": [], "source": [ "!cd {data_dir} && wget https://vctk-subset.s3.amazonaws.com/vctk_subset.tar.gz && tar zxf vctk_subset.tar.gz" ] }, { "cell_type": "code", "execution_count": null, "id": "2d5edbe5", "metadata": {}, "outputs": [], "source": [ "manidir = f\"{data_dir}/vctk_subset\"\n", "!ls {manidir}" ] }, { "cell_type": "code", "execution_count": null, "id": "c1de2249", "metadata": {}, "outputs": [], "source": [ "train_manifest = os.path.abspath(os.path.join(manidir, 'train.json'))\n", "valid_manifest = os.path.abspath(os.path.join(manidir, 'dev.json'))" ] }, { "cell_type": "markdown", "id": "e657c830", "metadata": {}, "source": [ "## b. Preprocessing" ] }, { "cell_type": "markdown", "id": "4d0076d4", "metadata": {}, "source": [ "### Add absolute file path in manifest\n", "We use absolute path for audio_filepath to get the audio during training." ] }, { "cell_type": "code", "execution_count": null, "id": "7ccb5fb6", "metadata": {}, "outputs": [], "source": [ "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest" ] }, { "cell_type": "code", "execution_count": null, "id": "23dc1ba6", "metadata": {}, "outputs": [], "source": [ "train_data = read_manifest(train_manifest)\n", "for m in train_data: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))\n", "write_manifest(train_manifest, train_data)\n", "\n", "valid_data = read_manifest(valid_manifest)\n", "for m in valid_data: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))\n", "write_manifest(valid_manifest, valid_data)" ] }, { "cell_type": "markdown", "id": "b852072b", "metadata": {}, "source": [ "### Extract Supplementary Data\n", "\n", "As mentioned in the [FastPitch and MixerTTS training tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tts/FastPitch_MixerTTS_Training.ipynb) - To accelerate and stabilize our training, we also need to extract pitch for every audio, estimate pitch statistics (mean, std, min, and max). To do this, all we need to do is iterate over our data one time, via `extract_sup_data.py` script." ] }, { "cell_type": "code", "execution_count": null, "id": "f6bdd226", "metadata": {}, "outputs": [], "source": [ "!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \\\n", " manifest_filepath={train_manifest} \\\n", " sup_data_path={supp_dir} \\\n", " dataset.sample_rate={sample_rate} \\\n", " dataset.n_fft=2048 \\\n", " dataset.win_length=2048 \\\n", " dataset.hop_length=512" ] }, { "cell_type": "markdown", "id": "fdae4e4e", "metadata": {}, "source": [ "After running the above command line, you will observe a new folder NeMoTTS_sup_data/pitch and printouts of pitch statistics like below. Specify these values to the FastPitch training configurations. We will be there in the following section.\n", "```bash\n", "PITCH_MEAN=175.48513793945312, PITCH_STD=42.3786735534668\n", "PITCH_MIN=65.4063949584961, PITCH_MAX=270.8517761230469\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "ac8fae15", "metadata": {}, "outputs": [], "source": [ "!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \\\n", " manifest_filepath={valid_manifest} \\\n", " sup_data_path={supp_dir} \\\n", " dataset.sample_rate={sample_rate} \\\n", " dataset.n_fft=2048 \\\n", " dataset.win_length=2048 \\\n", " dataset.hop_length=512" ] }, { "cell_type": "markdown", "id": "c9f98c86", "metadata": {}, "source": [ "## c. Model Setting\n", "### Transform pre-trained checkpoint to adapter-compatible checkpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "fd8c66fb", "metadata": {}, "outputs": [], "source": [ "from nemo.collections.tts.models import FastPitchModel\n", "from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer\n", "from nemo.core import adapter_mixins\n", "from omegaconf import DictConfig, OmegaConf, open_dict" ] }, { "cell_type": "code", "execution_count": null, "id": "ff535c8f", "metadata": {}, "outputs": [], "source": [ "def update_model_config_to_support_adapter(config) -> DictConfig:\n", " with open_dict(config):\n", " enc_adapter_metadata = adapter_mixins.get_registered_adapter(config.input_fft._target_)\n", " if enc_adapter_metadata is not None:\n", " config.input_fft._target_ = enc_adapter_metadata.adapter_class_path\n", "\n", " dec_adapter_metadata = adapter_mixins.get_registered_adapter(config.output_fft._target_)\n", " if dec_adapter_metadata is not None:\n", " config.output_fft._target_ = dec_adapter_metadata.adapter_class_path\n", "\n", " pitch_predictor_adapter_metadata = adapter_mixins.get_registered_adapter(config.pitch_predictor._target_)\n", " if pitch_predictor_adapter_metadata is not None:\n", " config.pitch_predictor._target_ = pitch_predictor_adapter_metadata.adapter_class_path\n", "\n", " duration_predictor_adapter_metadata = adapter_mixins.get_registered_adapter(config.duration_predictor._target_)\n", " if duration_predictor_adapter_metadata is not None:\n", " config.duration_predictor._target_ = duration_predictor_adapter_metadata.adapter_class_path\n", "\n", " aligner_adapter_metadata = adapter_mixins.get_registered_adapter(config.alignment_module._target_)\n", " if aligner_adapter_metadata is not None:\n", " config.alignment_module._target_ = aligner_adapter_metadata.adapter_class_path\n", "\n", " return config" ] }, { "cell_type": "code", "execution_count": null, "id": "4f457111", "metadata": {}, "outputs": [], "source": [ "spec_model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint).eval().cuda()\n", "spec_model.cfg = update_model_config_to_support_adapter(spec_model.cfg)" ] }, { "cell_type": "markdown", "id": "ef40def3", "metadata": {}, "source": [ "### Precompute Speaker Embedding\n", "Get all GST speaker embeddings from training data, take average, and save as `precomputed_emb` in FastPitch" ] }, { "cell_type": "code", "execution_count": null, "id": "30664bcb", "metadata": {}, "outputs": [], "source": [ "wave_model = WaveformFeaturizer(sample_rate=sample_rate)\n", "train_data = read_manifest(train_manifest)\n", "\n", "spk_embs = [] \n", "for data in train_data:\n", " with torch.no_grad():\n", " audio = wave_model.process(data['audio_filepath'])\n", " audio_length = torch.tensor(audio.shape[0]).long()\n", " audio = audio.unsqueeze(0).to(device=spec_model.device)\n", " audio_length = audio_length.unsqueeze(0).to(device=spec_model.device)\n", " spec_ref, spec_ref_lens = spec_model.preprocessor(input_signal=audio, length=audio_length)\n", " spk_emb = spec_model.fastpitch.get_speaker_embedding(batch_size=spec_ref.shape[0],\n", " speaker=None,\n", " reference_spec=spec_ref,\n", " reference_spec_lens=spec_ref_lens)\n", "\n", " spk_embs.append(spk_emb.squeeze().cpu())\n", "\n", "spk_embs = torch.stack(spk_embs, dim=0)\n", "spk_emb = torch.mean(spk_embs, dim=0)\n", "spk_emb_dim = spk_emb.shape[0]\n", "\n", "with open_dict(spec_model.cfg):\n", " spec_model.cfg.speaker_encoder.precomputed_embedding_dim = spec_model.cfg.symbols_embedding_dim\n", "\n", "spec_model.fastpitch.speaker_encoder.overwrite_precomputed_emb(spk_emb)" ] }, { "cell_type": "code", "execution_count": null, "id": "43001c75", "metadata": {}, "outputs": [], "source": [ "spec_model.save_to('Pretrained-FastPitch.nemo')\n", "shutil.copyfile(finetuned_hifigan_on_multispeaker_checkpoint, \"Pretrained-HifiGan.nemo\")\n", "pretrained_fastpitch_checkpoint = os.path.abspath(\"Pretrained-FastPitch.nemo\")\n", "finetuned_hifigan_on_multispeaker_checkpoint = os.path.abspath(\"Pretrained-HifiGan.nemo\")" ] }, { "cell_type": "markdown", "id": "42915e02", "metadata": {}, "source": [ "## d. Training" ] }, { "cell_type": "code", "execution_count": null, "id": "884bc2d0", "metadata": {}, "outputs": [], "source": [ "phone_dict_name = \"ipa_cmudict-0.7b_nv23.01.txt\" if use_ipa else \"cmudict-0.7b_nv22.10\"\n", "phoneme_dict_path = os.path.abspath(os.path.join(code_dir, \"scripts\", \"tts_dataset_files\", phone_dict_name))\n", "heteronyms_path = os.path.abspath(os.path.join(code_dir, \"scripts\", \"tts_dataset_files\", \"heteronyms-052722\"))\n", "\n", "# Copy and Paste the PITCH_MEAN and PITCH_STD from previous steps (train_manifest) to override pitch_mean and pitch_std configs below.\n", "PITCH_MEAN=175.48513793945312\n", "PITCH_STD=42.3786735534668\n", "\n", "config_filename = \"fastpitch_align_ipa_adapter.yaml\" if use_ipa else \"fastpitch_align_44100_adapter.yaml\"" ] }, { "cell_type": "markdown", "id": "6f04fc86", "metadata": {}, "source": [ "### Important notes\n", "* `+init_from_nemo_model`: initialize with a multi-speaker FastPitch checkpoint\n", "* `model.speaker_encoder.precomputed_embedding_dim={spk_emb_dim}`: use precomputed speaker embedding\n", "* `~model.speaker_encoder.lookup_module`: we use precomputed speaker embedding, so we remove the pre-trained looked-up speaker embedding\n", "* `~model.speaker_encoder.gst_module`: we use precomputed speaker embedding, so we remove the pre-trained gst speaker embedding\n", "* Other optional arguments based on your preference:\n", " * batch_size\n", " * exp_manager\n", " * trainer\n", " * model.unfreeze_aligner=true\n", " * model.unfreeze_duration_predictor=true\n", " * model.unfreeze_pitch_predictor=true" ] }, { "cell_type": "code", "execution_count": null, "id": "7ae8383a", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Normally 200 epochs\n", "!cd {code_dir} && python examples/tts/fastpitch_finetune_adapters.py \\\n", "--config-name={config_filename} \\\n", "+init_from_nemo_model={pretrained_fastpitch_checkpoint} \\\n", "train_dataset={train_manifest} \\\n", "validation_datasets={valid_manifest} \\\n", "sup_data_types=\"['align_prior_matrix', 'pitch', 'energy']\" \\\n", "sup_data_path={supp_dir} \\\n", "pitch_mean={PITCH_MEAN} \\\n", "pitch_std={PITCH_STD} \\\n", "model.speaker_encoder.precomputed_embedding_dim={spk_emb_dim} \\\n", "~model.speaker_encoder.lookup_module \\\n", "~model.speaker_encoder.gst_module \\\n", "model.train_ds.dataloader_params.batch_size=8 \\\n", "model.validation_ds.dataloader_params.batch_size=8 \\\n", "+model.text_tokenizer.add_blank_at=True \\\n", "model.optim.name=adam \\\n", "model.optim.lr=1e-3 \\\n", "model.optim.sched.warmup_steps=0 \\\n", "+model.optim.sched.min_lr=1e-4 \\\n", "exp_manager.exp_dir={logs_dir} \\\n", "+exp_manager.create_wandb_logger=True \\\n", "+exp_manager.wandb_logger_kwargs.name=\"tutorial-FastPitch-finetune-adaptation\" \\\n", "+exp_manager.wandb_logger_kwargs.project=\"NeMo\" \\\n", "+exp_manager.checkpoint_callback_params.save_top_k=-1 \\\n", "trainer.max_epochs=20 \\\n", "trainer.check_val_every_n_epoch=10 \\\n", "trainer.log_every_n_steps=1 \\\n", "trainer.devices=1 \\\n", "trainer.strategy=ddp \\\n", "trainer.precision=32" ] }, { "cell_type": "code", "execution_count": null, "id": "39d3074c", "metadata": {}, "outputs": [], "source": [ "# e.g. NeMoTTS_logs/FastPitch/Y-M-D_H-M-S/checkpoints/FastPitch.nemo\n", "# e.g. NeMoTTS_logs/FastPitch/Y-M-D_H-M-S/checkpoints/adapters.pt\n", "last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / \"FastPitch\").iterdir() if i.is_dir()]))[-1] / \"checkpoints\"\n", "finetuned_fastpitch_checkpoint = list(last_checkpoint_dir.glob('*.nemo'))[0]\n", "finetuned_adapter_checkpoint = list(last_checkpoint_dir.glob('adapters.pt'))[0]\n", "print(finetuned_fastpitch_checkpoint)\n", "print(finetuned_adapter_checkpoint)" ] }, { "cell_type": "markdown", "id": "9e9a1f45", "metadata": {}, "source": [ "# 3. Fine-tune HiFiGAN on adaptation data" ] }, { "cell_type": "markdown", "id": "deec135f", "metadata": {}, "source": [ "## a. Dataset Preparation\n", "Generate mel-spectrograms for HiFiGAN training." ] }, { "cell_type": "code", "execution_count": null, "id": "1aecaa68", "metadata": { "scrolled": false }, "outputs": [], "source": [ "!cd {code_dir} \\\n", "&& python scripts/dataset_processing/tts/resynthesize_dataset.py \\\n", "--model-path={finetuned_fastpitch_checkpoint} \\\n", "--input-json-manifest={train_manifest} \\\n", "--input-sup-data-path={supp_dir} \\\n", "--output-folder={mels_dir} \\\n", "--device=\"cuda:0\" \\\n", "--batch-size=1 \\\n", "--num-workers=1 \\\n", "&& python scripts/dataset_processing/tts/resynthesize_dataset.py \\\n", "--model-path={finetuned_fastpitch_checkpoint} \\\n", "--input-json-manifest={valid_manifest} \\\n", "--input-sup-data-path={supp_dir} \\\n", "--output-folder={mels_dir} \\\n", "--device=\"cuda:0\" \\\n", "--batch-size=1 \\\n", "--num-workers=1" ] }, { "cell_type": "code", "execution_count": null, "id": "6a153ea0", "metadata": {}, "outputs": [], "source": [ "train_manifest_mel = f\"{mels_dir}/train_mel.json\"\n", "valid_manifest_mel = f\"{mels_dir}/dev_mel.json\"" ] }, { "cell_type": "markdown", "id": "b05cd550", "metadata": {}, "source": [ "## b. Training" ] }, { "cell_type": "code", "execution_count": null, "id": "e5d5f281", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Normally 500 epochs\n", "!cd {code_dir} && python examples/tts/hifigan_finetune.py \\\n", "--config-name=hifigan_44100.yaml \\\n", "train_dataset={train_manifest_mel} \\\n", "validation_datasets={valid_manifest_mel} \\\n", "+init_from_nemo_model={finetuned_hifigan_on_multispeaker_checkpoint} \\\n", "model.train_ds.dataloader_params.batch_size=32 \\\n", "model.optim.lr=0.0001 \\\n", "model/train_ds=train_ds_finetune \\\n", "model/validation_ds=val_ds_finetune \\\n", "+trainer.max_epochs=50 \\\n", "trainer.check_val_every_n_epoch=5 \\\n", "trainer.devices=-1 \\\n", "trainer.strategy='ddp_find_unused_parameters_true' \\\n", "trainer.precision=16 \\\n", "exp_manager.exp_dir={logs_dir} \\\n", "exp_manager.create_wandb_logger=True \\\n", "exp_manager.wandb_logger_kwargs.name=\"tutorial-HiFiGAN-finetune-multispeaker\" \\\n", "exp_manager.wandb_logger_kwargs.project=\"NeMo\"" ] }, { "cell_type": "code", "execution_count": null, "id": "9c1c42f3", "metadata": {}, "outputs": [], "source": [ "# e.g. NeMoTTS_logs/HifiGan/Y-M-D_H-M-S/checkpoints/HifiGan.nemo\n", "last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / \"HifiGan\").iterdir() if i.is_dir()]))[-1] / \"checkpoints\"\n", "finetuned_hifigan_on_adaptation_checkpoint = list(last_checkpoint_dir.glob('*.nemo'))[0]\n", "finetuned_hifigan_on_adaptation_checkpoint" ] }, { "cell_type": "markdown", "id": "0665ac78", "metadata": {}, "source": [ "# 4. Inference" ] }, { "cell_type": "code", "execution_count": null, "id": "5f4afb24", "metadata": {}, "outputs": [], "source": [ "from nemo.collections.tts.models import HifiGanModel\n", "import IPython.display as ipd\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "id": "0d9ff309", "metadata": {}, "source": [ "## a. Load Model" ] }, { "cell_type": "code", "execution_count": null, "id": "81e4dee0", "metadata": {}, "outputs": [], "source": [ "# Load from pretrained FastPitch and finetuned adapter\n", "# spec_model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint)\n", "# spec_model.load_adapters(finetuned_adapter_checkpoint)\n", "\n", "# Load from finetuned FastPitch\n", "spec_model = FastPitchModel.restore_from(finetuned_fastpitch_checkpoint)\n", "spec_model = spec_model.eval().cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "1eaef8be", "metadata": {}, "outputs": [], "source": [ "# HiFiGAN\n", "vocoder_model = HifiGanModel.restore_from(finetuned_hifigan_on_adaptation_checkpoint).eval().cuda()" ] }, { "cell_type": "markdown", "id": "837bdbab", "metadata": {}, "source": [ "## b. Output Audio" ] }, { "cell_type": "code", "execution_count": null, "id": "fef139cb", "metadata": {}, "outputs": [], "source": [ "def gen_spectrogram(text, spec_gen_model):\n", " parsed = spec_gen_model.parse(text)\n", " with torch.no_grad(): \n", " spectrogram = spec_gen_model.generate_spectrogram(tokens=parsed)\n", " return spectrogram\n", " \n", "def synth_audio(vocoder_model, spectrogram): \n", " with torch.no_grad(): \n", " audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)\n", " if isinstance(audio, torch.Tensor):\n", " audio = audio.to('cpu').numpy()\n", " return audio" ] }, { "cell_type": "code", "execution_count": null, "id": "b98ac280", "metadata": {}, "outputs": [], "source": [ "# Validatation Audio\n", "num_val = 3\n", "val_records = []\n", "with open(valid_manifest, \"r\") as f:\n", " for i, line in enumerate(f):\n", " val_records.append(json.loads(line))\n", " if len(val_records) >= num_val:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "b17446f9", "metadata": {}, "outputs": [], "source": [ "for i, val_record in enumerate(val_records):\n", " spec_pred = gen_spectrogram(val_record['text'], spec_model)\n", " audio_gen = synth_audio(vocoder_model, spec_pred)\n", "\n", " audio_gt = ipd.Audio(val_record['audio_filepath'], rate=sample_rate)\n", " audio_gen = ipd.Audio(audio_gen, rate=sample_rate)\n", " \n", " print(\"------\")\n", " print(f\"Text: {val_record['text']}\")\n", " print('Ground Truth Audio')\n", " ipd.display(audio_gt)\n", " print('Synthesized Audio')\n", " ipd.display(audio_gen)\n", " plt.imshow(spec_pred[0].to('cpu').numpy(), origin=\"lower\", aspect=\"auto\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "f8f525d1", "metadata": {}, "outputs": [], "source": [ "print(f\"Finetuned FastPitch: {finetuned_fastpitch_checkpoint}\")\n", "print(f\"Finetuned HiFi-Gan: {finetuned_hifigan_on_adaptation_checkpoint}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 5 }