{ "cells": [ { "cell_type": "markdown", "id": "afd8cdc9", "metadata": {}, "source": [ "# FastPitch MultiSpeaker Pretraining\n", "\n", "This notebook is designed to provide a guide on how to run FastPitch MultiSpeaker Pretraining Pipeline. It contains the following sections:\n", "1. **Pre-train FastPitch on multi-speaker data**: pre-train a multi-speaker FastPitch\n", "* Dataset Preparation: download dataset and extract manifest files.\n", "* Preprocessing: add absolute audio paths in manifest, calibrate speaker id to start from 0, and extract Supplementary Data.\n", "* Training: pre-train multispeaker FastPitch\n", "2. **Fine-tune HiFiGAN on multi-speaker data**: fine-tune a vocoder for the pre-trained multi-speaker FastPitch\n", "* Dataset Preparation: extract mel-spectrograms from pre-trained FastPitch.\n", "* Training: fine-tune HiFiGAN with pre-trained multi-speaker data.\n", "3. **Inference**: generate speech from pre-trained multi-speaker FastPitch\n", "* Load Model: load pre-trained multi-speaker FastPitch.\n", "* Output Audio: generate audio files." ] }, { "cell_type": "markdown", "id": "4fc9c6b9", "metadata": {}, "source": [ "# License\n", "> Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n", "> \n", "> Licensed under the Apache License, Version 2.0 (the \"License\");\n", "> you may not use this file except in compliance with the License.\n", "> You may obtain a copy of the License at\n", "> \n", "> http://www.apache.org/licenses/LICENSE-2.0\n", "> \n", "> Unless required by applicable law or agreed to in writing, software\n", "> distributed under the License is distributed on an \"AS IS\" BASIS,\n", "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "> See the License for the specific language governing permissions and\n", "> limitations under the License." ] }, { "cell_type": "code", "execution_count": null, "id": "b81f6c14", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "You can either run this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", "Instructions for setting up Colab are as follows:\n", "1. Open a new Python 3 notebook.\n", "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", "4. Run this cell to set up dependencies# .\n", "\"\"\"\n", "# BRANCH = 'main'\n", "# # If you're using Colab and not running locally, uncomment and run this cell.\n", "# !apt-get install sox libsndfile1 ffmpeg\n", "# !pip install wget unidecode pynini==2.1.4 scipy==1.7.3\n", "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", "\n", "# # Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,\n", "# # comment out the below lines and set `code_dir` to your local path.\n", "code_dir = 'NeMoTTS' \n", "!git clone https://github.com/NVIDIA/NeMo.git {code_dir}" ] }, { "cell_type": "code", "execution_count": null, "id": "f2f1e3ac", "metadata": {}, "outputs": [], "source": [ "!wandb login #PASTE_WANDB_APIKEY_HERE" ] }, { "cell_type": "code", "execution_count": null, "id": "1acd141d", "metadata": {}, "outputs": [], "source": [ "sample_rate = 44100\n", "# Store all manifest and audios\n", "data_dir = 'NeMoTTS_dataset'\n", "# Store all supplementary files\n", "supp_dir = \"NeMoTTS_sup_data\"\n", "# Store all training logs\n", "logs_dir = \"NeMoTTS_logs\"\n", "# Store all mel-spectrograms for vocoder training\n", "mels_dir = \"NeMoTTS_mels\"" ] }, { "cell_type": "code", "execution_count": null, "id": "7b54c45e", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import nemo\n", "import torch\n", "import numpy as np\n", "\n", "from pathlib import Path\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "a119994b", "metadata": {}, "outputs": [], "source": [ "os.makedirs(code_dir, exist_ok=True)\n", "code_dir = os.path.abspath(code_dir)\n", "os.makedirs(data_dir, exist_ok=True)\n", "data_dir = os.path.abspath(data_dir)\n", "os.makedirs(supp_dir, exist_ok=True)\n", "supp_dir = os.path.abspath(supp_dir)\n", "os.makedirs(logs_dir, exist_ok=True)\n", "logs_dir = os.path.abspath(logs_dir)\n", "os.makedirs(mels_dir, exist_ok=True)\n", "mels_dir = os.path.abspath(mels_dir)" ] }, { "cell_type": "markdown", "id": "dbb3ac0e", "metadata": {}, "source": [ "# 1. Pre-train FastPitch on multi-speaker data" ] }, { "cell_type": "markdown", "id": "095a1fca", "metadata": {}, "source": [ "## a. Dataset Preparation\n", "For our tutorial, we use the subset of VCTK dataset with 5 speakers (p225-p229). The audios have 48 kHz sampling rate, we downsample to 44.1 kHz in this tutorial. \n", "You can read more about dataset [here](https://datashare.ed.ac.uk/handle/10283/2950)." ] }, { "cell_type": "code", "execution_count": null, "id": "69b17b07", "metadata": {}, "outputs": [], "source": [ "!cd {data_dir} && wget https://vctk-subset.s3.amazonaws.com/vctk_subset_multispeaker.tar.gz && tar zxf vctk_subset_multispeaker.tar.gz" ] }, { "cell_type": "code", "execution_count": null, "id": "a65e7938", "metadata": {}, "outputs": [], "source": [ "manidir = f\"{data_dir}/vctk_subset_multispeaker\"\n", "!ls {manidir}" ] }, { "cell_type": "code", "execution_count": null, "id": "08b27b92", "metadata": {}, "outputs": [], "source": [ "train_manifest = os.path.abspath(os.path.join(manidir, 'train.json'))\n", "valid_manifest = os.path.abspath(os.path.join(manidir, 'dev.json'))" ] }, { "cell_type": "markdown", "id": "7cbf24d6", "metadata": {}, "source": [ "## b. Preprocessing" ] }, { "cell_type": "markdown", "id": "cae8567d", "metadata": {}, "source": [ "### Add absolute audio path in manifest\n", "We use absolute path for `audio_filepath` to get the audio during training." ] }, { "cell_type": "code", "execution_count": null, "id": "71d2fe63", "metadata": {}, "outputs": [], "source": [ "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest" ] }, { "cell_type": "code", "execution_count": null, "id": "dc51398c", "metadata": {}, "outputs": [], "source": [ "train_data = read_manifest(train_manifest)\n", "for m in train_data: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))\n", "write_manifest(train_manifest, train_data)\n", "\n", "valid_data = read_manifest(valid_manifest)\n", "for m in valid_data: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))\n", "write_manifest(valid_manifest, valid_data)" ] }, { "cell_type": "markdown", "id": "678bb37c", "metadata": {}, "source": [ "### Calibrate speaker id to start from 0\n", "We use speaker id start from 0, so we can create a speaker look-up table with speaker size." ] }, { "cell_type": "code", "execution_count": null, "id": "594c6f2d", "metadata": {}, "outputs": [], "source": [ "train_data = read_manifest(train_manifest)\n", "speaker2id = {s: _id for _id, s in enumerate(set([m['speaker'] for m in train_data]))}\n", "for m in train_data: m['old_speaker'], m['speaker'] = m['speaker'], speaker2id[m['speaker']]\n", "write_manifest(train_manifest, train_data)\n", "\n", "valid_data = read_manifest(valid_manifest)\n", "for m in valid_data: m['old_speaker'], m['speaker'] = m['speaker'], speaker2id[m['speaker']]\n", "write_manifest(valid_manifest, valid_data)" ] }, { "cell_type": "markdown", "id": "15b6cc65", "metadata": {}, "source": [ "### Extract Supplementary Data\n", "\n", "As mentioned in the [FastPitch and MixerTTS training tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tts/FastPitch_MixerTTS_Training.ipynb) - To accelerate and stabilize our training, we also need to extract pitch for every audio, estimate pitch statistics (mean, std, min, and max). To do this, all we need to do is iterate over our data one time, via `extract_sup_data.py` script." ] }, { "cell_type": "code", "execution_count": null, "id": "c3728ac9", "metadata": {}, "outputs": [], "source": [ "!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \\\n", " manifest_filepath={train_manifest} \\\n", " sup_data_path={supp_dir} \\\n", " dataset.sample_rate={sample_rate} \\\n", " dataset.n_fft=2048 \\\n", " dataset.win_length=2048 \\\n", " dataset.hop_length=512" ] }, { "cell_type": "markdown", "id": "effd9182", "metadata": {}, "source": [ "After running the above command line, you will observe a new folder NeMoTTS_sup_data/pitch and printouts of pitch statistics like below. Specify these values to the FastPitch training configurations. We will be there in the following section.\n", "```bash\n", "PITCH_MEAN=140.84278869628906, PITCH_STD=50.97673034667969\n", "PITCH_MIN=65.4063949584961, PITCH_MAX=285.3046875\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "37e54cd4", "metadata": {}, "outputs": [], "source": [ "!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \\\n", " manifest_filepath={valid_manifest} \\\n", " sup_data_path={supp_dir} \\\n", " dataset.sample_rate={sample_rate} \\\n", " dataset.n_fft=2048 \\\n", " dataset.win_length=2048 \\\n", " dataset.hop_length=512" ] }, { "cell_type": "markdown", "id": "82d2c99d", "metadata": {}, "source": [ "* If you want to compute pitch mean and std for each speaker, you can use the script `compute_speaker_stats.py`\n", "```bash\n", "!cd {code_dir} && python scripts/dataset_processing/tts/compute_speaker_stats.py \\\n", " --manifest_path={train_manifest} \\\n", " --sup_data_path={supp_dir} \\\n", " --pitch_stats_path={data_dir}/pitch_stats.json\n", "```" ] }, { "cell_type": "markdown", "id": "a7c8dfb6", "metadata": {}, "source": [ "## c. Training" ] }, { "cell_type": "code", "execution_count": null, "id": "e378a792", "metadata": {}, "outputs": [], "source": [ "phoneme_dict_path = os.path.abspath(os.path.join(code_dir, \"scripts\", \"tts_dataset_files\", \"cmudict-0.7b_nv22.10\"))\n", "heteronyms_path = os.path.abspath(os.path.join(code_dir, \"scripts\", \"tts_dataset_files\", \"heteronyms-052722\"))\n", "\n", "# Copy and Paste the PITCH_MEAN and PITCH_STD from previous steps (train_manifest) to override pitch_mean and pitch_std configs below.\n", "PITCH_MEAN=140.84278869628906\n", "PITCH_STD=65.4063949584961" ] }, { "cell_type": "markdown", "id": "a90ddfb3", "metadata": {}, "source": [ "### Important notes\n", "* `sup_data_types=\"['align_prior_matrix', 'pitch', 'speaker_id', 'reference_audio']\" `\n", " * **speaker_id**: each data has an unique speaker index (start from 0) in the input.\n", " * **reference_audio**: each data has a reference audio (from the same speaker) in the input.\n", " \n", "* `model.speaker_encoder.lookup_module.n_speakers`\n", " * if use **model.speaker_encoder.lookup_module**, please give n_speakers to create the lookup table\n", "\n", "* `condition_types=\"['add', 'concat', layernorm']`\n", " * use different operation type to condition module (e.g. input_fft/output_fft/duration_predictor/pitch_predictor/alignment_module)\n", " * **add**: add conditions to module input\n", " * **concat**: concat conditions to module input\n", " * **layernorm**: scale and shift layernorm outputs based on conditions\n", " \n", "* Other default arguments in config:\n", " * `model.speaker_encoder.lookup_module`: model creates lookup table to get speaker embedding from speaker id.\n", " * `model.speaker_encoder.gst_module`: model creates global style token to extract speaker information from reference audio.\n", "\n", "* Other optional arguments based on your preference:\n", " * batch_size\n", " * max_duration\n", " * min_duration\n", " * exp_manager\n", " * trainer" ] }, { "cell_type": "code", "execution_count": null, "id": "ac22f3a8", "metadata": {}, "outputs": [], "source": [ "# Normally 200 epochs\n", "!(cd {code_dir} && python examples/tts/fastpitch.py \\\n", "--config-name=fastpitch_align_44100_adapter.yaml \\\n", "+init_from_pretrained_model=\"tts_en_fastpitch\" \\\n", "train_dataset={train_manifest} \\\n", "validation_datasets={valid_manifest} \\\n", "sup_data_types=\"['align_prior_matrix', 'pitch', 'speaker_id', 'reference_audio']\" \\\n", "sup_data_path={supp_dir} \\\n", "pitch_mean={PITCH_MEAN} \\\n", "pitch_std={PITCH_STD} \\\n", "phoneme_dict_path={phoneme_dict_path} \\\n", "heteronyms_path={heteronyms_path} \\\n", "model.speaker_encoder.lookup_module.n_speakers=5 \\\n", "model.input_fft.condition_types=\"['add', 'layernorm']\" \\\n", "model.output_fft.condition_types=\"['add', 'layernorm']\" \\\n", "model.duration_predictor.condition_types=\"['add', 'layernorm']\" \\\n", "model.pitch_predictor.condition_types=\"['add', 'layernorm']\" \\\n", "model.alignment_module.condition_types=\"['add']\" \\\n", "model.train_ds.dataloader_params.batch_size=8 \\\n", "model.validation_ds.dataloader_params.batch_size=8 \\\n", "model.train_ds.dataset.max_duration=20 \\\n", "model.validation_ds.dataset.max_duration=20 \\\n", "model.validation_ds.dataset.min_duration=0.1 \\\n", "exp_manager.exp_dir={logs_dir} \\\n", "+exp_manager.create_wandb_logger=True \\\n", "+exp_manager.wandb_logger_kwargs.name=\"tutorial-FastPitch-pretrain-multispeaker\" \\\n", "+exp_manager.wandb_logger_kwargs.project=\"NeMo\" \\\n", "trainer.max_epochs=20 \\\n", "trainer.check_val_every_n_epoch=20 \\\n", "trainer.log_every_n_steps=1 \\\n", "trainer.devices=-1 \\\n", "trainer.strategy=ddp \\\n", "trainer.precision=32 \\\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "b6fc98a5", "metadata": {}, "outputs": [], "source": [ "# e.g. NeMoTTS_logs/FastPitch/Y-M-D_H-M-S/checkpoints/FastPitch.nemo\n", "last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / \"FastPitch\").iterdir() if i.is_dir()]))[-1] / \"checkpoints\"\n", "pretrained_fastpitch_checkpoint = os.path.abspath(list(last_checkpoint_dir.glob('*.nemo'))[0])\n", "print(pretrained_fastpitch_checkpoint)" ] }, { "cell_type": "markdown", "id": "b175f755", "metadata": {}, "source": [ "# 2. Fine-tune HiFiGAN on multi-speaker data" ] }, { "cell_type": "markdown", "id": "5749a0b8", "metadata": {}, "source": [ "## a. Dataset Preparation\n", "Generate mel-spectrograms for HiFiGAN training." ] }, { "cell_type": "code", "execution_count": null, "id": "3d77bda9", "metadata": {}, "outputs": [], "source": [ "!cd {code_dir} \\\n", "&& python scripts/dataset_processing/tts/resynthesize_dataset.py \\\n", "--model-path={pretrained_fastpitch_checkpoint} \\\n", "--input-json-manifest={train_manifest} \\\n", "--input-sup-data-path={supp_dir} \\\n", "--output-folder={mels_dir} \\\n", "--device=\"cuda:0\" \\\n", "--batch-size=1 \\\n", "--num-workers=1 \\\n", "&& python scripts/dataset_processing/tts/resynthesize_dataset.py \\\n", "--model-path={pretrained_fastpitch_checkpoint} \\\n", "--input-json-manifest={valid_manifest} \\\n", "--input-sup-data-path={supp_dir} \\\n", "--output-folder={mels_dir} \\\n", "--device=\"cuda:0\" \\\n", "--batch-size=1 \\\n", "--num-workers=1" ] }, { "cell_type": "code", "execution_count": null, "id": "8c9159a1", "metadata": {}, "outputs": [], "source": [ "train_manifest_mel = f\"{mels_dir}/train_mel.json\"\n", "valid_manifest_mel = f\"{mels_dir}/dev_mel.json\"" ] }, { "cell_type": "markdown", "id": "24653f24", "metadata": {}, "source": [ "## b. Training" ] }, { "cell_type": "code", "execution_count": null, "id": "fadc0410", "metadata": {}, "outputs": [], "source": [ "# Normally 100 epochs\n", "!cd {code_dir} && python examples/tts/hifigan_finetune.py \\\n", "--config-name=hifigan_44100.yaml \\\n", "train_dataset={train_manifest_mel} \\\n", "validation_datasets={valid_manifest_mel} \\\n", "+init_from_pretrained_model=\"tts_en_hifitts_hifigan_ft_fastpitch\" \\\n", "model.train_ds.dataloader_params.batch_size=32 \\\n", "model.optim.lr=0.0001 \\\n", "model/train_ds=train_ds_finetune \\\n", "model/validation_ds=val_ds_finetune \\\n", "+trainer.max_epochs=5 \\\n", "trainer.check_val_every_n_epoch=5 \\\n", "trainer.devices=1 \\\n", "trainer.strategy='ddp_find_unused_parameters_true' \\\n", "trainer.precision=16 \\\n", "exp_manager.exp_dir={logs_dir} \\\n", "exp_manager.create_wandb_logger=True \\\n", "exp_manager.wandb_logger_kwargs.name=\"tutorial-HiFiGAN-finetune-multispeaker\" \\\n", "exp_manager.wandb_logger_kwargs.project=\"NeMo\"" ] }, { "cell_type": "code", "execution_count": null, "id": "864fe5ba", "metadata": {}, "outputs": [], "source": [ "# e.g. NeMoTTS_logs/HifiGan/Y-M-D_H-M-S/checkpoints/HifiGan.nemo\n", "last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / \"HifiGan\").iterdir() if i.is_dir()]))[-1] / \"checkpoints\"\n", "finetuned_hifigan_on_multispeaker_checkpoint = os.path.abspath(list(last_checkpoint_dir.glob('*.nemo'))[0])\n", "finetuned_hifigan_on_multispeaker_checkpoint" ] }, { "cell_type": "markdown", "id": "e04540b6", "metadata": {}, "source": [ "# 3. Inference" ] }, { "cell_type": "code", "execution_count": null, "id": "fdf662f7", "metadata": {}, "outputs": [], "source": [ "from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer\n", "from nemo.collections.tts.models import FastPitchModel\n", "from nemo.collections.tts.models import HifiGanModel\n", "from collections import defaultdict\n", "import IPython.display as ipd\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "id": "270a3264", "metadata": {}, "source": [ "## a. Load Model" ] }, { "cell_type": "code", "execution_count": null, "id": "01315a66", "metadata": {}, "outputs": [], "source": [ "wave_model = WaveformFeaturizer(sample_rate=sample_rate)" ] }, { "cell_type": "code", "execution_count": null, "id": "536c8fdc", "metadata": {}, "outputs": [], "source": [ "# FastPitch\n", "spec_model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint).eval().cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "a2ace7c4", "metadata": {}, "outputs": [], "source": [ "# HiFiGAN\n", "vocoder_model = HifiGanModel.restore_from(finetuned_hifigan_on_multispeaker_checkpoint).eval().cuda()" ] }, { "cell_type": "markdown", "id": "cf4a42fa", "metadata": {}, "source": [ "## b. Output Audio" ] }, { "cell_type": "code", "execution_count": null, "id": "1b376468", "metadata": {}, "outputs": [], "source": [ "def gt_spectrogram(audio_path, wave_model, spec_gen_model):\n", " features = wave_model.process(audio_path, trim=False)\n", " audio, audio_length = features, torch.tensor(features.shape[0]).long()\n", " audio = audio.unsqueeze(0).to(device=spec_gen_model.device)\n", " audio_length = audio_length.unsqueeze(0).to(device=spec_gen_model.device)\n", " with torch.no_grad():\n", " spectrogram, spec_len = spec_gen_model.preprocessor(input_signal=audio, length=audio_length)\n", " return spectrogram, spec_len\n", "\n", "def gen_spectrogram(text, spec_gen_model, speaker, reference_spec, reference_spec_lens):\n", " parsed = spec_gen_model.parse(text)\n", " speaker = torch.tensor([speaker]).long().to(device=spec_gen_model.device)\n", " with torch.no_grad(): \n", " spectrogram = spec_gen_model.generate_spectrogram(tokens=parsed, \n", " speaker=speaker, \n", " reference_spec=reference_spec, \n", " reference_spec_lens=reference_spec_lens)\n", "\n", " return spectrogram\n", " \n", "def synth_audio(vocoder_model, spectrogram): \n", " with torch.no_grad(): \n", " audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)\n", " if isinstance(audio, torch.Tensor):\n", " audio = audio.to('cpu').numpy()\n", " return audio" ] }, { "cell_type": "code", "execution_count": null, "id": "f93f73a6", "metadata": {}, "outputs": [], "source": [ "# Reference Audio\n", "reference_records = []\n", "with open(train_manifest, \"r\") as f:\n", " for i, line in enumerate(f):\n", " reference_records.append(json.loads(line))\n", "\n", "speaker_to_index = defaultdict(list)\n", "for i, d in enumerate(reference_records): speaker_to_index[d.get('speaker', None)].append(i)\n", " \n", "# Validatation Audio\n", "num_val = 3\n", "val_records = []\n", "with open(valid_manifest, \"r\") as f:\n", " for i, line in enumerate(f):\n", " val_records.append(json.loads(line))\n", " if len(val_records) >= num_val:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "77590752", "metadata": {}, "outputs": [], "source": [ "for i, val_record in enumerate(val_records):\n", " reference_record = reference_records[speaker_to_index[val_record['speaker']][0]]\n", " reference_spec, reference_spec_lens = gt_spectrogram(reference_record['audio_filepath'], wave_model, spec_model)\n", " reference_spec = reference_spec.to(spec_model.device)\n", " spec_pred = gen_spectrogram(val_record['text'], \n", " spec_model,\n", " speaker=val_record['speaker'], \n", " reference_spec=reference_spec, \n", " reference_spec_lens=reference_spec_lens)\n", "\n", " audio_gen = synth_audio(vocoder_model, spec_pred)\n", " \n", " audio_ref = ipd.Audio(reference_record['audio_filepath'], rate=sample_rate)\n", " audio_gt = ipd.Audio(val_record['audio_filepath'], rate=sample_rate)\n", " audio_gen = ipd.Audio(audio_gen, rate=sample_rate)\n", " \n", " print(\"------\")\n", " print(f\"Text: {val_record['text']}\")\n", " print('Reference Audio')\n", " ipd.display(audio_ref)\n", " print('Ground Truth Audio')\n", " ipd.display(audio_gt)\n", " print('Synthesized Audio')\n", " ipd.display(audio_gen)\n", " plt.imshow(spec_pred[0].to('cpu').numpy(), origin=\"lower\", aspect=\"auto\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "8cd156e4", "metadata": {}, "outputs": [], "source": [ "print(f\"FastPitch checkpoint: {pretrained_fastpitch_checkpoint}\")\n", "print(f\"HiFi-Gan checkpoint: {finetuned_hifigan_on_multispeaker_checkpoint}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }