diff --git a/.gitattributes b/.gitattributes
index b777d1e942755f20b38ceaf0bec6b5810442960f..f5f1e05ff10aa3f835c7ebdb8e4f7d71c03c4df4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -35,3 +35,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
AuxiliaryASR/Data/train_list.csv filter=lfs diff=lfs merge=lfs -text
AuxiliaryASR/Data/train_list_plus.csv filter=lfs diff=lfs merge=lfs -text
+stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/train.log filter=lfs diff=lfs merge=lfs -text
+stts_48khz/StyleTTS2_48khz/Utils/JDC/bst_rmvpe_48k.t7 filter=lfs diff=lfs merge=lfs -text
+stts_48khz/StyleTTS2_48khz/infer.ipynb filter=lfs diff=lfs merge=lfs -text
diff --git a/stts_48khz/StyleTTS2_48khz/.gitignore b/stts_48khz/StyleTTS2_48khz/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..bee8a64b79a99590d5303307144172cfe824fbf7
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/.gitignore
@@ -0,0 +1 @@
+__pycache__
diff --git a/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Demo_LJSpeech.ipynb b/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Demo_LJSpeech.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..52e1c191fe65e09b1c8c853fb0f39b9679a82785
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Demo_LJSpeech.ipynb
@@ -0,0 +1,486 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nm653VK4CG9F"
+ },
+ "source": [
+ "### Install packages and download models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gciBKMqCCLvT"
+ },
+ "outputs": [],
+ "source": [
+ "%%shell\n",
+ "git clone https://github.com/yl4579/StyleTTS2.git\n",
+ "cd StyleTTS2\n",
+ "pip install SoundFile torchaudio munch torch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n",
+ "sudo apt-get install espeak-ng\n",
+ "git-lfs clone https://huggingface.co/yl4579/StyleTTS2-LJSpeech\n",
+ "mv StyleTTS2-LJSpeech/Models ."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OAA8lx-XCQnM"
+ },
+ "source": [
+ "### Load models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "m0XRpbxSCSix"
+ },
+ "outputs": [],
+ "source": [
+ "%cd StyleTTS2\n",
+ "\n",
+ "import torch\n",
+ "torch.manual_seed(0)\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "\n",
+ "import random\n",
+ "random.seed(0)\n",
+ "\n",
+ "import numpy as np\n",
+ "np.random.seed(0)\n",
+ "\n",
+ "import nltk\n",
+ "nltk.download('punkt')\n",
+ "\n",
+ "# load packages\n",
+ "import time\n",
+ "import random\n",
+ "import yaml\n",
+ "from munch import Munch\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "import torch.nn.functional as F\n",
+ "import torchaudio\n",
+ "import librosa\n",
+ "from nltk.tokenize import word_tokenize\n",
+ "\n",
+ "from models import *\n",
+ "from utils import *\n",
+ "from text_utils import TextCleaner\n",
+ "textclenaer = TextCleaner()\n",
+ "\n",
+ "%matplotlib inline\n",
+ "\n",
+ "device = 'cpu' if torch.cuda.is_available() else 'cpu'\n",
+ "\n",
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
+ "mean, std = -4, 4\n",
+ "\n",
+ "def length_to_mask(lengths):\n",
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
+ " return mask\n",
+ "\n",
+ "def preprocess(wave):\n",
+ " wave_tensor = torch.from_numpy(wave).float()\n",
+ " mel_tensor = to_mel(wave_tensor)\n",
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
+ " return mel_tensor\n",
+ "\n",
+ "def compute_style(ref_dicts):\n",
+ " reference_embeddings = {}\n",
+ " for key, path in ref_dicts.items():\n",
+ " wave, sr = librosa.load(path, sr=24000)\n",
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
+ " if sr != 24000:\n",
+ " audio = librosa.resample(audio, sr, 24000)\n",
+ " mel_tensor = preprocess(audio).to(device)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " ref = model.style_encoder(mel_tensor.unsqueeze(1))\n",
+ " reference_embeddings[key] = (ref.squeeze(1), audio)\n",
+ "\n",
+ " return reference_embeddings\n",
+ "\n",
+ "# load phonemizer\n",
+ "import phonemizer\n",
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')\n",
+ "\n",
+ "config = yaml.safe_load(open(\"Models/Kaede-san/config.yml\"))\n",
+ "\n",
+ "# load pretrained ASR model\n",
+ "ASR_config = config.get('ASR_config', False)\n",
+ "ASR_path = config.get('ASR_path', False)\n",
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
+ "\n",
+ "# load pretrained F0 model\n",
+ "F0_path = config.get('F0_path', False)\n",
+ "pitch_extractor = load_F0_models(F0_path)\n",
+ "\n",
+ "# load BERT model\n",
+ "from Utils.PLBERT.util import load_plbert\n",
+ "BERT_path = config.get('PLBERT_dir', False)\n",
+ "plbert = load_plbert(BERT_path)\n",
+ "\n",
+ "model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)\n",
+ "_ = [model[key].eval() for key in model]\n",
+ "_ = [model[key].to(device) for key in model]\n",
+ "\n",
+ "params_whole = torch.load(\"Models/LJSpeech/epoch_2nd_00100.pth\", map_location='cpu')\n",
+ "params = params_whole['net']\n",
+ "\n",
+ "for key in model:\n",
+ " if key in params:\n",
+ " print('%s loaded' % key)\n",
+ " try:\n",
+ " model[key].load_state_dict(params[key])\n",
+ " except:\n",
+ " from collections import OrderedDict\n",
+ " state_dict = params[key]\n",
+ " new_state_dict = OrderedDict()\n",
+ " for k, v in state_dict.items():\n",
+ " name = k[7:] # remove `module.`\n",
+ " new_state_dict[name] = v\n",
+ " # load params\n",
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
+ "# except:\n",
+ "# _load(params[key], model[key])\n",
+ "_ = [model[key].eval() for key in model]\n",
+ "\n",
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n",
+ "\n",
+ "sampler = DiffusionSampler(\n",
+ " model.diffusion.diffusion,\n",
+ " sampler=ADPM2Sampler(),\n",
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
+ " clamp=False\n",
+ ")\n",
+ "\n",
+ "def inference(text, noise, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " text = text.replace('\"', '')\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ "\n",
+ " tokens = textclenaer(ps)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n",
+ " text_mask = length_to_mask(input_lengths).to(tokens.device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
+ "\n",
+ " s_pred = sampler(noise,\n",
+ " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n",
+ " embedding_scale=embedding_scale).squeeze(0)\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ " pred_dur[-1] += 5\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),\n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ "\n",
+ " return out.squeeze().cpu().numpy()\n",
+ "\n",
+ "def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " text = text.replace('\"', '')\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ "\n",
+ " tokens = textclenaer(\"so↑nna to↓kini, jo↓kaʔta jo↑otoka, ge↓ŋkiga mo↑ɾaeɾʔʔte i↑ʔte mo↑ɾaeɾɯto, so↑ɾedakede ta↑içeɴ↓sanante ɸɯ↑kito↓ndʑa i↑ma↓sɯ. a↑idoɾɯo ja↓ʔte i↑tejo↓kaʔtaʔte o↑moe↓ɾɯ n↓desɯjo. ko↑no ka↑itoo, sɯ↑ko↓ɕi zɯ↑ɾɯ↓ikamo ɕi↑ɾemase↓ŋkedo, wa↑taɕino ço↑ntoono ki↑motɕide↓sɯ.\")\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n",
+ " text_mask = length_to_mask(input_lengths).to(tokens.device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
+ "\n",
+ " s_pred = sampler(noise,\n",
+ " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n",
+ " embedding_scale=embedding_scale).squeeze(0)\n",
+ "\n",
+ " if s_prev is not None:\n",
+ " # convex combination of previous and current style\n",
+ " s_pred = alpha * s_prev + (1 - alpha) * s_pred\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),\n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ "\n",
+ " return out.squeeze().cpu().numpy(), s_pred"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vuCbS0gdArgJ"
+ },
+ "source": [
+ "### Synthesize speech"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "7Ud1Y-kbBPTw"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Input Text { display-mode: \"form\" }\n",
+ "# synthesize a text\n",
+ "text = \"StyleTTS 2 is a text-to-speech model that leverages style diffusion and adversarial training with large speech language models to achieve human-level text-to-speech synthesis.\" # @param {type:\"string\"}\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TM2NjuM7B6sz"
+ },
+ "source": [
+ "#### Basic synthesis (5 diffusion steps)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KILqC-V-Ay5e"
+ },
+ "outputs": [],
+ "source": [
+ "start = time.time()\n",
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "wav = inference(text, noise, diffusion_steps=5, embedding_scale=1)\n",
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ "print(f\"RTF = {rtf:5f}\")\n",
+ "import IPython.display as ipd\n",
+ "display(ipd.Audio(wav, rate=24000))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oZk9o-EzCBVx"
+ },
+ "source": [
+ "#### With higher diffusion steps (more diverse)\n",
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9_OHtzMbB9gL"
+ },
+ "outputs": [],
+ "source": [
+ "start = time.time()\n",
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "wav = inference(text, noise, diffusion_steps=10, embedding_scale=1)\n",
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ "print(f\"RTF = {rtf:5f}\")\n",
+ "import IPython.display as ipd\n",
+ "display(ipd.Audio(wav, rate=24000))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NyDACd-0CaqL"
+ },
+ "source": [
+ "### Speech expressiveness\n",
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cRkS5VWxCck4"
+ },
+ "source": [
+ "#### With embedding_scale=1\n",
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "H5g5RO-mCbZB"
+ },
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav = inference(v, noise, diffusion_steps=10, embedding_scale=1)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "f4S8TXSpCgpA"
+ },
+ "source": [
+ "#### With embedding_scale=2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xHHIdeNrCezC"
+ },
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav = inference(v, noise, diffusion_steps=10, embedding_scale=2) # embedding_scale=2 for more pronounced emotion\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nAh7Tov4CkuH"
+ },
+ "source": [
+ "### Long-form generation\n",
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "cellView": "form",
+ "id": "IJwUbgvACoDu"
+ },
+ "outputs": [],
+ "source": [
+ "passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first-class homemade products there is a market in all large cities. All first-class grocers have customers who purchase such goods.''' # @param {type:\"string\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nP-7i2QAC0JT"
+ },
+ "outputs": [],
+ "source": [
+ "sentences = passage.split('.') # simple split by comma\n",
+ "wavs = []\n",
+ "s_prev = None\n",
+ "for text in sentences:\n",
+ " if text.strip() == \"\": continue\n",
+ " text += '.' # add it back\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=10, embedding_scale=1.5)\n",
+ " wavs.append(wav)\n",
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "authorship_tag": "ABX9TyM1x2mx2VnkYNFVlD+DFzmy",
+ "gpuType": "T4",
+ "include_colab_link": true,
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Demo_LibriTTS.ipynb b/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Demo_LibriTTS.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..be12546936926dbd0e8a2525c589b76074e18f68
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Demo_LibriTTS.ipynb
@@ -0,0 +1,1218 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aAGQPfgYIR23"
+ },
+ "source": [
+ "### Install packages and download models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zDPW5uSpISd2",
+ "outputId": "6463ff79-18d5-4071-c6ad-01947beeb368"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+
+ ]
+ }
+ ],
+ "source": [
+ "%%shell\n",
+ "git clone https://github.com/yl4579/StyleTTS2.git\n",
+ "cd StyleTTS2\n",
+ "pip install SoundFile torchaudio munch torch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n",
+ "sudo apt-get install espeak-ng\n",
+ "git-lfs clone https://huggingface.co/yl4579/StyleTTS2-LibriTTS\n",
+ "mv StyleTTS2-LibriTTS/Models .\n",
+ "mv StyleTTS2-LibriTTS/reference_audio.zip .\n",
+ "unzip reference_audio.zip\n",
+ "mv reference_audio Demo/reference_audio"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eJdB_nCOIVIN"
+ },
+ "source": [
+ "### Load models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cha8Tr2uJwN0"
+ },
+ "outputs": [],
+ "source": [
+ "import nltk\n",
+ "nltk.download('punkt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Qoow8Wd8ITtm"
+ },
+ "outputs": [],
+ "source": [
+ "%cd StyleTTS2\n",
+ "\n",
+ "import torch\n",
+ "torch.manual_seed(0)\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "\n",
+ "import random\n",
+ "random.seed(0)\n",
+ "\n",
+ "import numpy as np\n",
+ "np.random.seed(0)\n",
+ "\n",
+ "# load packages\n",
+ "import time\n",
+ "import random\n",
+ "import yaml\n",
+ "from munch import Munch\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "import torch.nn.functional as F\n",
+ "import torchaudio\n",
+ "import librosa\n",
+ "from nltk.tokenize import word_tokenize\n",
+ "\n",
+ "from models import *\n",
+ "from utils import *\n",
+ "from text_utils import TextCleaner\n",
+ "textclenaer = TextCleaner()\n",
+ "\n",
+ "%matplotlib inline\n",
+ "\n",
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
+ "mean, std = -4, 4\n",
+ "\n",
+ "def length_to_mask(lengths):\n",
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
+ " return mask\n",
+ "\n",
+ "def preprocess(wave):\n",
+ " wave_tensor = torch.from_numpy(wave).float()\n",
+ " mel_tensor = to_mel(wave_tensor)\n",
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
+ " return mel_tensor\n",
+ "\n",
+ "def compute_style(path):\n",
+ " wave, sr = librosa.load(path, sr=24000)\n",
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
+ " if sr != 24000:\n",
+ " audio = librosa.resample(audio, sr, 24000)\n",
+ " mel_tensor = preprocess(audio).to(device)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
+ "\n",
+ " return torch.cat([ref_s, ref_p], dim=1)\n",
+ "\n",
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "\n",
+ "# load phonemizer\n",
+ "import phonemizer\n",
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n",
+ "\n",
+ "config = yaml.safe_load(open(\"Models/LibriTTS/config.yml\"))\n",
+ "\n",
+ "# load pretrained ASR model\n",
+ "ASR_config = config.get('ASR_config', False)\n",
+ "ASR_path = config.get('ASR_path', False)\n",
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
+ "\n",
+ "# load pretrained F0 model\n",
+ "F0_path = config.get('F0_path', False)\n",
+ "pitch_extractor = load_F0_models(F0_path)\n",
+ "\n",
+ "# load BERT model\n",
+ "from Utils.PLBERT.util import load_plbert\n",
+ "BERT_path = config.get('PLBERT_dir', False)\n",
+ "plbert = load_plbert(BERT_path)\n",
+ "\n",
+ "model_params = recursive_munch(config['model_params'])\n",
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
+ "_ = [model[key].eval() for key in model]\n",
+ "_ = [model[key].to(device) for key in model]\n",
+ "\n",
+ "params_whole = torch.load(\"Models/LibriTTS/epochs_2nd_00020.pth\", map_location='cpu')\n",
+ "params = params_whole['net']\n",
+ "\n",
+ "for key in model:\n",
+ " if key in params:\n",
+ " print('%s loaded' % key)\n",
+ " try:\n",
+ " model[key].load_state_dict(params[key])\n",
+ " except:\n",
+ " from collections import OrderedDict\n",
+ " state_dict = params[key]\n",
+ " new_state_dict = OrderedDict()\n",
+ " for k, v in state_dict.items():\n",
+ " name = k[7:] # remove `module.`\n",
+ " new_state_dict[name] = v\n",
+ " # load params\n",
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
+ "# except:\n",
+ "# _load(params[key], model[key])\n",
+ "_ = [model[key].eval() for key in model]\n",
+ "\n",
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n",
+ "\n",
+ "sampler = DiffusionSampler(\n",
+ " model.diffusion.diffusion,\n",
+ " sampler=ADPM2Sampler(),\n",
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
+ " clamp=False\n",
+ ")\n",
+ "\n",
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ " tokens = textclenaer(ps)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
+ "\n",
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
+ " embedding=bert_dur,\n",
+ " embedding_scale=embedding_scale,\n",
+ " features=ref_s, # reference from the same speaker as the embedding\n",
+ " num_steps=diffusion_steps).squeeze(1)\n",
+ "\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en,\n",
+ " s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ "\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(en)\n",
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
+ " en = asr_new\n",
+ "\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ "\n",
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(asr)\n",
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
+ " asr = asr_new\n",
+ "\n",
+ " out = model.decoder(asr,\n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ "\n",
+ "\n",
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later\n",
+ "\n",
+ "def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ " ps = ps.replace('``', '\"')\n",
+ " ps = ps.replace(\"''\", '\"')\n",
+ "\n",
+ " tokens = textclenaer(ps)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
+ "\n",
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
+ " embedding=bert_dur,\n",
+ " embedding_scale=embedding_scale,\n",
+ " features=ref_s, # reference from the same speaker as the embedding\n",
+ " num_steps=diffusion_steps).squeeze(1)\n",
+ "\n",
+ " if s_prev is not None:\n",
+ " # convex combination of previous and current style\n",
+ " s_pred = t * s_prev + (1 - t) * s_pred\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
+ "\n",
+ " s_pred = torch.cat([ref, s], dim=-1)\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en,\n",
+ " s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ "\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(en)\n",
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
+ " en = asr_new\n",
+ "\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ "\n",
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(asr)\n",
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
+ " asr = asr_new\n",
+ "\n",
+ " out = model.decoder(asr,\n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ "\n",
+ "\n",
+ " return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later\n",
+ "\n",
+ "def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ "\n",
+ " tokens = textclenaer(ps)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ "\n",
+ " ref_text = ref_text.strip()\n",
+ " ps = global_phonemizer.phonemize([ref_text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ "\n",
+ " ref_tokens = textclenaer(ps)\n",
+ " ref_tokens.insert(0, 0)\n",
+ " ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)\n",
+ "\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
+ "\n",
+ " ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)\n",
+ " ref_text_mask = length_to_mask(ref_input_lengths).to(device)\n",
+ " ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())\n",
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
+ " embedding=bert_dur,\n",
+ " embedding_scale=embedding_scale,\n",
+ " features=ref_s, # reference from the same speaker as the embedding\n",
+ " num_steps=diffusion_steps).squeeze(1)\n",
+ "\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en,\n",
+ " s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ "\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(en)\n",
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
+ " en = asr_new\n",
+ "\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ "\n",
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(asr)\n",
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
+ " asr = asr_new\n",
+ "\n",
+ " out = model.decoder(asr,\n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ "\n",
+ "\n",
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "32S6U0LyJbCA"
+ },
+ "source": [
+ "### Synthesize speech"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ehK_0daMJdk_"
+ },
+ "source": [
+ "#### Basic synthesis (5 diffusion steps, seen speakers)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SJs2x41MJhM-"
+ },
+ "outputs": [],
+ "source": [
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. ''' # @param {type:\"string\"}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xuqIJe-IJb7A"
+ },
+ "outputs": [],
+ "source": [
+ "reference_dicts = {}\n",
+ "reference_dicts['696_92939'] = \"Demo/reference_audio/696_92939_000016_000006.wav\"\n",
+ "reference_dicts['1789_142896'] = \"Demo/reference_audio/1789_142896_000022_000005.wav\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "H3ra3IxJJmF0"
+ },
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, path in reference_dicts.items():\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print(k + ' Synthesized:')\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print('Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aB3wUz6yJ-P_"
+ },
+ "source": [
+ "#### With higher diffusion steps (more diverse)\n",
+ "\n",
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lF27XUo4JrKk"
+ },
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, path in reference_dicts.items():\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=10, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print(k + ' Synthesized:')\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print(k + ' Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pFT_vmJcKDs1"
+ },
+ "source": [
+ "#### Basic synthesis (5 diffusion steps, unseen speakers)\n",
+ "The following samples are to reproduce samples in [Section 4](https://styletts2.github.io/#libri) of the demo page. All spsakers are unseen during training. You can compare the generated samples to popular zero-shot TTS models like Vall-E and NaturalSpeech 2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HvNAeGPEKAWN"
+ },
+ "outputs": [],
+ "source": [
+ "reference_dicts = {}\n",
+ "# format: (path, text)\n",
+ "reference_dicts['1221-135767'] = (\"Demo/reference_audio/1221-135767-0014.wav\", \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\")\n",
+ "reference_dicts['5639-40744'] = (\"Demo/reference_audio/5639-40744-0020.wav\", \"Thus did this humane and right minded father comfort his unhappy daughter, and her mother embracing her again, did all she could to soothe her feelings.\")\n",
+ "reference_dicts['908-157963'] = (\"Demo/reference_audio/908-157963-0027.wav\", \"And lay me down in my cold bed and leave my shining lot.\")\n",
+ "reference_dicts['4077-13754'] = (\"Demo/reference_audio/4077-13754-0000.wav\", \"The army found the people in poverty and left them in comparative wealth.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mFnyvYp5KAYN"
+ },
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, v in reference_dicts.items():\n",
+ " path, text = v\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print(k + ' Synthesized: ' + text)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print(k + ' Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QBZ53BQtKNQ6"
+ },
+ "source": [
+ "### Speech expressiveness\n",
+ "\n",
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page. The speaker reference used is `1221-135767-0014.wav`, which is unseen during training.\n",
+ "\n",
+ "#### With `embedding_scale=1`\n",
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5FwE9CefKQk6"
+ },
+ "outputs": [],
+ "source": [
+ "ref_s = compute_style(\"Demo/reference_audio/1221-135767-0014.wav\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0CKMI0ZsKUDh"
+ },
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "reemQKVEKWAZ"
+ },
+ "source": [
+ "#### With `embedding_scale=2`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "npIAiAUvKYGv"
+ },
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=2)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lqKZaXeYKbrH"
+ },
+ "source": [
+ "#### With `embedding_scale=2, alpha = 0.5, beta = 0.9`\n",
+ "`alpha` and `beta` is the factor to determine much we use the style sampled based on the text instead of the reference. The higher the value of `alpha` and `beta`, the more suitable the style it is to the text but less similar to the reference. Using higher beta makes the synthesized speech more emotional, at the cost of lower similarity to the reference. `alpha` determines the timbre of the speaker while `beta` determines the prosody."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VjXuRCCWKcdN"
+ },
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=2)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xrwYXGh0KiIW"
+ },
+ "source": [
+ "### Zero-shot speaker adaptation\n",
+ "This section recreates the \"Acoustic Environment Maintenance\" and \"Speaker’s Emotion Maintenance\" demo in [Section 4](https://styletts2.github.io/#libri) of the demo page. You can compare the generated samples to popular zero-shot TTS models like Vall-E. Note that the model was trained only on LibriTTS, which is about 250 times fewer data compared to those used to trian Vall-E with similar or better effect for these maintainance."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ETUywHHmKimE"
+ },
+ "source": [
+ "#### Acoustic Environment Maintenance\n",
+ "\n",
+ "Since we want to maintain the acoustic environment in the speaker (timbre), we set `alpha = 0` to make the speaker as close to the reference as possible while only changing the prosody according to the text. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yvjBK3syKnZL"
+ },
+ "outputs": [],
+ "source": [
+ "reference_dicts = {}\n",
+ "# format: (path, text)\n",
+ "reference_dicts['3'] = (\"Demo/reference_audio/3.wav\", \"As friends thing I definitely I've got more male friends.\")\n",
+ "reference_dicts['4'] = (\"Demo/reference_audio/4.wav\", \"Everything is run by computer but you got to know how to think before you can do a computer.\")\n",
+ "reference_dicts['5'] = (\"Demo/reference_audio/5.wav\", \"Then out in LA you guys got a whole another ball game within California to worry about.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jclowWp4KomJ"
+ },
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, v in reference_dicts.items():\n",
+ " path, text = v\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.0, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print('Synthesized: ' + text)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print('Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LgIm7M93KqVZ"
+ },
+ "source": [
+ "#### Speaker’s Emotion Maintenance\n",
+ "\n",
+ "Since we want to maintain the emotion in the speaker (prosody), we set `beta = 0.1` to make the speaker as closer to the reference as possible while having some diversity thruogh the slight timbre change."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yzsNoP6oKulL"
+ },
+ "outputs": [],
+ "source": [
+ "reference_dicts = {}\n",
+ "# format: (path, text)\n",
+ "reference_dicts['Anger'] = (\"Demo/reference_audio/anger.wav\", \"We have to reduce the number of plastic bags.\")\n",
+ "reference_dicts['Sleepy'] = (\"Demo/reference_audio/sleepy.wav\", \"We have to reduce the number of plastic bags.\")\n",
+ "reference_dicts['Amused'] = (\"Demo/reference_audio/amused.wav\", \"We have to reduce the number of plastic bags.\")\n",
+ "reference_dicts['Disgusted'] = (\"Demo/reference_audio/disgusted.wav\", \"We have to reduce the number of plastic bags.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7h2-9cpfKwr4"
+ },
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, v in reference_dicts.items():\n",
+ " path, text = v\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.1, diffusion_steps=10, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print(k + ' Synthesized: ' + text)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print(k + ' Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aNS82PGwKzgg"
+ },
+ "source": [
+ "### Longform Narration\n",
+ "\n",
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "qs97nL5HK5DH"
+ },
+ "outputs": [],
+ "source": [
+ "passage = passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first class home made products there is a market in all large cities. All first-class grocers have customers who purchase such goods.''' # @param {type:\"string\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "background_save": true
+ },
+ "id": "8Mu9whHYK_1b"
+ },
+ "outputs": [],
+ "source": [
+ "# seen speaker\n",
+ "path = \"Demo/reference_audio/696_92939_000016_000006.wav\"\n",
+ "s_ref = compute_style(path)\n",
+ "sentences = passage.split('.') # simple split by comma\n",
+ "wavs = []\n",
+ "s_prev = None\n",
+ "for text in sentences:\n",
+ " if text.strip() == \"\": continue\n",
+ " text += '.' # add it back\n",
+ "\n",
+ " wav, s_prev = LFinference(text,\n",
+ " s_prev,\n",
+ " s_ref,\n",
+ " alpha = 0.3,\n",
+ " beta = 0.9, # make it more suitable for the text\n",
+ " t = 0.7,\n",
+ " diffusion_steps=10, embedding_scale=1.5)\n",
+ " wavs.append(wav)\n",
+ "print('Synthesized: ')\n",
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))\n",
+ "print('Reference: ')\n",
+ "display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "81Rh-lgWLB2i"
+ },
+ "source": [
+ "### Style Transfer\n",
+ "\n",
+ "The following section demostrates the style transfer capacity for unseen speakers in [Section 6](https://styletts2.github.io/#emo) of the demo page. For this, we set `alpha=0.5, beta = 0.9` for the most pronounced effects (mostly using the sampled style)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "CtIgr5kOLE9a"
+ },
+ "outputs": [],
+ "source": [
+ "# reference texts to sample styles\n",
+ "\n",
+ "ref_texts = {}\n",
+ "ref_texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "ref_texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "ref_texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "ref_texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MlA1CbhzLIoI"
+ },
+ "outputs": [],
+ "source": [
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "s_ref = compute_style(path)\n",
+ "\n",
+ "text = \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\"\n",
+ "for k,v in ref_texts.items():\n",
+ " wav = STinference(text, s_ref, v, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=1.5)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2M0iaXlkLJUQ"
+ },
+ "source": [
+ "### Speech diversity\n",
+ "\n",
+ "This section reproduces samples in [Section 7](https://styletts2.github.io/#var) of the demo page.\n",
+ "\n",
+ "`alpha` and `beta` determine the diversity of the synthesized speech. There are two extreme cases:\n",
+ "- If `alpha = 1` and `beta = 1`, the synthesized speech sounds the most dissimilar to the reference speaker, but it is also the most diverse (each time you synthesize a speech it will be totally different).\n",
+ "- If `alpha = 0` and `beta = 0`, the synthesized speech sounds the most siimlar to the reference speaker, but it is deterministic (i.e., the sampled style is not used for speech synthesis).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tSxZDvF2LNu4"
+ },
+ "source": [
+ "#### Default setting (`alpha = 0.3, beta=0.7`)\n",
+ "This setting uses 70% of the reference timbre and 30% of the reference prosody and use the diffusion model to sample them based on the text."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AAomGCDZLIt5"
+ },
+ "outputs": [],
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BKrSMdgcLQRP"
+ },
+ "source": [
+ "#### Less diverse setting (`alpha = 0.1, beta=0.3`)\n",
+ "This setting uses 90% of the reference timbre and 70% of the reference prosody. This makes it more similar to the reference speaker at cost of less diverse samples."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Uo7gVmFoLRfm"
+ },
+ "outputs": [],
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.1, beta=0.3, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nfQ0Xrg9LStd"
+ },
+ "source": [
+ "#### More diverse setting (`alpha = 0.5, beta=0.95`)\n",
+ "This setting uses 50% of the reference timbre and 5% of the reference prosody (so it uses 100% of the sampled prosody, which makes it more diverse), but this makes it more dissimilar to the reference speaker. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cPHz4BzVLT_u"
+ },
+ "outputs": [],
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.5, beta=0.95, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#### Extreme setting (`alpha = 1, beta=1`)\n",
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very dissimilar to the reference speaker."
+ ],
+ "metadata": {
+ "id": "hPKg9eYpL00f"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=1, beta=1, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ],
+ "metadata": {
+ "id": "Ei-7JOccL0bF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#### No variation (`alpha = 0, beta=0`)\n",
+ "This setting uses 100% of the reference timbre and prosody and do not use the diffusion model at all. This makes the speaker very similar to the reference speaker, but there is no variation."
+ ],
+ "metadata": {
+ "id": "FVMPc3bhL3eL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0, beta=0, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ],
+ "metadata": {
+ "id": "yh1QZ7uhL4wM"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Extra fun!\n",
+ "\n",
+ "You can record your own voice and clone it using pre-trained StyleTTS 2 model here."
+ ],
+ "metadata": {
+ "id": "T0EvkWrAMBDB"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#### Run the following cell to record your voice for 5 seconds. Please keep speaking to have the best effect."
+ ],
+ "metadata": {
+ "id": "R985j5QONY8I"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# all imports\n",
+ "from IPython.display import Javascript\n",
+ "from google.colab import output\n",
+ "from base64 import b64decode\n",
+ "\n",
+ "RECORD = \"\"\"\n",
+ "const sleep = time => new Promise(resolve => setTimeout(resolve, time))\n",
+ "const b2text = blob => new Promise(resolve => {\n",
+ " const reader = new FileReader()\n",
+ " reader.onloadend = e => resolve(e.srcElement.result)\n",
+ " reader.readAsDataURL(blob)\n",
+ "})\n",
+ "var record = time => new Promise(async resolve => {\n",
+ " stream = await navigator.mediaDevices.getUserMedia({ audio: true })\n",
+ " recorder = new MediaRecorder(stream)\n",
+ " chunks = []\n",
+ " recorder.ondataavailable = e => chunks.push(e.data)\n",
+ " recorder.start()\n",
+ " await sleep(time)\n",
+ " recorder.onstop = async ()=>{\n",
+ " blob = new Blob(chunks)\n",
+ " text = await b2text(blob)\n",
+ " resolve(text)\n",
+ " }\n",
+ " recorder.stop()\n",
+ "})\n",
+ "\"\"\"\n",
+ "\n",
+ "def record(sec=3):\n",
+ " display(Javascript(RECORD))\n",
+ " s = output.eval_js('record(%d)' % (sec*1000))\n",
+ " b = b64decode(s.split(',')[1])\n",
+ " with open('audio.wav','wb') as f:\n",
+ " f.write(b)\n",
+ " return 'audio.wav' # or webm ?"
+ ],
+ "metadata": {
+ "id": "MWrFs0KWMBpz"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#### Please run this cell and speak:"
+ ],
+ "metadata": {
+ "id": "z35qXwM0Nhx1"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print('Speak now for 5 seconds.')\n",
+ "audio = record(sec=5)\n",
+ "import IPython.display as ipd\n",
+ "display(ipd.Audio(audio, rate=24000, normalize=False))"
+ ],
+ "metadata": {
+ "id": "KUEoFyQBMR-8"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#### Synthesize in your own voice"
+ ],
+ "metadata": {
+ "id": "OQS_7IBpNmM1"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. ''' # @param {type:\"string\"}\n"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "c0I3LY7vM8Ta"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "reference_dicts = {}\n",
+ "reference_dicts['You'] = audio"
+ ],
+ "metadata": {
+ "id": "80eW-pwxNCxu"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "start = time.time()\n",
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, path in reference_dicts.items():\n",
+ " ref_s = compute_style(path)\n",
+ "\n",
+ " wav = inference(text, ref_s, alpha=0.1, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print('Speaker: ' + k)\n",
+ " import IPython.display as ipd\n",
+ " print('Synthesized:')\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print('Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ],
+ "metadata": {
+ "id": "yIga6MTuNJaN"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "collapsed_sections": [
+ "aAGQPfgYIR23",
+ "eJdB_nCOIVIN",
+ "R985j5QONY8I"
+ ],
+ "authorship_tag": "ABX9TyPQdFTqqVEknEG/ma/HMfU+",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Finetune_Demo.ipynb b/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Finetune_Demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..facfdadd37771cef3a5f88a5d6028cc37cc4b75b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Colab/StyleTTS2_Finetune_Demo.ipynb
@@ -0,0 +1,480 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4",
+ "authorship_tag": "ABX9TyNiDU9ykIeYxO86Lmuid+ph",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Install packages and download models"
+ ],
+ "metadata": {
+ "id": "yLqBa4uYPrqE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%shell\n",
+ "git clone https://github.com/yl4579/StyleTTS2.git\n",
+ "cd StyleTTS2\n",
+ "pip install SoundFile torchaudio munch torch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n",
+ "sudo apt-get install espeak-ng\n",
+ "git-lfs clone https://huggingface.co/yl4579/StyleTTS2-LibriTTS\n",
+ "mv StyleTTS2-LibriTTS/Models ."
+ ],
+ "metadata": {
+ "id": "H72WF06ZPrTF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Download dataset (LJSpeech, 200 samples, ~15 minutes of data)\n",
+ "\n",
+ "You can definitely do it with fewer samples. This is just a proof of concept with 200 smaples."
+ ],
+ "metadata": {
+ "id": "G398sL8wPzTB"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd StyleTTS2\n",
+ "!rm -rf Data"
+ ],
+ "metadata": {
+ "id": "kJuQUBrEPy5C"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!gdown --id 1vqz26D3yn7OXS2vbfYxfSnpLS6m6tOFP\n",
+ "!unzip Data.zip"
+ ],
+ "metadata": {
+ "id": "mDXW8ZZePuSb"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Change the finetuning config\n",
+ "\n",
+ "Depending on the GPU you got, you may want to change the bacth size, max audio length, epiochs and so on."
+ ],
+ "metadata": {
+ "id": "_AlBQREWU8ud"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "config_path = \"Configs/config_ft.yml\"\n",
+ "\n",
+ "import yaml\n",
+ "config = yaml.safe_load(open(config_path))"
+ ],
+ "metadata": {
+ "id": "7uEITi0hU4I2"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "config['data_params']['root_path'] = \"Data/wavs\"\n",
+ "\n",
+ "config['batch_size'] = 2 # not enough RAM\n",
+ "config['max_len'] = 100 # not enough RAM\n",
+ "config['loss_params']['joint_epoch'] = 110 # we do not do SLM adversarial training due to not enough RAM\n",
+ "\n",
+ "with open(config_path, 'w') as outfile:\n",
+ " yaml.dump(config, outfile, default_flow_style=True)"
+ ],
+ "metadata": {
+ "id": "TPTRgOKSVT4K"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Start finetuning\n"
+ ],
+ "metadata": {
+ "id": "uUuB_19NWj2Y"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!python train_finetune.py --config_path ./Configs/config_ft.yml"
+ ],
+ "metadata": {
+ "id": "HZVAD5GKWm-O"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Test the model quality\n",
+ "\n",
+ "Note that this mainly serves as a proof of concept due to RAM limitation of free Colab instances. A lot of settings are suboptimal. In the future when DDP works for train_second.py, we will also add mixed precision finetuning to save time and RAM. You can also add SLM adversarial training run if you have paid Colab services (such as A100 with 40G of RAM)."
+ ],
+ "metadata": {
+ "id": "I0_7wsGkXGfc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import nltk\n",
+ "nltk.download('punkt')"
+ ],
+ "metadata": {
+ "id": "OPLphjbncE7p"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "torch.manual_seed(0)\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "\n",
+ "import random\n",
+ "random.seed(0)\n",
+ "\n",
+ "import numpy as np\n",
+ "np.random.seed(0)\n",
+ "\n",
+ "# load packages\n",
+ "import time\n",
+ "import random\n",
+ "import yaml\n",
+ "from munch import Munch\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "import torch.nn.functional as F\n",
+ "import torchaudio\n",
+ "import librosa\n",
+ "from nltk.tokenize import word_tokenize\n",
+ "\n",
+ "from models import *\n",
+ "from utils import *\n",
+ "from text_utils import TextCleaner\n",
+ "textclenaer = TextCleaner()\n",
+ "\n",
+ "%matplotlib inline\n",
+ "\n",
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
+ "mean, std = -4, 4\n",
+ "\n",
+ "def length_to_mask(lengths):\n",
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
+ " return mask\n",
+ "\n",
+ "def preprocess(wave):\n",
+ " wave_tensor = torch.from_numpy(wave).float()\n",
+ " mel_tensor = to_mel(wave_tensor)\n",
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
+ " return mel_tensor\n",
+ "\n",
+ "def compute_style(path):\n",
+ " wave, sr = librosa.load(path, sr=24000)\n",
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
+ " if sr != 24000:\n",
+ " audio = librosa.resample(audio, sr, 24000)\n",
+ " mel_tensor = preprocess(audio).to(device)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
+ "\n",
+ " return torch.cat([ref_s, ref_p], dim=1)\n",
+ "\n",
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "\n",
+ "# load phonemizer\n",
+ "import phonemizer\n",
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n",
+ "\n",
+ "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n",
+ "\n",
+ "# load pretrained ASR model\n",
+ "ASR_config = config.get('ASR_config', False)\n",
+ "ASR_path = config.get('ASR_path', False)\n",
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
+ "\n",
+ "# load pretrained F0 model\n",
+ "F0_path = config.get('F0_path', False)\n",
+ "pitch_extractor = load_F0_models(F0_path)\n",
+ "\n",
+ "# load BERT model\n",
+ "from Utils.PLBERT.util import load_plbert\n",
+ "BERT_path = config.get('PLBERT_dir', False)\n",
+ "plbert = load_plbert(BERT_path)\n",
+ "\n",
+ "model_params = recursive_munch(config['model_params'])\n",
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
+ "_ = [model[key].eval() for key in model]\n",
+ "_ = [model[key].to(device) for key in model]"
+ ],
+ "metadata": {
+ "id": "jIIAoDACXJL0"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n",
+ "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))"
+ ],
+ "metadata": {
+ "id": "eKXRAyyzcMpQ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n",
+ "params = params_whole['net']"
+ ],
+ "metadata": {
+ "id": "ULuU9-VDb9Pk"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "for key in model:\n",
+ " if key in params:\n",
+ " print('%s loaded' % key)\n",
+ " try:\n",
+ " model[key].load_state_dict(params[key])\n",
+ " except:\n",
+ " from collections import OrderedDict\n",
+ " state_dict = params[key]\n",
+ " new_state_dict = OrderedDict()\n",
+ " for k, v in state_dict.items():\n",
+ " name = k[7:] # remove `module.`\n",
+ " new_state_dict[name] = v\n",
+ " # load params\n",
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
+ "# except:\n",
+ "# _load(params[key], model[key])\n",
+ "_ = [model[key].eval() for key in model]"
+ ],
+ "metadata": {
+ "id": "J-U29yIYc2ea"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule"
+ ],
+ "metadata": {
+ "id": "jrPQ_Yrwc3n6"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "sampler = DiffusionSampler(\n",
+ " model.diffusion.diffusion,\n",
+ " sampler=ADPM2Sampler(),\n",
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
+ " clamp=False\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "n2CWYNoqc455"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ " tokens = textclenaer(ps)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
+ "\n",
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
+ " embedding=bert_dur,\n",
+ " embedding_scale=embedding_scale,\n",
+ " features=ref_s, # reference from the same speaker as the embedding\n",
+ " num_steps=diffusion_steps).squeeze(1)\n",
+ "\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en,\n",
+ " s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ "\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(en)\n",
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
+ " en = asr_new\n",
+ "\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ "\n",
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(asr)\n",
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
+ " asr = asr_new\n",
+ "\n",
+ " out = model.decoder(asr,\n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ "\n",
+ "\n",
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
+ ],
+ "metadata": {
+ "id": "2x5kVb3nc_eY"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Synthesize speech"
+ ],
+ "metadata": {
+ "id": "O159JnwCc6CC"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "text = '''Maltby and Company would issue warrants on them deliverable to the importer, and the goods were then passed to be stored in neighboring warehouses.\n",
+ "'''"
+ ],
+ "metadata": {
+ "id": "ThciXQ6rc9Eq"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# get a random reference in the training set, note that it doesn't matter which one you use\n",
+ "path = \"Data/wavs/LJ001-0110.wav\"\n",
+ "# this style vector ref_s can be saved as a parameter together with the model weights\n",
+ "ref_s = compute_style(path)"
+ ],
+ "metadata": {
+ "id": "jldPkJyCc83a"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "start = time.time()\n",
+ "wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n",
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ "print(f\"RTF = {rtf:5f}\")\n",
+ "import IPython.display as ipd\n",
+ "display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ],
+ "metadata": {
+ "id": "_mIU0jqDdQ-c"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/stts_48khz/StyleTTS2_48khz/Configs/config.yml b/stts_48khz/StyleTTS2_48khz/Configs/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..511b87714a234f09ee1abb255a1929c3c87f4803
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Configs/config.yml
@@ -0,0 +1,116 @@
+log_dir: "Models/LJSpeech"
+first_stage_path: "first_stage.pth"
+save_freq: 2
+log_interval: 10
+device: "cuda"
+epochs_1st: 200 # number of epochs for first stage training (pre-training)
+epochs_2nd: 100 # number of peochs for second stage training (joint training)
+batch_size: 16
+max_len: 325 # maximum number of frames
+pretrained_model: ""
+second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
+load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
+
+F0_path: "Utils/JDC/bst.t7"
+ASR_config: "Utils/ASR/config.yml"
+ASR_path: "Utils/ASR/epoch_00080.pth"
+PLBERT_dir: 'Utils/PLBERT/'
+
+data_params:
+ train_data: "Data/train_list.txt"
+ val_data: "Data/val_list.txt"
+ root_path: "/local/LJSpeech-1.1/wavs"
+ OOD_data: "Data/OOD_texts.txt"
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
+
+preprocess_params:
+ sr: 24000
+ spect_params:
+ n_fft: 2048
+ win_length: 1200
+ hop_length: 300
+
+model_params:
+ multispeaker: false
+
+ dim_in: 64
+ hidden_dim: 512
+ max_conv_dim: 512
+ n_layer: 3
+ n_mels: 80
+
+ n_token: 178 # number of phoneme tokens
+ max_dur: 50 # maximum duration of a single phoneme
+ style_dim: 128 # style vector size
+
+ dropout: 0.2
+
+ # config for decoder
+ decoder:
+ type: 'istftnet' # either hifigan or istftnet
+ resblock_kernel_sizes: [3,7,11]
+ upsample_rates : [10, 6]
+ upsample_initial_channel: 512
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
+ upsample_kernel_sizes: [20, 12]
+ gen_istft_n_fft: 20
+ gen_istft_hop_size: 5
+
+ # speech language model config
+ slm:
+ model: 'microsoft/wavlm-base-plus'
+ sr: 16000 # sampling rate of SLM
+ hidden: 768 # hidden size of SLM
+ nlayers: 13 # number of layers of SLM
+ initial_channel: 64 # initial channels of SLM discriminator head
+
+ # style diffusion model config
+ diffusion:
+ embedding_mask_proba: 0.1
+ # transformer config
+ transformer:
+ num_layers: 3
+ num_heads: 8
+ head_features: 64
+ multiplier: 2
+
+ # diffusion distribution config
+ dist:
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
+ mean: -3.0
+ std: 1.0
+
+loss_params:
+ lambda_mel: 5. # mel reconstruction loss
+ lambda_gen: 1. # generator loss
+ lambda_slm: 1. # slm feature matching loss
+
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
+ TMA_epoch: 50 # TMA starting epoch (1st stage)
+
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
+ lambda_dur: 1. # duration loss (2nd stage)
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
+ lambda_diff: 1. # score matching loss (2nd stage)
+
+ diff_epoch: 20 # style diffusion starting epoch (2nd stage)
+ joint_epoch: 50 # joint training starting epoch (2nd stage)
+
+optimizer_params:
+ lr: 0.0001 # general learning rate
+ bert_lr: 0.00001 # learning rate for PLBERT
+ ft_lr: 0.00001 # learning rate for acoustic modules
+
+slmadv_params:
+ min_len: 400 # minimum length of samples
+ max_len: 500 # maximum length of samples
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
+ iter: 10 # update the discriminator every this iterations of generator update
+ thresh: 5 # gradient norm above which the gradient is scaled
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
+ sig: 1.5 # sigma for differentiable duration modeling
+
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/Configs/config_ft.yml b/stts_48khz/StyleTTS2_48khz/Configs/config_ft.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b842e99dc230e5ff82ee1019fdbcd836a7da7e41
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Configs/config_ft.yml
@@ -0,0 +1,118 @@
+log_dir: "Models/Style_Kanade_48khz_fn"
+save_freq: 1
+log_interval: 10
+device: "cuda"
+epochs: 2
+batch_size: 14
+max_len: 2250 # maximum number of frames
+pretrained_model: "/home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz_fn/finetune_phase_13999.pth"
+second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
+load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters
+
+
+F0_path: "/home/austin/disk1/stts-zs_cleaning/F0_extractor/PitchExtractor/Checkpoint_200k/PE_48khz_epoch_00060.pth"
+ASR_config: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml"
+ASR_path: "/home/austin/disk2/llmvcs/tt/AuxiliaryASR/Checkpoint_new_plus/epoch_00070.pth"
+
+PLBERT_dir: '/home/austin/disk2/llmvcs/tt/stylekan/Utils/PLBERT'
+
+data_params:
+ train_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/train_48_pure.csv"
+ val_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/val_48_pure.csv"
+ root_path: ""
+ OOD_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv"
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
+
+#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 accelerate launch accelerate_train_finetune.py -config_path ./Configs/config_ft.yml
+#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python train_second.py --config_path ./Configs/config_kanade.yml
+preprocess_params:
+ sr: 48000
+ spect_params:
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+
+model_params:
+ multispeaker: true
+
+ dim_in: 64
+ hidden_dim: 512
+ max_conv_dim: 512
+ n_layer: 3
+ n_mels: 80
+
+ n_token: 178 # number of phoneme tokens
+ max_dur: 50 # maximum duration of a single phoneme
+ style_dim: 128 # style vector size
+
+ dropout: 0.2
+
+ # config for decoder
+ decoder:
+ type: 'istftnet' # either hifigan or istftnet
+ resblock_kernel_sizes: [3,7,11]
+ upsample_rates : [16, 8]
+ upsample_initial_channel: 512
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
+ upsample_kernel_sizes: [32, 16]
+ gen_istft_n_fft: 32
+ gen_istft_hop_size: 4
+
+
+ # speech language model config
+ slm:
+ model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
+ sr: 16000 # sampling rate of SLM
+ hidden: 1280 # hidden size of SLM
+ nlayers: 33 # number of layers of SLM
+ initial_channel: 64 # initial channels of SLM discriminator head
+
+ # style diffusion model config
+ diffusion:
+ embedding_mask_proba: 0.1
+ # transformer config
+ transformer:
+ num_layers: 3
+ num_heads: 8
+ head_features: 64
+ multiplier: 2
+
+ # diffusion distribution config
+ dist:
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
+ mean: -3.0
+ std: 1.0
+
+loss_params:
+ lambda_mel: 5. # mel reconstruction loss
+ lambda_gen: 1. # generator loss
+ lambda_slm: 1. # slm feature matching loss
+
+ lambda_mono: 1. # monotonic alignment loss (TMA)
+ lambda_s2s: 1. # sequence-to-sequence loss (TMA)
+
+ lambda_F0: 1. # F0 reconstruction loss
+ lambda_norm: 1. # norm reconstruction loss
+ lambda_dur: 1. # duration loss
+ lambda_ce: 20. # duration predictor probability output CE loss
+ lambda_sty: 1. # style reconstruction loss
+ lambda_diff: 1. # score matching loss
+
+ diff_epoch: 0 # style diffusion starting epoch
+ joint_epoch: 30 # joint training starting epoch
+
+optimizer_params:
+ lr: 0.0001 # general learning rate
+ bert_lr: 0.00001 # learning rate for PLBERT
+ ft_lr: 0.0001 # learning rate for acoustic modules
+
+slmadv_params:
+ min_len: 400 # minimum length of samples
+ max_len: 500 # maximum length of samples
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
+ iter: 10 # update the discriminator every this iterations of generator update
+ thresh: 5 # gradient norm above which the gradient is scaled
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
+ sig: 1.5 # sigma for differentiable duration modeling
+
diff --git a/stts_48khz/StyleTTS2_48khz/Configs/config_kanade_48khz.yml b/stts_48khz/StyleTTS2_48khz/Configs/config_kanade_48khz.yml
new file mode 100644
index 0000000000000000000000000000000000000000..29ff7272c88ad122273e8245055b6d83cd9628e3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Configs/config_kanade_48khz.yml
@@ -0,0 +1,125 @@
+log_dir: "Models/Style_Kanade_48khz"
+first_stage_path: ""
+save_freq: 1
+log_interval: 10
+device: "cuda"
+epochs_1st: 25 # number of epochs for first stage training (pre-training)
+epochs_2nd: 15 # number of peochs for second stage training (joint training)
+batch_size: 35
+
+max_len: 560 # approximately 15 seconds -> (512 / 48000) × (**2812** // 2) = 14.997 sec
+
+pretrained_model: "/home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/Top_ckpt.pth"
+second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
+load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
+
+# F0_path: "/home/ubuntu/STTS_48khz/StyleTTS2-48khz/Utils/JDC/bst_rmvpe_48k.t7"
+# ASR_config: "Utils/ASR/config.yml"
+# ASR_path: "/home/ubuntu/STTS_48khz/StyleTTS2-48khz/Utils/ASR/epoch_00050_48K.pth"
+
+F0_path: "/home/austin/disk1/stts-zs_cleaning/F0_extractor/PitchExtractor/Checkpoint_200k/PE_48khz_epoch_00060.pth"
+ASR_config: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml"
+ASR_path: "/home/austin/disk2/llmvcs/tt/AuxiliaryASR/Checkpoint_new_plus/epoch_00070.pth"
+
+PLBERT_dir: '/home/austin/disk2/llmvcs/tt/stylekan/Utils/PLBERT'
+
+data_params:
+ train_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/train_48_pure.csv"
+ val_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/val_48_pure.csv"
+ root_path: ""
+ OOD_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv"
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
+
+#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 accelerate launch accelerate_train_second.py --config_path ./Configs/config_kanade_48khz.yml
+#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python train_second.py --config_path ./Configs/config_kanade.yml
+#CUDA_VISIBLE_DEVICES=7 accelerate launch accelerate_train_second.py --config_path ./Configs/config_ft.yml
+preprocess_params:
+ sr: 48000
+ spect_params:
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+
+model_params:
+ multispeaker: true
+
+ dim_in: 64
+ hidden_dim: 512
+ max_conv_dim: 512
+ n_layer: 3
+ n_mels: 80
+
+ n_token: 178 # number of phoneme tokens
+ max_dur: 50 # maximum duration of a single phoneme
+ style_dim: 128 # style vector size
+
+ dropout: 0.2
+
+ decoder:
+ type: 'istftnet' # either hifigan or istftnet
+ resblock_kernel_sizes: [3,7,11]
+ upsample_rates : [16, 8]
+ upsample_initial_channel: 512
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
+ upsample_kernel_sizes: [32, 16]
+ gen_istft_n_fft: 32
+ gen_istft_hop_size: 4
+
+
+ # speech language model config
+ slm:
+ model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
+ sr: 16000 # sampling rate of SLM
+ hidden: 1280 # hidden size of SLM
+ nlayers: 33 # number of layers of SLM
+ initial_channel: 64 # initial channels of SLM discriminator head
+
+ # style diffusion model config
+ diffusion:
+ embedding_mask_proba: 0.1
+ # transformer config
+ transformer:
+ num_layers: 3
+ num_heads: 8
+ head_features: 64
+ multiplier: 2
+
+ # diffusion distribution config
+ dist:
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
+ mean: -3.0
+ std: 1.0
+
+loss_params:
+ lambda_mel: 10. # mel reconstruction loss
+ lambda_gen: 1. # generator loss
+ lambda_slm: 1. # slm feature matching loss
+
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
+ TMA_epoch: 3 # TMA starting epoch (1st stage)
+
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
+ lambda_dur: 1. # duration loss (2nd stage)
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
+ lambda_diff: 1. # score matching loss (2nd stage)
+
+ diff_epoch: 4 # style diffusion starting epoch (2nd stage)
+ joint_epoch: 999 # joint training starting epoch (2nd stage)
+
+optimizer_params:
+ lr: 0.0001 # general learning rate
+ bert_lr: 0.00001 # learning rate for PLBERT
+ ft_lr: 0.00001 # learning rate for acoustic modules
+
+slmadv_params:
+ min_len: 400 # minimum length of samples
+ max_len: 500 # maximum length of samples
+ batch_percentage: .5 # to prevent out of memory, only use 1/2 of the original batch size
+ iter: 20 # update the discriminator every this iterations of generator update
+ thresh: 5 # gradient norm above which the gradient is scaled
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
+ sig: 1.5 # sigma for differentiable duration modeling
diff --git a/stts_48khz/StyleTTS2_48khz/Configs/config_kanade_48khz_copy.yml b/stts_48khz/StyleTTS2_48khz/Configs/config_kanade_48khz_copy.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c00785cd71270b7c0b67d653b2cd63a38fb9dd9d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Configs/config_kanade_48khz_copy.yml
@@ -0,0 +1,124 @@
+log_dir: "Models/Style_Kanade_48khz_test"
+first_stage_path: "/home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_1st_00020.pth"
+save_freq: 1
+log_interval: 10
+device: "cuda"
+epochs_1st: 25 # number of epochs for first stage training (pre-training)
+epochs_2nd: 15 # number of peochs for second stage training (joint training)
+batch_size: 2
+
+max_len: 2812 # approximately 15 seconds -> 0.01066666666666666666666666666667 × (2812 // 2) = 14.999
+
+pretrained_model: "/home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00004.pth"
+second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
+load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
+
+# F0_path: "/home/ubuntu/STTS_48khz/StyleTTS2-48khz/Utils/JDC/bst_rmvpe_48k.t7"
+# ASR_config: "Utils/ASR/config.yml"
+# ASR_path: "/home/ubuntu/STTS_48khz/StyleTTS2-48khz/Utils/ASR/epoch_00050_48K.pth"
+
+F0_path: "/home/austin/disk1/stts-zs_cleaning/F0_extractor/PitchExtractor/Checkpoint_200k/PE_48khz_epoch_00060.pth"
+ASR_config: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml"
+ASR_path: "/home/austin/disk2/llmvcs/tt/AuxiliaryASR/Checkpoint_new_plus/epoch_00070.pth"
+
+PLBERT_dir: '/home/austin/disk2/llmvcs/tt/stylekan/Utils/PLBERT'
+
+data_params:
+ train_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/val_48_pure.txt"
+ val_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/val_48_pure.csv"
+ root_path: ""
+ OOD_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv"
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
+
+#CUDA_VISIBLE_DEVICES=5,6,7 accelerate launch accelerate_train_second.py --config_path ./Configs/config_kanade_48khz_copy.yml
+#CUDA_VISIBLE_DEVICES=6,7 python train_second.py --config_path ./Configs/config_kanade_48khz_copy.yml
+preprocess_params:
+ sr: 48000
+ spect_params:
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+
+model_params:
+ multispeaker: true
+
+ dim_in: 64
+ hidden_dim: 512
+ max_conv_dim: 512
+ n_layer: 3
+ n_mels: 80
+
+ n_token: 178 # number of phoneme tokens
+ max_dur: 50 # maximum duration of a single phoneme
+ style_dim: 128 # style vector size
+
+ dropout: 0.2
+
+ decoder:
+ type: 'istftnet' # either hifigan or istftnet
+ resblock_kernel_sizes: [3,7,11]
+ upsample_rates : [16, 8]
+ upsample_initial_channel: 512
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
+ upsample_kernel_sizes: [32, 16]
+ gen_istft_n_fft: 32
+ gen_istft_hop_size: 4
+
+
+ # speech language model config
+ slm:
+ model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
+ sr: 16000 # sampling rate of SLM
+ hidden: 1280 # hidden size of SLM
+ nlayers: 33 # number of layers of SLM
+ initial_channel: 64 # initial channels of SLM discriminator head
+
+ # style diffusion model config
+ diffusion:
+ embedding_mask_proba: 0.1
+ # transformer config
+ transformer:
+ num_layers: 3
+ num_heads: 8
+ head_features: 64
+ multiplier: 2
+
+ # diffusion distribution config
+ dist:
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
+ mean: -3.0
+ std: 1.0
+
+loss_params:
+ lambda_mel: 10. # mel reconstruction loss
+ lambda_gen: 1. # generator loss
+ lambda_slm: 1. # slm feature matching loss
+
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
+ TMA_epoch: 3 # TMA starting epoch (1st stage)
+
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
+ lambda_dur: 1. # duration loss (2nd stage)
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
+ lambda_diff: 1. # score matching loss (2nd stage)
+
+ diff_epoch: 4 # style diffusion starting epoch (2nd stage)
+ joint_epoch: 12 # joint training starting epoch (2nd stage)
+
+optimizer_params:
+ lr: 0.0001 # general learning rate
+ bert_lr: 0.00001 # learning rate for PLBERT
+ ft_lr: 0.00001 # learning rate for acoustic modules
+
+slmadv_params:
+ min_len: 400 # minimum length of samples
+ max_len: 500 # maximum length of samples
+ batch_percentage: 1 # to prevent out of memory, only use half of the original batch size
+ iter: 20 # update the discriminator every this iterations of generator update
+ thresh: 5 # gradient norm above which the gradient is scaled
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
+ sig: 1.5 # sigma for differentiable duration modeling
diff --git a/stts_48khz/StyleTTS2_48khz/Demo/Inference_LJSpeech.ipynb b/stts_48khz/StyleTTS2_48khz/Demo/Inference_LJSpeech.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..3a6923e4309b03a156e340e85361edc367b23fbf
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Demo/Inference_LJSpeech.ipynb
@@ -0,0 +1,554 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "9adb7bd1",
+ "metadata": {},
+ "source": [
+ "# StyleTTS 2 Demo (LJSpeech)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6108384d",
+ "metadata": {},
+ "source": [
+ "### Utils"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "96e173bf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "torch.manual_seed(0)\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "\n",
+ "import random\n",
+ "random.seed(0)\n",
+ "\n",
+ "import numpy as np\n",
+ "np.random.seed(0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "da84c60f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%cd .."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5a3ddcc8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load packages\n",
+ "import time\n",
+ "import random\n",
+ "import yaml\n",
+ "from munch import Munch\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "import torch.nn.functional as F\n",
+ "import torchaudio\n",
+ "import librosa\n",
+ "from nltk.tokenize import word_tokenize\n",
+ "\n",
+ "from models import *\n",
+ "from utils import *\n",
+ "from text_utils import TextCleaner\n",
+ "textclenaer = TextCleaner()\n",
+ "\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bbdc04c0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "00ee05e1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
+ "mean, std = -4, 4\n",
+ "\n",
+ "def length_to_mask(lengths):\n",
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
+ " return mask\n",
+ "\n",
+ "def preprocess(wave):\n",
+ " wave_tensor = torch.from_numpy(wave).float()\n",
+ " mel_tensor = to_mel(wave_tensor)\n",
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
+ " return mel_tensor\n",
+ "\n",
+ "def compute_style(ref_dicts):\n",
+ " reference_embeddings = {}\n",
+ " for key, path in ref_dicts.items():\n",
+ " wave, sr = librosa.load(path, sr=24000)\n",
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
+ " if sr != 24000:\n",
+ " audio = librosa.resample(audio, sr, 24000)\n",
+ " mel_tensor = preprocess(audio).to(device)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " ref = model.style_encoder(mel_tensor.unsqueeze(1))\n",
+ " reference_embeddings[key] = (ref.squeeze(1), audio)\n",
+ " \n",
+ " return reference_embeddings"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7b9cecbe",
+ "metadata": {},
+ "source": [
+ "### Load models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "64fc4c0f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load phonemizer\n",
+ "import phonemizer\n",
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "48e7b644",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = yaml.safe_load(open(\"Models/LJSpeech/config.yml\"))\n",
+ "\n",
+ "# load pretrained ASR model\n",
+ "ASR_config = config.get('ASR_config', False)\n",
+ "ASR_path = config.get('ASR_path', False)\n",
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
+ "\n",
+ "# load pretrained F0 model\n",
+ "F0_path = config.get('F0_path', False)\n",
+ "pitch_extractor = load_F0_models(F0_path)\n",
+ "\n",
+ "# load BERT model\n",
+ "from Utils.PLBERT.util import load_plbert\n",
+ "BERT_path = config.get('PLBERT_dir', False)\n",
+ "plbert = load_plbert(BERT_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ffc18cf7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)\n",
+ "_ = [model[key].eval() for key in model]\n",
+ "_ = [model[key].to(device) for key in model]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "64529d5c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "params_whole = torch.load(\"Models/LJSpeech/epoch_2nd_00100.pth\", map_location='cpu')\n",
+ "params = params_whole['net']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "895d9706",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for key in model:\n",
+ " if key in params:\n",
+ " print('%s loaded' % key)\n",
+ " try:\n",
+ " model[key].load_state_dict(params[key])\n",
+ " except:\n",
+ " from collections import OrderedDict\n",
+ " state_dict = params[key]\n",
+ " new_state_dict = OrderedDict()\n",
+ " for k, v in state_dict.items():\n",
+ " name = k[7:] # remove `module.`\n",
+ " new_state_dict[name] = v\n",
+ " # load params\n",
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
+ "# except:\n",
+ "# _load(params[key], model[key])\n",
+ "_ = [model[key].eval() for key in model]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c1a59db2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e30985ab",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sampler = DiffusionSampler(\n",
+ " model.diffusion.diffusion,\n",
+ " sampler=ADPM2Sampler(),\n",
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
+ " clamp=False\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b803110e",
+ "metadata": {},
+ "source": [
+ "### Synthesize speech"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "24655f46",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthesize a text\n",
+ "text = ''' StyleTTS 2 is a text-to-speech model that leverages style diffusion and adversarial training with large speech language models to achieve human-level text-to-speech synthesis. '''"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ca57469c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def inference(text, noise, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " text = text.replace('\"', '')\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ "\n",
+ " tokens = textclenaer(ps)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ " \n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n",
+ " text_mask = length_to_mask(input_lengths).to(tokens.device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
+ "\n",
+ " s_pred = sampler(noise, \n",
+ " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n",
+ " embedding_scale=embedding_scale).squeeze(0)\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ " pred_dur[-1] += 5\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), \n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ " \n",
+ " return out.squeeze().cpu().numpy()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d438ef4f",
+ "metadata": {},
+ "source": [
+ "#### Basic synthesis (5 diffusion steps)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d3d7f7d5",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "start = time.time()\n",
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "wav = inference(text, noise, diffusion_steps=5, embedding_scale=1)\n",
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ "print(f\"RTF = {rtf:5f}\")\n",
+ "import IPython.display as ipd\n",
+ "display(ipd.Audio(wav, rate=24000))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2d5d9df0",
+ "metadata": {},
+ "source": [
+ "#### With higher diffusion steps (more diverse)\n",
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a10129fd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "start = time.time()\n",
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "wav = inference(text, noise, diffusion_steps=10, embedding_scale=1)\n",
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ "print(f\"RTF = {rtf:5f}\")\n",
+ "import IPython.display as ipd\n",
+ "display(ipd.Audio(wav, rate=24000))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1877ea15",
+ "metadata": {},
+ "source": [
+ "### Speech expressiveness\n",
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4c4777b7",
+ "metadata": {},
+ "source": [
+ "#### With embedding_scale=1\n",
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c29ea2f0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav = inference(v, noise, diffusion_steps=10, embedding_scale=1)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3c89499f",
+ "metadata": {},
+ "source": [
+ "#### With embedding_scale=2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f73be3aa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav = inference(v, noise, diffusion_steps=10, embedding_scale=2) # embedding_scale=2 for more pronounced emotion\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9320da63",
+ "metadata": {},
+ "source": [
+ "### Long-form generation\n",
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cdd4db51",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first-class homemade products there is a market in all large cities. All first-class grocers have customers who purchase such goods.'''"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ebb941c8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " text = text.replace('\"', '')\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ "\n",
+ " tokens = textclenaer(ps)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ " \n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n",
+ " text_mask = length_to_mask(input_lengths).to(tokens.device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
+ "\n",
+ " s_pred = sampler(noise, \n",
+ " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n",
+ " embedding_scale=embedding_scale).squeeze(0)\n",
+ " \n",
+ " if s_prev is not None:\n",
+ " # convex combination of previous and current style\n",
+ " s_pred = alpha * s_prev + (1 - alpha) * s_pred\n",
+ " \n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), \n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ " \n",
+ " return out.squeeze().cpu().numpy(), s_pred"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7ca0ef2e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sentences = passage.split('.') # simple split by comma\n",
+ "wavs = []\n",
+ "s_prev = None\n",
+ "for text in sentences:\n",
+ " if text.strip() == \"\": continue\n",
+ " text += '.' # add it back\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=10, embedding_scale=1.5)\n",
+ " wavs.append(wav)\n",
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "NLP",
+ "language": "python",
+ "name": "nlp"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/stts_48khz/StyleTTS2_48khz/Demo/Inference_LibriTTS.ipynb b/stts_48khz/StyleTTS2_48khz/Demo/Inference_LibriTTS.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..3664886e4ecfa6f9de8f07119ee1197b234a6c44
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Demo/Inference_LibriTTS.ipynb
@@ -0,0 +1,1242 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "9adb7bd1",
+ "metadata": {},
+ "source": [
+ "# StyleTTS 2 Demo (LibriTTS) 48khz\n",
+ "\n",
+ "Before you run the following cells, please make sure you have downloaded [reference_audio.zip](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/resolve/main/reference_audio.zip) and unzipped it under the `demo` folder."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6108384d",
+ "metadata": {},
+ "source": [
+ "### Utils"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "96e173bf",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"6\"\n",
+ "import torch\n",
+ "print(torch.cuda.device_count()) # should print: 1\n",
+ "\n",
+ "import torch\n",
+ "torch.manual_seed(0)\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "\n",
+ "import random\n",
+ "random.seed(0)\n",
+ "\n",
+ "import numpy as np\n",
+ "np.random.seed(0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "da84c60f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz\n"
+ ]
+ }
+ ],
+ "source": [
+ "%cd /home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "5a3ddcc8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "177\n"
+ ]
+ }
+ ],
+ "source": [
+ "# load packages\n",
+ "import time\n",
+ "import random\n",
+ "import yaml\n",
+ "from munch import Munch\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "import torch.nn.functional as F\n",
+ "import torchaudio\n",
+ "import librosa\n",
+ "from nltk.tokenize import word_tokenize\n",
+ "\n",
+ "from models import *\n",
+ "from utils import *\n",
+ "from text_utils import TextCleaner\n",
+ "textclenaer = TextCleaner()\n",
+ "\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "00ee05e1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
+ " n_mels=80, n_fft=2048, win_length=2048, hop_length=512)\n",
+ "mean, std = -4, 4\n",
+ "\n",
+ "def length_to_mask(lengths):\n",
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
+ " return mask\n",
+ "\n",
+ "def preprocess(wave):\n",
+ " wave_tensor = torch.from_numpy(wave).float()\n",
+ " mel_tensor = to_mel(wave_tensor)\n",
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
+ " return mel_tensor\n",
+ "\n",
+ "def compute_style(path):\n",
+ " wave, sr = librosa.load(path, sr=48000)\n",
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
+ " if sr != 48000:\n",
+ " audio = librosa.resample(audio, sr, 48000)\n",
+ " mel_tensor = preprocess(audio).to(device)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
+ "\n",
+ " return torch.cat([ref_s, ref_p], dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "bbdc04c0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7b9cecbe",
+ "metadata": {},
+ "source": [
+ "### Load models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "64fc4c0f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load phonemizer\n",
+ "import phonemizer\n",
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "48e7b644",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = yaml.safe_load(open(\"/home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/config_kanade_48khz.yml\"))\n",
+ "\n",
+ "# load pretrained ASR model\n",
+ "ASR_config = config.get('ASR_config', False)\n",
+ "ASR_path = config.get('ASR_path', False)\n",
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
+ "\n",
+ "# load pretrained F0 model\n",
+ "F0_path = config.get('F0_path', False)\n",
+ "pitch_extractor = load_F0_models(F0_path)\n",
+ "\n",
+ "# load BERT model\n",
+ "from Utils.PLBERT.util import load_plbert\n",
+ "BERT_path = config.get('PLBERT_dir', False)\n",
+ "plbert = load_plbert(BERT_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "ffc18cf7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_params = recursive_munch(config['model_params'])\n",
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
+ "_ = [model[key].eval() for key in model]\n",
+ "_ = [model[key].to(device) for key in model]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "64529d5c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "params_whole = torch.load(\"/home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00006.pth\", map_location='cpu')\n",
+ "params = params_whole['net']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "895d9706",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "bert loaded\n",
+ "bert_encoder loaded\n",
+ "predictor loaded\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "decoder loaded\n",
+ "text_encoder loaded\n",
+ "predictor_encoder loaded\n",
+ "style_encoder loaded\n",
+ "diffusion loaded\n",
+ "text_aligner loaded\n",
+ "pitch_extractor loaded\n",
+ "mpd loaded\n",
+ "msd loaded\n",
+ "wd loaded\n"
+ ]
+ }
+ ],
+ "source": [
+ "for key in model:\n",
+ " if key in params:\n",
+ " print('%s loaded' % key)\n",
+ " try:\n",
+ " model[key].load_state_dict(params[key])\n",
+ " except:\n",
+ " from collections import OrderedDict\n",
+ " state_dict = params[key]\n",
+ " new_state_dict = OrderedDict()\n",
+ " for k, v in state_dict.items():\n",
+ " name = k[7:] # remove `module.`\n",
+ " new_state_dict[name] = v\n",
+ " # load params\n",
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
+ "# except:\n",
+ "# _load(params[key], model[key])\n",
+ "_ = [model[key].eval() for key in model]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "e30985ab",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n",
+ "sampler = DiffusionSampler(\n",
+ " model.diffusion.diffusion,\n",
+ " sampler=ADPM2Sampler(),\n",
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
+ " clamp=False\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b803110e",
+ "metadata": {},
+ "source": [
+ "### Synthesize speech"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ca57469c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " # text = text.strip()\n",
+ " # ps = global_phonemizer.phonemize([text])\n",
+ " # ps = word_tokenize(ps[0])\n",
+ " # ps = ' '.join(ps)\n",
+ " tokens = textclenaer(text)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ " \n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
+ " \n",
+ "\n",
+ "\n",
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
+ " embedding=bert_dur,\n",
+ " embedding_scale=embedding_scale,\n",
+ " features=ref_s, # reference from the same speaker as the embedding\n",
+ " num_steps=diffusion_steps).squeeze(1)\n",
+ "\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en, \n",
+ " s, input_lengths, text_mask)\n",
+ "\n",
+ " x = model.predictor.lstm(d)\n",
+ " x_mod = model.predictor.prepare_projection(x) # 640 -> 512\n",
+ " duration = model.predictor.duration_proj(x_mod)\n",
+ "\n",
+ "\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1) \n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(en)\n",
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
+ " en = asr_new\n",
+ "\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ "\n",
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(asr)\n",
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
+ " asr = asr_new\n",
+ "\n",
+ " out = model.decoder(asr, \n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ " \n",
+ " \n",
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d438ef4f",
+ "metadata": {},
+ "source": [
+ "#### Basic synthesis (5 diffusion steps, seen speakers)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "cace9787",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "text = '''soɯ iɯ imi de wa, sendʑoɯgawaɽa çitagi wa, kawaʔta no de mo, koɯsei ɕita no de mo, modoʔta no de mo toɽikaeɕita no de mo nakɯ, maɕite, deɽeta no de mo doɽota no de mo nakɯ.'''\n",
+ "reference_dicts = {}\n",
+ "# reference_dicts['696_92939'] = \"/home/austin/disk1/stts-zs_cleaning/data/moe_soshy/Japanese/imas_split/Syuuko/Syuuko_Events_and_Card/Event/saite_jewel/saite_jewel_chunk90.wav\"\n",
+ "# reference_dicts['1789_142896'] = \"/home/austin/disk1/stts-zs_cleaning/data/moe_soshy/Japanese/imas_split/Kanade/Kanade_Events_and_Card/Kanade_Events/KanLipps/Kanade_lipps_02/Kanade_lipps_02_chunk9.wav\"\n",
+ "# reference_dicts['1789_14289w'] = \"/home/austin/disk1/stts-zs_cleaning/data/moe_soshy/Japanese/sakura_moyu/01/01102220.wav\"\n",
+ "reference_dicts['1789_14289w'] = \"/home/austin/disk1/stts-zs_cleaning/data/moe_soshy/Japanese/monogatari/monogatari_voices/monogatari_split/sakamoto_maya/Sakamoto_Maya_01/Sakamoto_Maya_01_chunk1709.wav\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "16e8ac60",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'time' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m start \u001b[38;5;241m=\u001b[39m \u001b[43mtime\u001b[49m\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 2\u001b[0m noise \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m256\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, path \u001b[38;5;129;01min\u001b[39;00m reference_dicts\u001b[38;5;241m.\u001b[39mitems():\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'time' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "start = time.time()\n",
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, path in reference_dicts.items():\n",
+ " ref_s = compute_style(path)\n",
+ " \n",
+ " wav = inference(text, ref_s, alpha=0., beta=0.5, diffusion_steps=10, embedding_scale=2)\n",
+ " rtf = (time.time() - start) / (len(wav) / 48000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print(k + ' Synthesized:')\n",
+ " display(ipd.Audio(wav, rate=48000, normalize=False))\n",
+ " print('Reference:')\n",
+ " display(ipd.Audio(path, rate=48000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "14838708",
+ "metadata": {},
+ "source": [
+ "#### With higher diffusion steps (more diverse)\n",
+ "\n",
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6fbff03b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, path in reference_dicts.items():\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=10, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print(k + ' Synthesized:')\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print(k + ' Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7e6867fd",
+ "metadata": {},
+ "source": [
+ "#### Basic synthesis (5 diffusion steps, umseen speakers)\n",
+ "The following samples are to reproduce samples in [Section 4](https://styletts2.github.io/#libri) of the demo page. All spsakers are unseen during training. You can compare the generated samples to popular zero-shot TTS models like Vall-E and NaturalSpeech 2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f4e8faa0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reference_dicts = {}\n",
+ "# format: (path, text)\n",
+ "reference_dicts['1221-135767'] = (\"Demo/reference_audio/1221-135767-0014.wav\", \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\")\n",
+ "reference_dicts['5639-40744'] = (\"Demo/reference_audio/5639-40744-0020.wav\", \"Thus did this humane and right minded father comfort his unhappy daughter, and her mother embracing her again, did all she could to soothe her feelings.\")\n",
+ "reference_dicts['908-157963'] = (\"Demo/reference_audio/908-157963-0027.wav\", \"And lay me down in my cold bed and leave my shining lot.\")\n",
+ "reference_dicts['4077-13754'] = (\"Demo/reference_audio/4077-13754-0000.wav\", \"The army found the people in poverty and left them in comparative wealth.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "653f1406",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, v in reference_dicts.items():\n",
+ " path, text = v\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print(k + ' Synthesized: ' + text)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print(k + ' Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "141e91b3",
+ "metadata": {},
+ "source": [
+ "### Speech expressiveness\n",
+ "\n",
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page. The speaker reference used is `1221-135767-0014.wav`, which is unseen during training. \n",
+ "\n",
+ "#### With `embedding_scale=1`\n",
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "81addda4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ref_s = compute_style(\"Demo/reference_audio/1221-135767-0014.wav\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "be1b2a11",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "96d262b8",
+ "metadata": {},
+ "source": [
+ "#### With `embedding_scale=2`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3e7d40b4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=2)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "402b2bd6",
+ "metadata": {},
+ "source": [
+ "#### With `embedding_scale=2, alpha = 0.5, beta = 0.9`\n",
+ "`alpha` and `beta` is the factor to determine much we use the style sampled based on the text instead of the reference. The higher the value of `alpha` and `beta`, the more suitable the style it is to the text but less similar to the reference. Using higher beta makes the synthesized speech more emotional, at the cost of lower similarity to the reference. `alpha` determines the timbre of the speaker while `beta` determines the prosody. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "599de5d5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "texts = {}\n",
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
+ "\n",
+ "for k,v in texts.items():\n",
+ " noise = torch.randn(1,1,256).to(device)\n",
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=2)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "48548866",
+ "metadata": {},
+ "source": [
+ "### Zero-shot speaker adaptation\n",
+ "This section recreates the \"Acoustic Environment Maintenance\" and \"Speaker’s Emotion Maintenance\" demo in [Section 4](https://styletts2.github.io/#libri) of the demo page. You can compare the generated samples to popular zero-shot TTS models like Vall-E. Note that the model was trained only on LibriTTS, which is about 250 times fewer data compared to those used to trian Vall-E with similar or better effect for these maintainance. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23e81572",
+ "metadata": {},
+ "source": [
+ "#### Acoustic Environment Maintenance\n",
+ "\n",
+ "Since we want to maintain the acoustic environment in the speaker (timbre), we set `alpha = 0` to make the speaker as closer to the reference as possible while only changing the prosody according to the text. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8087bccb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reference_dicts = {}\n",
+ "# format: (path, text)\n",
+ "reference_dicts['3'] = (\"Demo/reference_audio/3.wav\", \"As friends thing I definitely I've got more male friends.\")\n",
+ "reference_dicts['4'] = (\"Demo/reference_audio/4.wav\", \"Everything is run by computer but you got to know how to think before you can do a computer.\")\n",
+ "reference_dicts['5'] = (\"Demo/reference_audio/5.wav\", \"Then out in LA you guys got a whole another ball game within California to worry about.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1e99c200",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, v in reference_dicts.items():\n",
+ " path, text = v\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.0, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print('Synthesized: ' + text)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print('Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7d56505d",
+ "metadata": {},
+ "source": [
+ "#### Speaker’s Emotion Maintenance\n",
+ "\n",
+ "Since we want to maintain the emotion in the speaker (prosody), we set `beta = 0.1` to make the speaker as closer to the reference as possible while having some diversity thruogh the slight timbre change."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f90179e7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reference_dicts = {}\n",
+ "# format: (path, text)\n",
+ "reference_dicts['Anger'] = (\"Demo/reference_audio/anger.wav\", \"We have to reduce the number of plastic bags.\")\n",
+ "reference_dicts['Sleepy'] = (\"Demo/reference_audio/sleepy.wav\", \"We have to reduce the number of plastic bags.\")\n",
+ "reference_dicts['Amused'] = (\"Demo/reference_audio/amused.wav\", \"We have to reduce the number of plastic bags.\")\n",
+ "reference_dicts['Disgusted'] = (\"Demo/reference_audio/disgusted.wav\", \"We have to reduce the number of plastic bags.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2e6bdfed",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, v in reference_dicts.items():\n",
+ " path, text = v\n",
+ " ref_s = compute_style(path)\n",
+ " start = time.time()\n",
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.1, diffusion_steps=10, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print(f\"RTF = {rtf:5f}\")\n",
+ " import IPython.display as ipd\n",
+ " print(k + ' Synthesized: ' + text)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print(k + ' Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "37ae3963",
+ "metadata": {},
+ "source": [
+ "### Longform Narration\n",
+ "\n",
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f12a716b",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "a1a38079",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " # text = text.strip()\n",
+ " # ps = global_phonemizer.phonemize([text])\n",
+ " # ps = word_tokenize(ps[0])\n",
+ " # ps = ' '.join(ps)\n",
+ " # ps = ps.replace('``', '\"')\n",
+ " # ps = ps.replace(\"''\", '\"')\n",
+ "\n",
+ " tokens = textclenaer(text)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ " \n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
+ "\n",
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
+ " embedding=bert_dur,\n",
+ " embedding_scale=embedding_scale,\n",
+ " features=ref_s, # reference from the same speaker as the embedding\n",
+ " num_steps=diffusion_steps).squeeze(1)\n",
+ " \n",
+ " if s_prev is not None:\n",
+ " # convex combination of previous and current style\n",
+ " s_pred = t * s_prev + (1 - t) * s_pred\n",
+ " \n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ " \n",
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
+ "\n",
+ " s_pred = torch.cat([ref, s], dim=-1)\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en, \n",
+ " s, input_lengths, text_mask)\n",
+ "\n",
+ " x = model.predictor.lstm(d)\n",
+ " x_mod = model.predictor.prepare_projection(x) # 640 -> 512\n",
+ " duration = model.predictor.duration_proj(x_mod)\n",
+ "\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(en)\n",
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
+ " en = asr_new\n",
+ "\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ "\n",
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(asr)\n",
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
+ " asr = asr_new\n",
+ "\n",
+ " out = model.decoder(asr, \n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ " \n",
+ " \n",
+ " return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "id": "e9088f7a",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "RuntimeError",
+ "evalue": "The size of tensor a (512) must match the size of tensor b (531) at non-singleton dimension 3",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[75], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m text\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m 13\u001b[0m text \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;66;03m# add it back\u001b[39;00m\n\u001b[0;32m---> 15\u001b[0m wav, s_prev \u001b[38;5;241m=\u001b[39m \u001b[43mLFinference\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43ms_prev\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43ms_ref\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# make it more suitable for the text\u001b[39;49;00m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.7\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiffusion_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membedding_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1.5\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 22\u001b[0m wavs\u001b[38;5;241m.\u001b[39mappend(wav)\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mSynthesized: \u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
+ "Cell \u001b[0;32mIn[19], line 64\u001b[0m, in \u001b[0;36mLFinference\u001b[0;34m(text, s_prev, ref_s, alpha, beta, t, diffusion_steps, embedding_scale)\u001b[0m\n\u001b[1;32m 61\u001b[0m asr_new[:, :, \u001b[38;5;241m1\u001b[39m:] \u001b[38;5;241m=\u001b[39m en[:, :, \u001b[38;5;241m0\u001b[39m:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 62\u001b[0m en \u001b[38;5;241m=\u001b[39m asr_new\n\u001b[0;32m---> 64\u001b[0m F0_pred, N_pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredictor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mF0Ntrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43men\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 66\u001b[0m asr \u001b[38;5;241m=\u001b[39m (t_en \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device))\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m model_params\u001b[38;5;241m.\u001b[39mdecoder\u001b[38;5;241m.\u001b[39mtype \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhifigan\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
+ "File \u001b[0;32m~/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/models.py:1651\u001b[0m, in \u001b[0;36mProsodyPredictor.F0Ntrain\u001b[0;34m(self, x, s)\u001b[0m\n\u001b[1;32m 1647\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mF0Ntrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, s):\n\u001b[0;32m-> 1651\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1652\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprepare_projection(x)\n\u001b[1;32m 1655\u001b[0m F0 \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/xlstm/xlstm_block_stack.py:120\u001b[0m, in \u001b[0;36mxLSTMBlockStack.forward\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m block \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mblocks:\n\u001b[0;32m--> 120\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 122\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_blocks_norm(x)\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/xlstm/blocks/xlstm_block.py:77\u001b[0m, in \u001b[0;36mxLSTMBlock.forward\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m---> 77\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mxlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mxlstm_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mffn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 79\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mffn(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mffn_norm(x), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/xlstm/blocks/mlstm/layer.py:116\u001b[0m, in \u001b[0;36mmLSTMLayer.forward\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mk_proj(x_mlstm_conv_act)\n\u001b[1;32m 114\u001b[0m v \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mv_proj(x_mlstm)\n\u001b[0;32m--> 116\u001b[0m h_tilde_state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlstm_cell\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mv\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 118\u001b[0m h_tilde_state_skip \u001b[38;5;241m=\u001b[39m h_tilde_state \u001b[38;5;241m+\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlearnable_skip \u001b[38;5;241m*\u001b[39m x_mlstm_conv_act)\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# output / z branch\u001b[39;00m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/xlstm/blocks/mlstm/cell.py:61\u001b[0m, in \u001b[0;36mmLSTMCell.forward\u001b[0;34m(self, q, k, v, **kwargs)\u001b[0m\n\u001b[1;32m 58\u001b[0m fgate_preact \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfgate(if_gate_input) \u001b[38;5;66;03m# (B, S, NH)\u001b[39;00m\n\u001b[1;32m 59\u001b[0m fgate_preact \u001b[38;5;241m=\u001b[39m fgate_preact\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# (B, NH, S, 1)#\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m h_state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackend_fn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[43m \u001b[49m\u001b[43mqueries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalues\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43migate_preact\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43migate_preact\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 66\u001b[0m \u001b[43m \u001b[49m\u001b[43mfgate_preact\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfgate_preact\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 67\u001b[0m \u001b[43m \u001b[49m\u001b[43mlower_triangular_matrix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 68\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (B, NH, S, DH)\u001b[39;00m\n\u001b[1;32m 70\u001b[0m h_state_norm \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutnorm(h_state) \u001b[38;5;66;03m# (B, NH, S, DH)\u001b[39;00m\n\u001b[1;32m 71\u001b[0m h_state_norm \u001b[38;5;241m=\u001b[39m h_state_norm\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mreshape(B, S, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# (B, NH, S, DH) -> (B, S, NH, DH) -> (B, S, H)\u001b[39;00m\n",
+ "File \u001b[0;32m~/disk2/micromamba/envs/sttszs/lib/python3.11/site-packages/xlstm/blocks/mlstm/backends.py:64\u001b[0m, in \u001b[0;36mparallel_stabilized_simple\u001b[0;34m(queries, keys, values, igate_preact, fgate_preact, lower_triangular_matrix, stabilize_rowwise, eps, **kwargs)\u001b[0m\n\u001b[1;32m 61\u001b[0m _log_fg_matrix \u001b[38;5;241m=\u001b[39m rep_log_fgates_cumsum \u001b[38;5;241m-\u001b[39m rep_log_fgates_cumsum\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# (B, NH, S+1, S+1)\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# Causal masking & selection of the correct submatrix, such that forgetgate at timestep t is not applied\u001b[39;00m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;66;03m# to the input at timestep t\u001b[39;00m\n\u001b[0;32m---> 64\u001b[0m log_fg_matrix \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwhere\u001b[49m\u001b[43m(\u001b[49m\u001b[43mltr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_log_fg_matrix\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (B, NH, S, S)\u001b[39;00m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;66;03m# gate decay matrix D (combination of forget gate and input gate)\u001b[39;00m\n\u001b[1;32m 67\u001b[0m log_D_matrix \u001b[38;5;241m=\u001b[39m log_fg_matrix \u001b[38;5;241m+\u001b[39m igate_preact\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# (B, NH, S, S)\u001b[39;00m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (512) must match the size of tensor b (531) at non-singleton dimension 3"
+ ]
+ }
+ ],
+ "source": [
+ "# unseen speaker\n",
+ "passage = '''ohoɕisama, doko kaɽa kita no? kiʔto, oːkina ɯtɕɯɯ no jɯme o miterɯɴ da ne. sono jɯme wa, wataɕitatɕi no kokoɽo ni mo todokɯ joɯ na, totemo fɯkai jɯme naɴ daɽoɯ ne. ohoɕisama ga çikarɯ jorɯ wa, wataɕi mo sono jɯme o mite, aɕita no boɯkeɴ o soɯzoɯ sɯrɯɴ da. dakaɽa, ohoɕisama, zɯʔto iʔɕo ni ite ne.'''\n",
+ "# reference_dicts = {}\n",
+ "# reference_dicts['696_92939'] = \"/home/austin/disk1/stts-zs_cleaning/data/moe_soshy/Japanese/imas_split/Syuuko/Syuuko_Events_and_Card/Event/saite_jewel/saite_jewel_chunk90.wav\"\n",
+ "# reference_dicts['1789_142896'] = \"/home/austin/disk1/stts-zs_cleaning/data/moe_soshy/Japanese/imas_split/Kanade/Kanade_Events_and_Card/Kanade_Events/KanLipps/Kanade_lipps_02/Kanade_lipps_02_chunk9.wav\"\n",
+ "path = \"/home/austin/disk1/stts-zs_cleaning/data/moe_soshy/Japanese/sakura_moyu/01/01001290.wav\"\n",
+ "s_ref = compute_style(path)\n",
+ "sentences = passage.split('.') # simple split by comma\n",
+ "wavs = []\n",
+ "s_prev = None\n",
+ "for text in sentences:\n",
+ " if text.strip() == \"\": continue\n",
+ " text += ',' # add it back\n",
+ " \n",
+ " wav, s_prev = LFinference(text, \n",
+ " s_prev, \n",
+ " s_ref, \n",
+ " alpha = 0.1, \n",
+ " beta = 0.3, # make it more suitable for the text\n",
+ " t = 0.7, \n",
+ " diffusion_steps=10, embedding_scale=1.5)\n",
+ " wavs.append(wav)\n",
+ "print('Synthesized: ')\n",
+ "display(ipd.Audio(np.concatenate(wavs), rate=48000, normalize=False))\n",
+ "print('Reference: ')\n",
+ "display(ipd.Audio(path, rate=48000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7517b657",
+ "metadata": {},
+ "source": [
+ "### Style Transfer\n",
+ "\n",
+ "The following section demostrates the style transfer capacity for unseen speakers in [Section 6](https://styletts2.github.io/#emo) of the demo page. For this, we set `alpha=0.5, beta = 0.9` for the most pronounced effects (mostly using the sampled style). "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ed95d0f7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
+ " text = text.strip()\n",
+ " ps = global_phonemizer.phonemize([text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ "\n",
+ " tokens = textclenaer(ps)\n",
+ " tokens.insert(0, 0)\n",
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
+ " \n",
+ " ref_text = ref_text.strip()\n",
+ " ps = global_phonemizer.phonemize([ref_text])\n",
+ " ps = word_tokenize(ps[0])\n",
+ " ps = ' '.join(ps)\n",
+ "\n",
+ " ref_tokens = textclenaer(ps)\n",
+ " ref_tokens.insert(0, 0)\n",
+ " ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)\n",
+ " \n",
+ " \n",
+ " with torch.no_grad():\n",
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
+ "\n",
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
+ " \n",
+ " ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)\n",
+ " ref_text_mask = length_to_mask(ref_input_lengths).to(device)\n",
+ " ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())\n",
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
+ " embedding=bert_dur,\n",
+ " embedding_scale=embedding_scale,\n",
+ " features=ref_s, # reference from the same speaker as the embedding\n",
+ " num_steps=diffusion_steps).squeeze(1)\n",
+ "\n",
+ "\n",
+ " s = s_pred[:, 128:]\n",
+ " ref = s_pred[:, :128]\n",
+ "\n",
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
+ "\n",
+ " d = model.predictor.text_encoder(d_en, \n",
+ " s, input_lengths, text_mask)\n",
+ "\n",
+ " x, _ = model.predictor.lstm(d)\n",
+ " duration = model.predictor.duration_proj(x)\n",
+ "\n",
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
+ "\n",
+ "\n",
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
+ " c_frame = 0\n",
+ " for i in range(pred_aln_trg.size(0)):\n",
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
+ " c_frame += int(pred_dur[i].data)\n",
+ "\n",
+ " # encode prosody\n",
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(en)\n",
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
+ " en = asr_new\n",
+ "\n",
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
+ "\n",
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
+ " if model_params.decoder.type == \"hifigan\":\n",
+ " asr_new = torch.zeros_like(asr)\n",
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
+ " asr = asr_new\n",
+ "\n",
+ " out = model.decoder(asr, \n",
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
+ " \n",
+ " \n",
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ec3f0da4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# reference texts to sample styles\n",
+ "\n",
+ "ref_texts = {}\n",
+ "ref_texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
+ "ref_texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
+ "ref_texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
+ "ref_texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6d0a3825",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "s_ref = compute_style(path)\n",
+ "\n",
+ "text = \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\"\n",
+ "for k,v in ref_texts.items():\n",
+ " wav = STinference(text, s_ref, v, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=1.5)\n",
+ " print(k + \": \")\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6750aed9",
+ "metadata": {},
+ "source": [
+ "### Speech diversity\n",
+ "\n",
+ "This section reproduces samples in [Section 7](https://styletts2.github.io/#var) of the demo page. \n",
+ "\n",
+ "`alpha` and `beta` determine the diversity of the synthesized speech. There are two extreme cases:\n",
+ "- If `alpha = 1` and `beta = 1`, the synthesized speech sounds the most dissimilar to the reference speaker, but it is also the most diverse (each time you synthesize a speech it will be totally different). \n",
+ "- If `alpha = 0` and `beta = 0`, the synthesized speech sounds the most siimlar to the reference speaker, but it is deterministic (i.e., the sampled style is not used for speech synthesis). \n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f6ae0aa5",
+ "metadata": {},
+ "source": [
+ "#### Default setting (`alpha = 0.3, beta=0.7`)\n",
+ "This setting uses 70% of the reference timbre and 30% of the reference prosody and use the diffusion model to sample them based on the text. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "36dc0148",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf9ef421",
+ "metadata": {},
+ "source": [
+ "#### Less diverse setting (`alpha = 0.1, beta=0.3`)\n",
+ "This setting uses 90% of the reference timbre and 70% of the reference prosody. This makes it more similar to the reference speaker at cost of less diverse samples. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9ba406bd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.1, beta=0.3, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a38fe464",
+ "metadata": {},
+ "source": [
+ "#### More diverse setting (`alpha = 0.5, beta=0.95`)\n",
+ "This setting uses 50% of the reference timbre and 5% of the reference prosody (so it uses 100% of the sampled prosody, which makes it more diverse), but this makes it more dissimilar to the reference speaker. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5f25bf94",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.5, beta=0.95, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "21c3a071",
+ "metadata": {},
+ "source": [
+ "#### Extreme setting (`alpha = 1, beta=1`)\n",
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very dissimilar to the reference speaker. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fff8bab1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=1, beta=1, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a8741e5a",
+ "metadata": {},
+ "source": [
+ "#### No variation (`alpha = 0, beta=0`)\n",
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very similar to the reference speaker, but there is no variation. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e55dd281",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# unseen speaker\n",
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
+ "ref_s = compute_style(path)\n",
+ "\n",
+ "text = \"How much variation is there?\"\n",
+ "for _ in range(5):\n",
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0, beta=0, embedding_scale=1)\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d5e86423",
+ "metadata": {},
+ "source": [
+ "### Extra fun!\n",
+ "\n",
+ "Here we clone some of the authors' voice of the StyleTTS 2 papers with a few seconds of the recording in the wild. None of the voices is in the dataset and all authors agreed to have their voices cloned here."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6f558314",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "caa5747c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reference_dicts = {}\n",
+ "reference_dicts['Yinghao'] = \"Demo/reference_audio/Yinghao.wav\"\n",
+ "reference_dicts['Gavin'] = \"Demo/reference_audio/Gavin.wav\"\n",
+ "reference_dicts['Vinay'] = \"Demo/reference_audio/Vinay.wav\"\n",
+ "reference_dicts['Nima'] = \"Demo/reference_audio/Nima.wav\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "44a4cea1",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "start = time.time()\n",
+ "noise = torch.randn(1,1,256).to(device)\n",
+ "for k, path in reference_dicts.items():\n",
+ " ref_s = compute_style(path)\n",
+ " \n",
+ " wav = inference(text, ref_s, alpha=0.1, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
+ " print('Speaker: ' + k)\n",
+ " import IPython.display as ipd\n",
+ " print('Synthesized:')\n",
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
+ " print('Reference:')\n",
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/stts_48khz/StyleTTS2_48khz/LICENSE b/stts_48khz/StyleTTS2_48khz/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0c7bc2e109e023323ea3146c306b392bfa3ce614
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Aaron (Yinghao) Li
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade/config_kanade.yml b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade/config_kanade.yml
new file mode 100644
index 0000000000000000000000000000000000000000..24c87d5c8302e7c980029bd1c15c5a62c58ad81e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade/config_kanade.yml
@@ -0,0 +1,124 @@
+log_dir: "Models/Style_Kanade"
+first_stage_path: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade/epoch_1st_00013.pth"
+save_freq: 1
+log_interval: 10
+device: "cuda"
+epochs_1st: 25 # number of epochs for first stage training (pre-training)
+epochs_2nd: 15 # number of peochs for second stage training (joint training)
+batch_size: 24
+max_len: 560 # maximum number of frames
+pretrained_model: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade/succ_epoch_2nd_00002.pth"
+second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
+load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
+
+# F0_path: "/home/ubuntu/STTS_48khz/StyleTTS2-48khz/Utils/JDC/bst_rmvpe_48k.t7"
+# ASR_config: "Utils/ASR/config.yml"
+# ASR_path: "/home/ubuntu/STTS_48khz/StyleTTS2-48khz/Utils/ASR/epoch_00050_48K.pth"
+
+
+
+F0_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/JDC/bst.t7"
+ASR_config: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml"
+ASR_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/bst_00080.pth"
+
+PLBERT_dir: 'Utils/PLBERT/'
+
+data_params:
+ train_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/filtered_train_list.csv"
+ val_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/mg_valid.txt"
+ root_path: ""
+ OOD_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv"
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
+
+
+preprocess_params:
+ sr: 24000
+ spect_params:
+ n_fft: 2048
+ win_length: 1200
+ hop_length: 300
+
+model_params:
+ multispeaker: true
+
+ dim_in: 64
+ hidden_dim: 512
+ max_conv_dim: 512
+ n_layer: 3
+ n_mels: 80
+
+ n_token: 178 # number of phoneme tokens
+ max_dur: 50 # maximum duration of a single phoneme
+ style_dim: 128 # style vector size
+
+ dropout: 0.2
+
+ decoder:
+ type: 'istftnet' # either hifigan or istftnet
+ resblock_kernel_sizes: [3,7,11]
+ upsample_rates : [10, 6]
+ upsample_initial_channel: 512
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
+ upsample_kernel_sizes: [20, 12]
+ gen_istft_n_fft: 20
+ gen_istft_hop_size: 5
+
+
+
+ # speech language model config
+ slm:
+ model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
+ sr: 16000 # sampling rate of SLM
+ hidden: 1280 # hidden size of SLM
+ nlayers: 33 # number of layers of SLM
+ initial_channel: 64 # initial channels of SLM discriminator head
+
+ # style diffusion model config
+ diffusion:
+ embedding_mask_proba: 0.1
+ # transformer config
+ transformer:
+ num_layers: 3
+ num_heads: 8
+ head_features: 64
+ multiplier: 2
+
+ # diffusion distribution config
+ dist:
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
+ mean: -3.0
+ std: 1.0
+
+loss_params:
+ lambda_mel: 10. # mel reconstruction loss
+ lambda_gen: 1. # generator loss
+ lambda_slm: 1. # slm feature matching loss
+
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
+ TMA_epoch: 9 # TMA starting epoch (1st stage)
+
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
+ lambda_dur: 1. # duration loss (2nd stage)
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
+ lambda_diff: 1. # score matching loss (2nd stage)
+
+ diff_epoch: 2 # style diffusion starting epoch (2nd stage)
+ joint_epoch: 6 # joint training starting epoch (2nd stage)
+
+optimizer_params:
+ lr: 0.0001 # general learning rate
+ bert_lr: 0.00001 # learning rate for PLBERT
+ ft_lr: 0.00001 # learning rate for acoustic modules
+
+slmadv_params:
+ min_len: 400 # minimum length of samples
+ max_len: 500 # maximum length of samples
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
+ iter: 20 # update the discriminator every this iterations of generator update
+ thresh: 5 # gradient norm above which the gradient is scaled
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
+ sig: 1.5 # sigma for differentiable duration modeling
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade/tensorboard/events.out.tfevents.1728511195.node-1.1421403.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade/tensorboard/events.out.tfevents.1728511195.node-1.1421403.0
new file mode 100644
index 0000000000000000000000000000000000000000..2e9e97d8004b3a39195d2d9102d5dd4bd3888a9d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade/tensorboard/events.out.tfevents.1728511195.node-1.1421403.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b345ee45d60a91882d9371bbc7e7f2e516d966f6ea5b1c9152a12fb4b3c5e7f2
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade/train.log b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade/train.log
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_165885.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_165885.pth
new file mode 100644
index 0000000000000000000000000000000000000000..519d9aba1101c44697062ca1eb12c5f79db35348
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_165885.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d88946c2ad6951102438d4377a0de03e8e932cdb921a96491c8302d2ccee062
+size 2634096094
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_65527.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_65527.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9cbd10ae7f32bddbfdbe4240651bfd2bc766d69c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_65527.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0234f6214df09eb84fccc4da8a32abbcafa66e538443f50ab0248547f7fb536
+size 2056542305
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_last.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_last.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9fffdf9b2d4092295070a2b414286968c27f8f96
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/2nd_phase_last.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:837323ec6d6a23319c3e80c169595a13010a4ed783015c4b727e7e26bae2928e
+size 2056545874
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/DP_epoch_2nd_00004.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/DP_epoch_2nd_00004.pth
new file mode 100644
index 0000000000000000000000000000000000000000..1856f077866eaa6a17e66d6c185714ed733d0def
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/DP_epoch_2nd_00004.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ebade56e9e2aab901b8c80811a807201b34a64eb56dfc267e84e71a275f09f48
+size 1522665000
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/NO_SLM_epoch_2nd_00009.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/NO_SLM_epoch_2nd_00009.pth
new file mode 100644
index 0000000000000000000000000000000000000000..e3e5a68e9f2031e7fd44e1e7cb826af3c3d1c18a
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/NO_SLM_epoch_2nd_00009.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:845b43f224edc0ad3a57d3a5bfaa2c293a28cbf3caa8c80a7bb9fa1bb08d0470
+size 2056549032
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/NO_SLM_epoch_2nd_00010.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/NO_SLM_epoch_2nd_00010.pth
new file mode 100644
index 0000000000000000000000000000000000000000..315d5477f0a91d57a32c7f15282e2c9f0896a90c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/NO_SLM_epoch_2nd_00010.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:79aa6cd825f51004625f7ba8806da65c6fcc95484dacfb7f862982b81b4a0308
+size 2056549032
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462180.node-1.1003680.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462180.node-1.1003680.0
new file mode 100644
index 0000000000000000000000000000000000000000..f84d38f6a16c20d841ab4113257e4cb3a0b4a847
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462180.node-1.1003680.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f7b7ad158fbb253067e0aac2b0dd0d55aedce480f843a0da1df65c621cba037
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462294.node-1.1004682.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462294.node-1.1004682.0
new file mode 100644
index 0000000000000000000000000000000000000000..f6033949b17d15da9d2f278bd1eb03c293106957
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462294.node-1.1004682.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8fe150cc2d1132dd2eb70b41c013159ef0d7f025f66c81a404da755a4535b8cb
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462472.node-1.1005638.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462472.node-1.1005638.0
new file mode 100644
index 0000000000000000000000000000000000000000..49e0189567d546849e7927a10a73a17e727ed675
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462472.node-1.1005638.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:851ea96b4c47b93d43be6f80acff6230535e6bfbe99c36667589b124441c38cb
+size 2560
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462951.node-1.1007312.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462951.node-1.1007312.0
new file mode 100644
index 0000000000000000000000000000000000000000..a8678c814dffd57115a6f3c0f6fa4df17204678b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728462951.node-1.1007312.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f59ad7609f349f7396199de23168c05c873f722309d0b0f17830b46c430bb07
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463094.node-1.1008219.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463094.node-1.1008219.0
new file mode 100644
index 0000000000000000000000000000000000000000..7dbcb71200b5bfa8b915a67877ea43de57c6da95
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463094.node-1.1008219.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5df01fc68dfb3b019689711260ed2a9321af4c95fdddbb2a96229d5978396625
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463336.node-1.1010823.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463336.node-1.1010823.0
new file mode 100644
index 0000000000000000000000000000000000000000..54d7924362587c72f976306a60efdad8719b7187
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463336.node-1.1010823.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83d76a0888871f133e42fc11450319a6561a6495a789b7ae87d47ecd65b4f97a
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463388.node-1.1011249.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463388.node-1.1011249.0
new file mode 100644
index 0000000000000000000000000000000000000000..4bb8266296ab66ad92485c6ca2766ee37ef105e8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463388.node-1.1011249.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2545722cbdece2a028da4a0a5b270adc17334aab1fbae91501816d21a7464ba3
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463515.node-1.1013548.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463515.node-1.1013548.0
new file mode 100644
index 0000000000000000000000000000000000000000..ecb310df68f0e63b1bc92a42f6a1aca62c0bf028
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463515.node-1.1013548.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d20009b085d787b22a11e53fae855158377954e651cafc1f47cbb9a00b724ad7
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463957.node-1.1016238.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463957.node-1.1016238.0
new file mode 100644
index 0000000000000000000000000000000000000000..90fd9d42e495c9ccafc6fa77e923f69d7da717ed
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728463957.node-1.1016238.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:487baec629ff2b35ef5a5dee52804320ff18926173581103b1e8707b68bd0065
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464009.node-1.1016738.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464009.node-1.1016738.0
new file mode 100644
index 0000000000000000000000000000000000000000..017244a0097198725fc92ce3f8a14966f537e102
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464009.node-1.1016738.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:31294dbd8a51f7d1bcc38516e988fb01106b7e3a95b221746b2ef8265e6f2041
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464233.node-1.1019060.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464233.node-1.1019060.0
new file mode 100644
index 0000000000000000000000000000000000000000..34d35a9a4b5d3fddd380fb7ba8b746590c1272c8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464233.node-1.1019060.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90c1c0bf4b9a6b7df5234a63b7a053f3566cdc36ff8f767eb38f0385df5adc12
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464354.node-1.1019744.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464354.node-1.1019744.0
new file mode 100644
index 0000000000000000000000000000000000000000..2e4cfd89cfa9f3c960b38505deb3b17c40ad5748
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464354.node-1.1019744.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f5120ccb9594416d07241b12993c9062eaef365a5143f138487b5142f3651fbf
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464586.node-1.1020751.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464586.node-1.1020751.0
new file mode 100644
index 0000000000000000000000000000000000000000..a046550d82fc47a0dbf940bde18588873b2e34b3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464586.node-1.1020751.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2f9e3f52c93ab01722243e4b44521e9f56eb045b9a0fc310dbe4451933ab70f
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464707.node-1.1021516.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464707.node-1.1021516.0
new file mode 100644
index 0000000000000000000000000000000000000000..24410e0bf0ef5643756d3f4224c4d8db8194a4c2
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464707.node-1.1021516.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b30a01b19998be89e637bd8a7c708043a48040b16666113b33b220bc75b7a1e8
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464831.node-1.1022361.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464831.node-1.1022361.0
new file mode 100644
index 0000000000000000000000000000000000000000..c72cd8beb17aae3855d5272c53118e8f0dd073fa
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464831.node-1.1022361.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b21a04faa35f646804b73e9dbb99f9959c273c80ee8adeeb8b3d31120575db24
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464900.node-1.1022907.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464900.node-1.1022907.0
new file mode 100644
index 0000000000000000000000000000000000000000..92c4cf5993d0122e1ee70f2aee60ee7b3a10f88c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728464900.node-1.1022907.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eff57ace3d453046db3a73d86ca648eb0d41e380919dadc739ad58ef6fab3a82
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465001.node-1.1025007.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465001.node-1.1025007.0
new file mode 100644
index 0000000000000000000000000000000000000000..2a20220f0d74d289146b7183fc18325462be1747
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465001.node-1.1025007.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe554f916ef302f23d632f05cba6257659a3f4690450ab700e91f6da91439a80
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465067.node-1.1026980.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465067.node-1.1026980.0
new file mode 100644
index 0000000000000000000000000000000000000000..b9131b341465771e1630fb683806197c686ca569
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465067.node-1.1026980.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5def090452ecadea26774681015c59eb6156b32b32c18043ee1c9f5a463649a0
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465130.node-1.1028957.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465130.node-1.1028957.0
new file mode 100644
index 0000000000000000000000000000000000000000..4ffa7d87e049728d614f867af21cdd2727af0eb0
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465130.node-1.1028957.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36986ddd50b5ac71411f0d8c23dd32b29f955e1eb687b626db75c7db9d454b58
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465526.node-1.1031919.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465526.node-1.1031919.0
new file mode 100644
index 0000000000000000000000000000000000000000..e1b41d58ace2cc9c6f513b4a469745e3e4357f4e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465526.node-1.1031919.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd287d2ae4069eef2ce60bb1f8e4d82ce6f0aefba2ae8b500d40986446cb0357
+size 1324
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465719.node-1.1034258.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465719.node-1.1034258.0
new file mode 100644
index 0000000000000000000000000000000000000000..df58a18d2e270e24bb60069afc725f39e4bd8f9f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465719.node-1.1034258.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18a3b4ae4ac7f2b657745d442f650035264c9bd73bcc6946c55c3ad68736b31a
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465773.node-1.1034708.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465773.node-1.1034708.0
new file mode 100644
index 0000000000000000000000000000000000000000..2918a5cac7cace77cc73ee02dbab7ee03f0589e8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465773.node-1.1034708.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f7a3dedca2882cc634dfc00ec09320fb229b606405d02dd7028fcefc78c10d9f
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465956.node-1.1037028.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465956.node-1.1037028.0
new file mode 100644
index 0000000000000000000000000000000000000000..384dff411839372ebd462170bdf96d1913a1ae15
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728465956.node-1.1037028.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:142f9797269c386eae3311cb4c463edbe4db04085b3211ab6cd2e25774125efd
+size 4414
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466151.node-1.1039623.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466151.node-1.1039623.0
new file mode 100644
index 0000000000000000000000000000000000000000..81fcddaec2e4ae9a07adb5e4896f8eb43211736e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466151.node-1.1039623.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a6ec80c55b1ac401c4154cec4959b4a6b580413479e9c39ad6527889bbea685
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466402.node-1.1042081.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466402.node-1.1042081.0
new file mode 100644
index 0000000000000000000000000000000000000000..2dd2c18163192cfed352bb2f712edd954c51e09d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466402.node-1.1042081.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2212f2ad84e8300d7c193b214dae7021fba8336b9e5a8bed28a95d63e67de435
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466494.node-1.1044247.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466494.node-1.1044247.0
new file mode 100644
index 0000000000000000000000000000000000000000..5bd32986ec6e212f4b25ab9361741e58fb668004
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466494.node-1.1044247.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eae22e2ba3ce18c37779e09b95aa6c6b05c73c9884fd725a412b5d35ff73e18a
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466592.node-1.1046287.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466592.node-1.1046287.0
new file mode 100644
index 0000000000000000000000000000000000000000..bc54826f8d6201ab117d3856bb536556164aa4b8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466592.node-1.1046287.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28504f4d924fd17bd6c5ad2a5622bf78a16dc81a71ad9cb9832a9c8b6ec72e3b
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466648.node-1.1048190.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466648.node-1.1048190.0
new file mode 100644
index 0000000000000000000000000000000000000000..eac829291a1e42c0060b3e53ff910ed5f0d0daf6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466648.node-1.1048190.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29375abbff4881e86b811dccec6d07e7bb4c45395e9be31210d760248689cb95
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466722.node-1.1050256.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466722.node-1.1050256.0
new file mode 100644
index 0000000000000000000000000000000000000000..7013a60be7fc37535adfea9351dde205984a41c9
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466722.node-1.1050256.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73a0a5169b6144f9da5dad9b9cde28cdb7f582da193f133b540c778d9959793c
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466835.node-1.1052485.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466835.node-1.1052485.0
new file mode 100644
index 0000000000000000000000000000000000000000..a81eb73556881db2ead7db8bdcb92eb05d5ad549
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728466835.node-1.1052485.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e189732ff4f4bbbd5be8860bf701c24df6a0829339a393ef37accae9ec13d129
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467182.node-1.1055478.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467182.node-1.1055478.0
new file mode 100644
index 0000000000000000000000000000000000000000..2ae1ae08726601d68c09f82294e13084412ca899
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467182.node-1.1055478.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:15672b4ba275caa4e6824df47c9dadd86d95dfd94cee352e8df910225bfd104f
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467254.node-1.1057438.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467254.node-1.1057438.0
new file mode 100644
index 0000000000000000000000000000000000000000..f66645cda60a7a868c4aa89183d33693ad73c00b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467254.node-1.1057438.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6926eb90a8a9352a0e3ad2a7bee3fc7e5e0092dae011bff18c45e2aa4558d44
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467319.node-1.1059460.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467319.node-1.1059460.0
new file mode 100644
index 0000000000000000000000000000000000000000..d9207ef951ef77bd86b25b384a335e82f0114f1c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467319.node-1.1059460.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bc2e88a28f015930cdb557256bf86f91e1ec15b765158390c2d7d7ce15eeed79
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467391.node-1.1061518.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467391.node-1.1061518.0
new file mode 100644
index 0000000000000000000000000000000000000000..5adb8a296faea1abcfb8839357caad494e938f80
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467391.node-1.1061518.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58cdf47456910a48cf2078656ff06b744334e4f53bc9ea0acb7b0a59df1dc9c8
+size 1942
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467539.node-1.1063858.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467539.node-1.1063858.0
new file mode 100644
index 0000000000000000000000000000000000000000..b9a77d0a27f9b58aaa7ebbe8b8f008990d3a08db
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467539.node-1.1063858.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e43bcd5d1a2fa2797b4ff7a826919deec7a90d5958348a7cd11e576f4fe9d26
+size 706
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467704.node-1.1066066.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467704.node-1.1066066.0
new file mode 100644
index 0000000000000000000000000000000000000000..c1ef54dde294618db716d17968f7daad930610e3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467704.node-1.1066066.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c021e3b2755292ff330cf385e1cd94e4618e6e5681cf54bdc27b33833d524e0
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467768.node-1.1068012.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467768.node-1.1068012.0
new file mode 100644
index 0000000000000000000000000000000000000000..83cf991635d84e8c12055f2da78e0e51a98452ea
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467768.node-1.1068012.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46239d81b667b3817b0eca24c1984bdb36c07c40089b59a5886d357053b9221d
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467944.node-1.1070293.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467944.node-1.1070293.0
new file mode 100644
index 0000000000000000000000000000000000000000..5d423b6db66244db2384a06afab5b396758f420e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467944.node-1.1070293.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b85175d09ba76e302aae5dfc38de6686e8b1519f016d070fcc8540c0887db232
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467997.node-1.1070845.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467997.node-1.1070845.0
new file mode 100644
index 0000000000000000000000000000000000000000..7ced2e7e6461750ec8c8721a364f370426c0e7e7
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728467997.node-1.1070845.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8eee6545b725873330aa17136b96ccbab040626eb7710e71fdb5e1d4793d6754
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468148.node-1.1073296.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468148.node-1.1073296.0
new file mode 100644
index 0000000000000000000000000000000000000000..b2677da1f76327c49a2668b4a43f0d144f38b776
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468148.node-1.1073296.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be1272bb28239a5927ce51c0eadc201df9ddb95cff89437c33e958f2f0ee62ce
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468218.node-1.1074238.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468218.node-1.1074238.0
new file mode 100644
index 0000000000000000000000000000000000000000..d88aabc77e3c93d6f07fd9fecb0c347d51ef4163
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468218.node-1.1074238.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a031ad0312a484cc722eeb39e3304f474f6f3b889b3d1200a6aeec7fc2ef8bca
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468357.node-1.1075256.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468357.node-1.1075256.0
new file mode 100644
index 0000000000000000000000000000000000000000..47d7e8f4510a07acc6d12e769a3306490b47889a
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468357.node-1.1075256.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19b5b263822dc220fb1b3336058441c792e345381e1612bf17d5ed625e1cf30a
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468524.node-1.1077715.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468524.node-1.1077715.0
new file mode 100644
index 0000000000000000000000000000000000000000..bef7dd23ef2e2ec9918d6e0c3e4d815c8d60e85f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468524.node-1.1077715.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24dd380ecc7a73301b1ecf94c22f29c7eb4ff8d3de2cabbda2139f3146c19b7b
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468759.node-1.1080401.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468759.node-1.1080401.0
new file mode 100644
index 0000000000000000000000000000000000000000..78d803552585e4f7e449addc068ffdf102f49469
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468759.node-1.1080401.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0ec05ac87d47892081f81192a235cf9cf3e1292a0193630f71ce466c257ad51
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468929.node-1.1082803.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468929.node-1.1082803.0
new file mode 100644
index 0000000000000000000000000000000000000000..514479f894112876cc8db45870633261db728d46
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728468929.node-1.1082803.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:301c084688e0c2b38d28dbeece0fee98d9395680b2ed8c500dc6b7788653452b
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469057.node-1.1085179.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469057.node-1.1085179.0
new file mode 100644
index 0000000000000000000000000000000000000000..7d018af39037d85554f916824fa9afe215cccaf6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469057.node-1.1085179.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ff367043ba65ca32cc860c68317072361aa508efacc41f0e0b2673896315eb5
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469159.node-1.1087345.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469159.node-1.1087345.0
new file mode 100644
index 0000000000000000000000000000000000000000..a240f4fff4b55f12d3a5c85b46cfd78ec05add1e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469159.node-1.1087345.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e4c3ddb353479b95e7d1aae9f1c8949437eb37a69f21ba96dab15325e57c49b
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469244.node-1.1087998.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469244.node-1.1087998.0
new file mode 100644
index 0000000000000000000000000000000000000000..d0c8db430373b46a8383abf1fb408f2a3f5adbba
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469244.node-1.1087998.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e67edadcfde5f78c890b787ca63d4fd9b9a77dd47b55b878ccbb04af381a5fad
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469369.node-1.1088976.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469369.node-1.1088976.0
new file mode 100644
index 0000000000000000000000000000000000000000..e7a1491f24eeff7a37f4782eddb2579c97c25e66
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469369.node-1.1088976.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad6abdceb726793e19f4ed68c8dfc6026a3b83ee0ec71311712b262bcd659755
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469441.node-1.1090879.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469441.node-1.1090879.0
new file mode 100644
index 0000000000000000000000000000000000000000..fea746d7bd128ab8bf9747dbf23b06163428d1b9
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469441.node-1.1090879.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c67970bd3f0b2c7f3408cf0847f4825a862000da80978a05fad863f59fcece9a
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469624.node-1.1093172.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469624.node-1.1093172.0
new file mode 100644
index 0000000000000000000000000000000000000000..42c1911c67b6b3171290d24e84b61da87a882320
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469624.node-1.1093172.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:590826afdf8242752eb56a97970472483afbbf09891ace12322adac84bee108f
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469744.node-1.1095329.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469744.node-1.1095329.0
new file mode 100644
index 0000000000000000000000000000000000000000..7baf85ed19eac34cf03effb3e43c74251a10c6a6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469744.node-1.1095329.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8d3fe6eb049edf942b9f757e1a4ed2e5501e6bf1d346fd1ceff6e97b596c520
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469808.node-1.1097304.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469808.node-1.1097304.0
new file mode 100644
index 0000000000000000000000000000000000000000..77e7445114eac9217101afeb9bb7ad55566a84f9
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469808.node-1.1097304.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e51abb1cd10d97bb7955b3dfc00215f56b7b4dfc02c924c8783c30073a75d023
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469982.node-1.1099829.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469982.node-1.1099829.0
new file mode 100644
index 0000000000000000000000000000000000000000..b589e76d6eb4cfee21360d9781c7173f098074f0
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728469982.node-1.1099829.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffa9b481f2f712ec0df57444eb79e6133233c86ed277e65cf204e8f61c65f617
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728470641.node-1.1103728.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728470641.node-1.1103728.0
new file mode 100644
index 0000000000000000000000000000000000000000..80ab75732e76250467671b2a86295bb99e822664
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728470641.node-1.1103728.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:045ac92694d9406a3212bf143be893c849eaae7dbdf78c026912b1edf9c68014
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728470714.node-1.1104551.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728470714.node-1.1104551.0
new file mode 100644
index 0000000000000000000000000000000000000000..aa7e8acdca5058becc5af640352b5542cb4810e3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728470714.node-1.1104551.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9c9c3f29e7096acb6a63610fba8f9ead5aab336e664ed9a401681d7b9cda094
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471420.node-1.1109347.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471420.node-1.1109347.0
new file mode 100644
index 0000000000000000000000000000000000000000..c695fb77db18c2486df0f66f89d86e4f734f1193
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471420.node-1.1109347.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92ab7a7c6da0662486a682b41622f41c251eb91317ce5590b1124090454f3b2b
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471474.node-1.1111520.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471474.node-1.1111520.0
new file mode 100644
index 0000000000000000000000000000000000000000..5fab1337d256913c9ecd9269325eaa4d3685202e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471474.node-1.1111520.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a90be5f92ca4c6ff0a14696813b1a87ef32964bf7451b2a225c06677718e052
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471556.node-1.1113028.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471556.node-1.1113028.0
new file mode 100644
index 0000000000000000000000000000000000000000..1f556663e6a3dece80e7f6a244865eeaf8e332f2
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471556.node-1.1113028.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b3b3663aea1749be50303295e0f9ab52874bf9c6117d1286905ce5969b29fea3
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471660.node-1.1114010.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471660.node-1.1114010.0
new file mode 100644
index 0000000000000000000000000000000000000000..c4c784bda7f0723df4e51a08e251933eeafe94ae
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728471660.node-1.1114010.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6abbb454f906fdd2b9890c686ceadfb667c20125f500ade0d45d981cef1a8d45
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472335.node-1.1118079.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472335.node-1.1118079.0
new file mode 100644
index 0000000000000000000000000000000000000000..e170a8e0f329de3386cfc3a58c0cc5060283aec6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472335.node-1.1118079.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5fa1c23288f18d17cce80b268582d7d8f61b0f32e432b0e9b239aeed41654c0f
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472408.node-1.1120288.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472408.node-1.1120288.0
new file mode 100644
index 0000000000000000000000000000000000000000..4a903a36f39a4e364e81e863b8640961fbebe867
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472408.node-1.1120288.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67637ef05465a978ada7b69cdcfc8e8b5ff3b70efe2487d2a295c0c3d97660c3
+size 1324
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472506.node-1.1123424.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472506.node-1.1123424.0
new file mode 100644
index 0000000000000000000000000000000000000000..99f8644cc86a9e3ac87bdbfb37c8eafa9a7a6e6d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728472506.node-1.1123424.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:184d97005478b90d652355ef655bed4fe1316e25c4e15fceb156329a8bed53e8
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473177.node-1.1136121.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473177.node-1.1136121.0
new file mode 100644
index 0000000000000000000000000000000000000000..7fb8e7102402c40e4f308c0d589847ac290d0f74
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473177.node-1.1136121.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52f6b34119843dee79241dcc39912902154377c366836a7d51c8de8b5c61709d
+size 1942
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473486.node-1.1139953.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473486.node-1.1139953.0
new file mode 100644
index 0000000000000000000000000000000000000000..fa8f1ae29a8a36e9a15576f77374aeecb9bf7476
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473486.node-1.1139953.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed7437ff69fde40d4f738db6e7a9fbeef00b97770213b7eddb9d9345a7d21435
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473618.node-1.1142371.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473618.node-1.1142371.0
new file mode 100644
index 0000000000000000000000000000000000000000..2f8669492a45db2f61b2a2bdc864743e558ff165
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473618.node-1.1142371.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c39c9006075adadadd8521de7185fdd4c4b0b222322af9a120903f7b5b10f893
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473779.node-1.1145021.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473779.node-1.1145021.0
new file mode 100644
index 0000000000000000000000000000000000000000..1aca3f04e3538ae71ed8a0487d1ff862cec724ae
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728473779.node-1.1145021.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbcfef02ac2e5af8a1078156c6ea9bf7c611327bd94ca2ced252af143435e01b
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728474255.node-1.1149018.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728474255.node-1.1149018.0
new file mode 100644
index 0000000000000000000000000000000000000000..cc5687e6b25f64ec090ad50140f911fee1b9eff9
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728474255.node-1.1149018.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57fd22b1bc314f07e9fd27114447a6c9d64c941b07a82c2a444bc4b5edffdabe
+size 555759
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487185.node-1.1202925.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487185.node-1.1202925.0
new file mode 100644
index 0000000000000000000000000000000000000000..67d95e84ee01059f566f7dc7dd27120407c182b5
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487185.node-1.1202925.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1e308b358ec1231307f87915228e18475dacf5300f627513939e9fde9dd39c3
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487233.node-1.1203376.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487233.node-1.1203376.0
new file mode 100644
index 0000000000000000000000000000000000000000..0a85c63ddc8e632463f7c0dd670f399f54e4f339
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487233.node-1.1203376.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fb27bda6805f0e56464daba72e54c21addd5089da410c7cd454ddd6a36b11d1
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487521.node-1.1206002.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487521.node-1.1206002.0
new file mode 100644
index 0000000000000000000000000000000000000000..c6f9e32c352f3b01223b871307ef8e69906dbaf9
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487521.node-1.1206002.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d018fcef81e2ab6e0b6771e4268bd39a6b1aea6f8f8b1a23615d91959ab6a0c
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487719.node-1.1208224.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487719.node-1.1208224.0
new file mode 100644
index 0000000000000000000000000000000000000000..bacad2942911f3f6fe90730b28ee91afb33e8557
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487719.node-1.1208224.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0560ae3206a2b07fbcb7136b020f7f429b69304a099a8fa549a4769969c167ec
+size 858
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487846.node-1.1210278.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487846.node-1.1210278.0
new file mode 100644
index 0000000000000000000000000000000000000000..82718b603b9ede8fbe27acac3c0c778893835791
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728487846.node-1.1210278.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c612276145519d5ee8850d445101c14bbdf0ee81ab9c394037bc4546a3c526c4
+size 4591746
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488020.node-1.1212545.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488020.node-1.1212545.0
new file mode 100644
index 0000000000000000000000000000000000000000..24bcc88a01415c78572be75043131a382651e354
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488020.node-1.1212545.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f4de8f292b6e9faf0a77c80f8d971913f78ddaf484f6db16d92d8ede454b4d72
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488090.node-1.1214251.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488090.node-1.1214251.0
new file mode 100644
index 0000000000000000000000000000000000000000..9f12a25ec9b02a284b5df1d3ea0f4af221a2a0b2
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488090.node-1.1214251.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f35c2d0bb2b73227e0792615ff3eb72383325cb424946f507799b794f81c5efc
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488134.node-1.1214738.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488134.node-1.1214738.0
new file mode 100644
index 0000000000000000000000000000000000000000..2e904820be5a69d7a4f8e5071c9164bf131b624d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488134.node-1.1214738.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85e9da33ed28c08ce54d0db888fcf34fa1de3e2d10005640165640729d7854e5
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488199.node-1.1215408.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488199.node-1.1215408.0
new file mode 100644
index 0000000000000000000000000000000000000000..dc0cb12017b0c3f803efe7bde5db48967ec1f2f6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488199.node-1.1215408.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd9f8a6bb182fffa20b9f8a6ef8c384f924401e88e6189fed6f37a201d025800
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488283.node-1.1217607.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488283.node-1.1217607.0
new file mode 100644
index 0000000000000000000000000000000000000000..762120ed77321e42be492e71b11f225e7faabe99
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488283.node-1.1217607.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:544957ba29db631d9688e63ac3b74f406569e14ca9b4e3d46399a8fd0cf8f1c1
+size 2560
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488813.node-1.1225079.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488813.node-1.1225079.0
new file mode 100644
index 0000000000000000000000000000000000000000..3ec6e491a574f535007ccab5c7c6b94d45ad91a2
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728488813.node-1.1225079.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:210b5e1cdb9749977d9cfa24e0bc4e02485694f4a912724efa5e18f84628290d
+size 5004630
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728489070.node-1.1227900.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728489070.node-1.1227900.0
new file mode 100644
index 0000000000000000000000000000000000000000..e665167341108852a2489bacca8352a7f35d53ea
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728489070.node-1.1227900.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd1242adfe2d72de796ffe9bddad1a32eac501ef158cc9b2d8e305fd92570f9b
+size 2560
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728489196.node-1.1230206.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728489196.node-1.1230206.0
new file mode 100644
index 0000000000000000000000000000000000000000..78da2206d3d39c700129a399896f02c075d7cd16
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728489196.node-1.1230206.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2d49106bef25e342fbdc3662a9f6da636acc0d4f518dcdb3c8d360bae234a0f
+size 30184
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728490041.node-1.1235498.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728490041.node-1.1235498.0
new file mode 100644
index 0000000000000000000000000000000000000000..3c2f5749c685928a62df0796bc95e77877d97b00
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728490041.node-1.1235498.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f061eb622d4a9f1d98ab385821a8531fc5bf9188587136973824d5c689e24d3
+size 8134
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728490343.node-1.1239201.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728490343.node-1.1239201.0
new file mode 100644
index 0000000000000000000000000000000000000000..63b250a4892f5185547142e2261e12360c592b49
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728490343.node-1.1239201.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bcece26f22a95f7d8ddc5ad8ce8fc582a01a40645eedaeaefa22599f7b407871
+size 79954
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728492553.node-1.1250117.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728492553.node-1.1250117.0
new file mode 100644
index 0000000000000000000000000000000000000000..1694dcac466ad8c7746f267a702a8a9c6913d91d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728492553.node-1.1250117.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:903b9b6e7d0f7751d3e74bf7ce04b66b87cc47f7009b52ffbf152364f8c53b28
+size 476379
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728503605.node-1.1299336.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728503605.node-1.1299336.0
new file mode 100644
index 0000000000000000000000000000000000000000..55021d217146e4474f92181315ef8805329ecb98
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728503605.node-1.1299336.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:947408fa80fc80f76a404a543586270d92743eea6115daa308911c732b1cb25c
+size 240
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728503690.node-1.1301434.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728503690.node-1.1301434.0
new file mode 100644
index 0000000000000000000000000000000000000000..61e3576b77387bf0acdc07cc54e9664557a7cbb6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728503690.node-1.1301434.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93e77b10045c1f87f31d8b253e1ac2f55a1a2f078ea36ae44973e0ea030144db
+size 240
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504087.node-1.1307695.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504087.node-1.1307695.0
new file mode 100644
index 0000000000000000000000000000000000000000..a89a9eb973fbc1564629621c8f1c2531fb1ec864
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504087.node-1.1307695.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2fee1c2fcc9e71d22993e8e86c2e2bd5b5b5170ff4f98f3e806bc0de732b724f
+size 240
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504151.node-1.1309669.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504151.node-1.1309669.0
new file mode 100644
index 0000000000000000000000000000000000000000..7da1741a0c61054823605fe57b5affe1d6a2a720
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504151.node-1.1309669.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c01f6a0c718c35440967c5fefb2306a80be4a9e6cb16b72398e5ef5a643a6d7c
+size 240
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504204.node-1.1311586.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504204.node-1.1311586.0
new file mode 100644
index 0000000000000000000000000000000000000000..d9b3af07dbaade7a6bf81f3ccbadc52325ae1c39
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728504204.node-1.1311586.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:200237c7afa9ce9e1ac0a8899a168429abe1accad0a09f71d5386749d3d9c5ee
+size 858
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511323.node-1.1423575.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511323.node-1.1423575.0
new file mode 100644
index 0000000000000000000000000000000000000000..ecaaeca9c281cd29e4536d9410360627e8f82ea1
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511323.node-1.1423575.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0db859c2db2ee907add0424780a44604ec92341fcd669b79c432e390b293a840
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511408.node-1.1424277.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511408.node-1.1424277.0
new file mode 100644
index 0000000000000000000000000000000000000000..19fecb8c17995ec285ee85f82155102dece0c2f8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511408.node-1.1424277.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfd82121f93140d6a4fe8a502d23a7f63e1e3c47025a28d7558a8ae8be543fcf
+size 240
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511536.node-1.1426496.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511536.node-1.1426496.0
new file mode 100644
index 0000000000000000000000000000000000000000..5e83c720ecf18b37c5455ca3d94324769b32f207
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511536.node-1.1426496.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f7f7e0aa60e435153833623263e3e7de4637756092df4fea84bf888200b5d5fa
+size 9940151
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511656.node-1.1428583.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511656.node-1.1428583.0
new file mode 100644
index 0000000000000000000000000000000000000000..9a53f6fb4c21834d249edfb6116b19909c9f95fa
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511656.node-1.1428583.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6483db507f7e754db989463f587ddfab962d71af36ff759c0b9b0eb9b59d0b98
+size 13389318
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511919.node-1.1430987.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511919.node-1.1430987.0
new file mode 100644
index 0000000000000000000000000000000000000000..173baab7793fa37dde9ea8f659c0b322a58c538c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728511919.node-1.1430987.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7136da2a05192cbace1d28174fc1552a0385ab718c69a6251c936e2410b36987
+size 2560
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512128.node-1.1433243.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512128.node-1.1433243.0
new file mode 100644
index 0000000000000000000000000000000000000000..5c716512f8bf1d362e0c4c56499b650bc5c97c5f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512128.node-1.1433243.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3118e7cfa8447d56f16cfa43900ae3b916819b68baed99a5f8523422c03e497b
+size 706
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512204.node-1.1435043.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512204.node-1.1435043.0
new file mode 100644
index 0000000000000000000000000000000000000000..61f1e24cf81aeaf0f1853a1120a810ada01e7f41
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512204.node-1.1435043.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df2ca2a9853eb391f5d26a6fe7b30f8eb1acb13c77f79a31680d77bbcda39eaf
+size 15773699
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512534.node-1.1440506.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512534.node-1.1440506.0
new file mode 100644
index 0000000000000000000000000000000000000000..c8987664e146c63c6b305d33d92398df562dc0c4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512534.node-1.1440506.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff809e0a1fcea3f6140cd7851fb97567db8ed35ba34fc5ba227fc4ba6fcb3cc3
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512594.node-1.1442230.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512594.node-1.1442230.0
new file mode 100644
index 0000000000000000000000000000000000000000..4c38135aab585323dcb1ed7c5a0809e01c6b0fb5
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512594.node-1.1442230.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d990c7485f8b75f9d4b573c665ff9de1203de263366bd7b8bebfe2455c518f93
+size 4195864
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512703.node-1.1444519.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512703.node-1.1444519.0
new file mode 100644
index 0000000000000000000000000000000000000000..4cef3152b193131e11794b8539f0c507262f32a6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512703.node-1.1444519.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b90e0713bac7338b7710ed7d04f9e3b38f083d133b68d56bfc6d80a3e47a394a
+size 13389318
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512819.node-1.1446634.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512819.node-1.1446634.0
new file mode 100644
index 0000000000000000000000000000000000000000..adf87e6f4200c50f3bf7646f8be61a3df341b8db
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728512819.node-1.1446634.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0203f7aba878fc46c47115a4adb32cdb3323e5d51825059af4129a264e962a1
+size 21578343
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728513849.node-1.1453466.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728513849.node-1.1453466.0
new file mode 100644
index 0000000000000000000000000000000000000000..ce34ea794c3e245bba0236457cf958af4df27f83
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/StyleTTS2-Second-Stage/events.out.tfevents.1728513849.node-1.1453466.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d4d5642faa2f1d14028b5b80fdc9d6c8c4f6f163502b3afac6f72a6b43651a4
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/Top_ckpt.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/Top_ckpt.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c9871d6ed8623a57ec42a6234a50a9201abb5232
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/Top_ckpt.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b369c6b5caacd1c89186bea27f7f999e8c2f6d565fc88d468dcf2a9a33c9efe7
+size 2056549032
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/config_kanade_48khz.yml b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/config_kanade_48khz.yml
new file mode 100644
index 0000000000000000000000000000000000000000..afc0b2496b8dd8d33d03838cdf8c8d0afb991d92
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/config_kanade_48khz.yml
@@ -0,0 +1,24 @@
+{ASR_config: /home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml, ASR_path: /home/austin/disk2/llmvcs/tt/AuxiliaryASR/Checkpoint_new_plus/epoch_00070.pth,
+ F0_path: /home/austin/disk1/stts-zs_cleaning/F0_extractor/PitchExtractor/Checkpoint_200k/PE_48khz_epoch_00060.pth,
+ PLBERT_dir: /home/austin/disk2/llmvcs/tt/stylekan/Utils/PLBERT, batch_size: 35,
+ data_params: {OOD_data: /home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv,
+ min_length: 50, root_path: '', train_data: /home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/train_48_pure.csv,
+ val_data: /home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/val_48_pure.csv},
+ device: cuda, epochs_1st: 25, epochs_2nd: 15, first_stage_path: '', load_only_params: false,
+ log_dir: Models/Style_Kanade_48khz, log_interval: 10, loss_params: {TMA_epoch: 3,
+ diff_epoch: 4, joint_epoch: 999, lambda_F0: 1.0, lambda_ce: 20.0, lambda_diff: 1.0,
+ lambda_dur: 1.0, lambda_gen: 1.0, lambda_mel: 10.0, lambda_mono: 1.0, lambda_norm: 1.0,
+ lambda_s2s: 1.0, lambda_slm: 1.0, lambda_sty: 1.0}, max_len: 560, model_params: {
+ decoder: {gen_istft_hop_size: 4, gen_istft_n_fft: 32, resblock_dilation_sizes: [
+ [1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes: [3, 7, 11], type: istftnet,
+ upsample_initial_channel: 512, upsample_kernel_sizes: [32, 16], upsample_rates: [
+ 16, 8]}, diffusion: {dist: {estimate_sigma_data: true, mean: -3.0, sigma_data: 0.30184900420868704,
+ std: 1.0}, embedding_mask_proba: 0.1, transformer: {head_features: 64, multiplier: 2,
+ num_heads: 8, num_layers: 3}}, dim_in: 64, dropout: 0.2, hidden_dim: 512,
+ max_conv_dim: 512, max_dur: 50, multispeaker: true, n_layer: 3, n_mels: 80, n_token: 178,
+ slm: {hidden: 1280, initial_channel: 64, model: Respair/Whisper_Large_v2_Encoder_Block,
+ nlayers: 33, sr: 16000}, sr: 48000, style_dim: 128}, optimizer_params: {bert_lr: 1.0e-05,
+ ft_lr: 1.0e-05, lr: 0.0001}, preprocess_params: {spect_params: {hop_length: 512,
+ n_fft: 2048, win_length: 2048}, sr: 48000}, pretrained_model: /home/austin/disk2/llmvcs/tt/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/Top_ckpt.pth,
+ save_freq: 1, second_stage_load_pretrained: true, slmadv_params: {batch_percentage: 0.5,
+ iter: 20, max_len: 500, min_len: 400, scale: 0.01, sig: 1.5, thresh: 5}}
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00003.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00003.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c3024a44fac48d9fffd501c502adfafb62536c71
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00003.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c6fafba7ae5563c6db31bfff9ef96c64fda9027771a077e7b5e68b83b813193
+size 1522665000
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00004.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00004.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7fff3d1db34f6167ec08ad3bc8fe48984922ee02
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00004.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0455818056ffbb8e00696f0637084b7bfbbcf6f2930b13da9359ee0637c41b92
+size 2056549032
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00009.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00009.pth
new file mode 100644
index 0000000000000000000000000000000000000000..5eae125b582c9d0256da98850956a4cc1abd1990
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00009.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78a6978215be4d7f1458d064bc13e69332eded52ebf9ca3fcd69931a0161b92a
+size 2056549032
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00010.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00010.pth
new file mode 100644
index 0000000000000000000000000000000000000000..47ef1292994d166c3afe90f234d0103b233d1826
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00010.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b1f78fd61e5f3d7314dfbbbe0dc7094f88eafc079c8c215372fb9dfecd26a3c0
+size 2608179446
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00011.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00011.pth
new file mode 100644
index 0000000000000000000000000000000000000000..083df15a26d3fde3bd6e905f2ca39e006385227e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00011.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9269bb3c2f928639363e3cf84efcca9ebbfa4dddf92ca1bd55ec2cfd519a4cf9
+size 2608179446
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00012.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00012.pth
new file mode 100644
index 0000000000000000000000000000000000000000..71d96f648cf57cf0df394c69051834be6b57d47b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00012.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d235010d51f4019cab22ee8210a3e8091ed80f08e8529b51228fc84f566cf2ed
+size 2608179446
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00013.pth b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00013.pth
new file mode 100644
index 0000000000000000000000000000000000000000..572595973a37cbe984ff53d260cf7f1f4ddc2b73
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/epoch_2nd_00013.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0ff231457c82fb3eabc2c2c2e4b9f2345faf4a882fa58f7e3d2f0fada1a3db9e
+size 2608179446
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728257866.node-1.65748.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728257866.node-1.65748.0
new file mode 100644
index 0000000000000000000000000000000000000000..bda7d197e54dea618495db73638144d9c9ab67bd
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728257866.node-1.65748.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b1730acd727794e3acaa1a5754d7126c545e88caaee5e75f8c3175706a65f8e
+size 33214666
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728429904.node-1.930170.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728429904.node-1.930170.0
new file mode 100644
index 0000000000000000000000000000000000000000..a88386c2bbb2bb1b29fc12c97c5a9ef18aca3194
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728429904.node-1.930170.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c8384a6c67a2456024956feed5c5750b6727c3cba6e9bb08fec7f00a5267512
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728430130.node-1.931128.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728430130.node-1.931128.0
new file mode 100644
index 0000000000000000000000000000000000000000..7e31e2b39bf79e07c891c4b2c09daac3d0f00961
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728430130.node-1.931128.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f63cbcad7db880aa92490861b053d81b5fe5f9a2bfcf9efcb939a929cc74daac
+size 2026
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728430270.node-1.933490.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728430270.node-1.933490.0
new file mode 100644
index 0000000000000000000000000000000000000000..d326fd213061b9269b0199fc4ae3ba8b3ee3c791
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728430270.node-1.933490.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0fc47f20440b0848a5fed51334d5ea19c595cf7323dc04fb1aeba8f06a1b64c
+size 7685297
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728473870.node-1.1145594.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728473870.node-1.1145594.0
new file mode 100644
index 0000000000000000000000000000000000000000..a304cbd5f085a0a09ca9d9daf7434070019b1502
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728473870.node-1.1145594.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:adf46083d4a114e7caa5999bbbec7a4c50dbcc79aefcf1838eeb10b108d5cd76
+size 7194
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728488454.node-1.1220173.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728488454.node-1.1220173.0
new file mode 100644
index 0000000000000000000000000000000000000000..a7a65c8cba9368b4ad2e66ce439acba9d03c1c45
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728488454.node-1.1220173.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cab4783a5e58e69457cfc3e6ae07408f354b14f10fef211f0900b0b05c13044f
+size 716
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728488746.node-1.1223649.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728488746.node-1.1223649.0
new file mode 100644
index 0000000000000000000000000000000000000000..7d8a041040cbf72e9523b9169bf6f06acef624b8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728488746.node-1.1223649.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f64a26251ae180924ebd92db974506db5f6f2df290ecfc5a10296b151b7e40c4
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728503762.node-1.1303426.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728503762.node-1.1303426.0
new file mode 100644
index 0000000000000000000000000000000000000000..1fd3a8fb789902f1b09b0139955ef010fdeeeb94
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728503762.node-1.1303426.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:634763cc376e8455e7004f281c2dda591846ef6cc2d59bf5d8f6e612d26351fa
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728503932.node-1.1305451.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728503932.node-1.1305451.0
new file mode 100644
index 0000000000000000000000000000000000000000..0df9ed07d3822d2bb0f8c44278c6390d9d407642
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728503932.node-1.1305451.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7757b5d5393026caf3dc4f3bc80bd2240e238f4433dba3643cbbe0919ba101e1
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728504347.node-1.1318570.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728504347.node-1.1318570.0
new file mode 100644
index 0000000000000000000000000000000000000000..1fa6b88354c19541f1678b9e5beb2d4e5351665c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728504347.node-1.1318570.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08ec66e9de179f66c751b74f6fb4b14abe6e6ca24eb83c19f495e6b02f94947d
+size 1751232
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728504863.node-1.1325263.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728504863.node-1.1325263.0
new file mode 100644
index 0000000000000000000000000000000000000000..cf275699ee0862757d8acddf5e68c70310a43626
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728504863.node-1.1325263.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cb868df96d5ba80a40f04553dd42349c59360a0d15fcad0699333016a961389
+size 1243492
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505075.node-1.1327494.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505075.node-1.1327494.0
new file mode 100644
index 0000000000000000000000000000000000000000..de2ab0552aa1d8f34b22f9dbd05ebecb86b3130d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505075.node-1.1327494.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:622e07d2e583706f3328a237439f9013392f650daffb81dd3da8bc556c7f29e9
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505112.node-1.1327855.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505112.node-1.1327855.0
new file mode 100644
index 0000000000000000000000000000000000000000..0105ff89aed9d6209d471f5e5862f5dedca82c9f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505112.node-1.1327855.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd599799b0d1ba5af998ce99b72e5630065cea307df4181a6c4ad4c2d27360ea
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505245.node-1.1329787.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505245.node-1.1329787.0
new file mode 100644
index 0000000000000000000000000000000000000000..7f0a4b0843947c1e6152e06c114bc300e26098cc
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505245.node-1.1329787.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e571aced58c070c39d19a199ba82a3a1949bae3f406c6353e1d777b62fb90d0a
+size 20791408
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505370.node-1.1331830.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505370.node-1.1331830.0
new file mode 100644
index 0000000000000000000000000000000000000000..7ab19ddcc06b7f890328d36fcff2360e57c310e4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505370.node-1.1331830.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:48a2ac543652ec2a66b0eeff0e475c80e5e417eca8a1d98cf0c92969d6f40cba
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505474.node-1.1333540.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505474.node-1.1333540.0
new file mode 100644
index 0000000000000000000000000000000000000000..90cae509a2e6f33bfd3638737037de16653f01ba
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505474.node-1.1333540.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28bbd72cac5daddea7483aa16fbbfb4c1d93cf3cf15ffd8513999a0a4760b036
+size 12474960
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505580.node-1.1335062.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505580.node-1.1335062.0
new file mode 100644
index 0000000000000000000000000000000000000000..8782a5ee32f8c007316a01ff1c89acea6057e1c3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505580.node-1.1335062.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b3adbb4a58c09d6f82698b546d1c2fce562de7a1fe8f3cffc5b580b6e989b4b2
+size 1749976
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505712.node-1.1337247.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505712.node-1.1337247.0
new file mode 100644
index 0000000000000000000000000000000000000000..edfeb5857c20544c88276071bd206b238cb6a5b5
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505712.node-1.1337247.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd986ad0ccd5966ce7c51d2b6d04c94130dd61ab3f29b7797903ac4d2a3392dc
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505819.node-1.1338692.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505819.node-1.1338692.0
new file mode 100644
index 0000000000000000000000000000000000000000..87afc820474f4b70d3784c75a9d9a712ad5e14ff
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505819.node-1.1338692.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac6ac91d60d9a64264772cc8e517c07f216c796ec3735052fd30bc3797add6d8
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505946.node-1.1340466.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505946.node-1.1340466.0
new file mode 100644
index 0000000000000000000000000000000000000000..aaf5533df6e2f3d0b06a7d0a3cccd6f20049ffb6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728505946.node-1.1340466.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54608f7854679c4eefa1a205910e580a1c41f92115385297b2f550f3f26aab81
+size 1610712
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506199.node-1.1343315.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506199.node-1.1343315.0
new file mode 100644
index 0000000000000000000000000000000000000000..dc515ed1f2d382ea0905f358a1087c286490a01c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506199.node-1.1343315.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3fbdd2b51eab50250f411285313e632d8ed75781bdd94d590b92236b057ffb55
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506401.node-1.1345821.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506401.node-1.1345821.0
new file mode 100644
index 0000000000000000000000000000000000000000..7588dac6cc544c647de32f05647a450786e9a71b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506401.node-1.1345821.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4764ecc65f5302445c787ded2039c3eec6aec1e8aad68102b2141d7f03131573
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506519.node-1.1347829.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506519.node-1.1347829.0
new file mode 100644
index 0000000000000000000000000000000000000000..230e99a1dfbb50a99ef99d0a2ccb91b175f278d6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506519.node-1.1347829.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:daa3ad31d85ae6394b7e4d960fc1140a9165b75b2176bd2f5962dbcd6dad58cb
+size 1071460
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506607.node-1.1349384.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506607.node-1.1349384.0
new file mode 100644
index 0000000000000000000000000000000000000000..3c706a473a08d2a5024a31c69d627c02ac85e445
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506607.node-1.1349384.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:527ff451ba23333317d9e13dcbe7e9af28781be3b64347ad81c9b63b7e231de7
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506713.node-1.1350839.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506713.node-1.1350839.0
new file mode 100644
index 0000000000000000000000000000000000000000..007093c3be656a20cb4a6a92ec9f0ba3c190249a
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506713.node-1.1350839.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:736aeb4a498ee3ab82a7c5e6a776b420ea9c2bfe5e7b792949fc2c9b12edf328
+size 1071460
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506793.node-1.1352426.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506793.node-1.1352426.0
new file mode 100644
index 0000000000000000000000000000000000000000..a2b073e58c1154096c03e4c2c8435b4db9542e64
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506793.node-1.1352426.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a9921c75f861546a3878b876355b68c2ab2a4477a6c45db5063d0c095d49345
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506891.node-1.1353833.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506891.node-1.1353833.0
new file mode 100644
index 0000000000000000000000000000000000000000..c0a38a0f65d7d091b5b3eaa3a19f988f4ff5bd17
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506891.node-1.1353833.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0356cb32e88010b3dc00d8b7ecdb44b796bf3fb1cec02fdb59e500658cc2b205
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506962.node-1.1355195.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506962.node-1.1355195.0
new file mode 100644
index 0000000000000000000000000000000000000000..1994e49dde03365e90f832ae0de32c3f3b102dea
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728506962.node-1.1355195.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c3a20820962c223f7f02afbfb56165f1bdc00ad047e4010511274752406838b2
+size 1128804
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507037.node-1.1356391.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507037.node-1.1356391.0
new file mode 100644
index 0000000000000000000000000000000000000000..552bcbba4afc5a644bbe0e090193f1ffb5e8ea03
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507037.node-1.1356391.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9159905a35d527f13e61f9a34894470a295be38b6569db22711a9b42f5ea2d1
+size 1144396
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507356.node-1.1361327.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507356.node-1.1361327.0
new file mode 100644
index 0000000000000000000000000000000000000000..76ba49323b9d51eb55b44ed4dbb9d63eeea7c04b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507356.node-1.1361327.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac806d1d61c578527908a770ebb809ded56563207f4cb26d32d5090e6f89c478
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507459.node-1.1362758.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507459.node-1.1362758.0
new file mode 100644
index 0000000000000000000000000000000000000000..52d637d4553cb813c069154e8e12136d9b35df70
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507459.node-1.1362758.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddeb910d7285ff265310ba13fc6ea7f3a87c493f740d900928215d430eba2942
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507617.node-1.1364285.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507617.node-1.1364285.0
new file mode 100644
index 0000000000000000000000000000000000000000..5e8d7fdcb721490b900a9dd4a8f73fd437af40c1
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507617.node-1.1364285.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9a2224636434867269dc5ef910f8f5d56ac633d76f3f52dfd021a61e12b04a4
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507736.node-1.1366187.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507736.node-1.1366187.0
new file mode 100644
index 0000000000000000000000000000000000000000..741e2d8efe59fb78fa4a9d0b3f51507703d98481
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507736.node-1.1366187.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a26a166d764e05eab5f774309f3c16e8fd0bd3dffed958cc4fe746151dbdded
+size 1749348
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507863.node-1.1368555.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507863.node-1.1368555.0
new file mode 100644
index 0000000000000000000000000000000000000000..8d948e6a1a66e18673015fe64af3a018995e1c45
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507863.node-1.1368555.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c3bec96a5ae42541682a4795f3498875f1bc5a4b286a80117fa62a7a75e9b03
+size 12474960
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507963.node-1.1370284.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507963.node-1.1370284.0
new file mode 100644
index 0000000000000000000000000000000000000000..87129d9296faa70053ad974c259094f9035ff382
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728507963.node-1.1370284.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c563b4a6aa008d024af4984267b6fdceb405ca404829369d1262ecb9a2821c26
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508113.node-1.1372610.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508113.node-1.1372610.0
new file mode 100644
index 0000000000000000000000000000000000000000..a98860e02a52b28e2b8eb7aa3205643eafac0f70
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508113.node-1.1372610.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89948ce4ed59a0e063ebb81a76035c300dc369d071d979c67c2203141428f60e
+size 1075556
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508240.node-1.1374446.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508240.node-1.1374446.0
new file mode 100644
index 0000000000000000000000000000000000000000..99cd25326026b5dbb317dd991f100f2b65138fab
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508240.node-1.1374446.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fde13a8a1b7b42aa19bf8044de28b5efc7920fd79585a451f801c869502333e1
+size 1749348
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508377.node-1.1376606.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508377.node-1.1376606.0
new file mode 100644
index 0000000000000000000000000000000000000000..9f129237cc6971ba9ee8e30b72e25041d1438344
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508377.node-1.1376606.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c5da030d75c4898bc784d03f3ff55bea3e09faf77ced9402d8d29d03e805d600
+size 22973570
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508581.node-1.1379646.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508581.node-1.1379646.0
new file mode 100644
index 0000000000000000000000000000000000000000..8fb53b94321dd52b34ab7f5cb4ba6076b9cd2038
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508581.node-1.1379646.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:049e6ff7afe5e7bc245a87143240ced30f830824c15d9842b4a9549dc312acd8
+size 1128804
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508682.node-1.1381228.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508682.node-1.1381228.0
new file mode 100644
index 0000000000000000000000000000000000000000..2cd93c703b3163b642a4d2ff9309d5d88c08cad5
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508682.node-1.1381228.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3b7da240165a3a7b1c290a4c52f425d0f6ce1831eb31a35ab24d0a8cf8ec8a3e
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508766.node-1.1382835.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508766.node-1.1382835.0
new file mode 100644
index 0000000000000000000000000000000000000000..b968406e16b5a1b435621a9e70a062b053e4cff5
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508766.node-1.1382835.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5fbc6b2ea6298c48e3f9823247fb719efe2448a5dee8bc89fbf881f32e19d1d0
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508887.node-1.1384844.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508887.node-1.1384844.0
new file mode 100644
index 0000000000000000000000000000000000000000..41a4babe704b2434fb991ae7efd26c7c3fc0c225
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728508887.node-1.1384844.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fd124176368795f3d038ae951c3dac1c3f13dba85ca455572712a0d57d14c96
+size 1067364
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509025.node-1.1387215.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509025.node-1.1387215.0
new file mode 100644
index 0000000000000000000000000000000000000000..8f8621eeb0ff077451c9aedacf7a3248beafa526
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509025.node-1.1387215.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11dd6fc2c343e32dbc7f59f91e4415ae6e4131b1455df15f2f7f136b352dc961
+size 1067364
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509194.node-1.1389681.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509194.node-1.1389681.0
new file mode 100644
index 0000000000000000000000000000000000000000..b2c71adfa5bd75278a78769b15f1c5a671ce2b50
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509194.node-1.1389681.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c188a356a9e46461418396291c81d4834c0db564438daee1f0189faf2236e02
+size 1749348
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509325.node-1.1392092.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509325.node-1.1392092.0
new file mode 100644
index 0000000000000000000000000000000000000000..aef70144b74e593cedb006620e454ab0a500205e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509325.node-1.1392092.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a87491c22bcecc80aec602f23667f67024a4c0bc1eba16e0edd2ae4240b1c01
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509442.node-1.1393983.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509442.node-1.1393983.0
new file mode 100644
index 0000000000000000000000000000000000000000..9b761cb65c2c2f4f522262784b95d4a7ea67fadb
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509442.node-1.1393983.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ca5514d5c2faad97bde1a496be1ff7540914402467f5bdbddcef29cd008e41e1
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509546.node-1.1395653.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509546.node-1.1395653.0
new file mode 100644
index 0000000000000000000000000000000000000000..dc7425084989b77d5ee3e4f0d48b358bcc48a4a8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509546.node-1.1395653.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a87f5d0f3a4828b3cb7c75322a200b1cca937e097e8871a8025e422118abb37
+size 1116516
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509626.node-1.1397060.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509626.node-1.1397060.0
new file mode 100644
index 0000000000000000000000000000000000000000..9896df176624779c5180b27fbed85a18f6199299
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509626.node-1.1397060.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:deb1884a18441cc7eaa5e3fe8cad1088836d4346d6e17035664165b2e25fe00c
+size 1116516
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509766.node-1.1398641.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509766.node-1.1398641.0
new file mode 100644
index 0000000000000000000000000000000000000000..7df62434be7c0155b966a8685ae202e656f12fdc
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509766.node-1.1398641.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ff77690c202090ec860731ea756c1c9b2bff73f49bba95d0d5a54efd84e1009
+size 12604588
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509894.node-1.1401053.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509894.node-1.1401053.0
new file mode 100644
index 0000000000000000000000000000000000000000..00f266f848f8e96a424ef2f3a67b3799816e9865
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728509894.node-1.1401053.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f9ded0612b24ea4ed4e8251ac149fc99d818dc4efe9544ec4b521b84d9204387
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510008.node-1.1402743.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510008.node-1.1402743.0
new file mode 100644
index 0000000000000000000000000000000000000000..48b4db97e9f573b845902ba5d8200e9ee3a03fd4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510008.node-1.1402743.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cb299c7fd7bb6c0f8017da19175a614e9bfb8b89ca272e4c28b8885e9949423b
+size 1143140
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510109.node-1.1404544.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510109.node-1.1404544.0
new file mode 100644
index 0000000000000000000000000000000000000000..053a92c7b9aa6263332e06cf96b929743aed29c6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510109.node-1.1404544.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b7f68f6bcda0bb2555931660172ab529217b75e243ca909a247cd8715b2d6676
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510223.node-1.1406276.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510223.node-1.1406276.0
new file mode 100644
index 0000000000000000000000000000000000000000..d7f18c9fb251629a4a7273689153f69b5475125f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510223.node-1.1406276.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4656c257e901deea87be30418c3e1e2e3dc9cad642eb45d91ac438df1565a980
+size 1749348
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510363.node-1.1408233.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510363.node-1.1408233.0
new file mode 100644
index 0000000000000000000000000000000000000000..478d83cc7297c5e9069b678b38366155c8fe459f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510363.node-1.1408233.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f42448c98409c955a52aa96e31e22ae16e6b254f6a1b42de8996a37a700c2ed5
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510438.node-1.1409781.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510438.node-1.1409781.0
new file mode 100644
index 0000000000000000000000000000000000000000..c09ba9ec745a5cfc089b8739268240c70734c724
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728510438.node-1.1409781.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:682836fe38ca40b3036e9dbd43acf0ff83575ace77d11c4f781e84d47532e3c0
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728511226.node-1.1421785.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728511226.node-1.1421785.0
new file mode 100644
index 0000000000000000000000000000000000000000..49f1009726e8b54fe0c5a1feebdb07532a4cd2ce
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728511226.node-1.1421785.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6eb418ba77535663ca54e503eac034810762907ab46c7ef7fc56119de3877cc2
+size 12474960
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512308.node-1.1436987.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512308.node-1.1436987.0
new file mode 100644
index 0000000000000000000000000000000000000000..23488d07bab460e37a2637bd461a51001aaa451d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512308.node-1.1436987.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:551cb28a08b980890a449ced08a5dc3a20b0993f6ca2e6b0897f2dc37e6e4607
+size 20213754
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512466.node-1.1439216.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512466.node-1.1439216.0
new file mode 100644
index 0000000000000000000000000000000000000000..c7ff955494faaf0986c1fd61a99e6451cc1d747b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512466.node-1.1439216.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:792d2b9ac5086e42d1060efd29c38bd27c6a49e95d395dc3775dc90187ca3cca
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512509.node-1.1440004.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512509.node-1.1440004.0
new file mode 100644
index 0000000000000000000000000000000000000000..acdeeed98f9ef9ce1aca5f4f1bff482dc8e36532
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728512509.node-1.1440004.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c71396c7eca35f1904f1ca204ff9be94e87c2721763689ef7fb4d0d52090734d
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728514975.node-1.1460521.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728514975.node-1.1460521.0
new file mode 100644
index 0000000000000000000000000000000000000000..a034573dad515a87f4b47e7262e94862496c04c0
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728514975.node-1.1460521.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eeb5bb9173997a8cc68c4ded9cc55952427ac85c295905a6af87dfcf82477593
+size 8904
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728546929.node-1.1555154.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728546929.node-1.1555154.0
new file mode 100644
index 0000000000000000000000000000000000000000..f4fd9f42f282a9440d5aa1a8fa46acc7874d87be
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728546929.node-1.1555154.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee82f8fbe5c8e762fb9fd9cd25a27ae462b7e323f70bf707ddceb217505f599f
+size 716
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728547108.node-1.1557312.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728547108.node-1.1557312.0
new file mode 100644
index 0000000000000000000000000000000000000000..4e754c61acbb1cea2fd88137e326e580db7d51c8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728547108.node-1.1557312.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9655e51b7c3e2790a2922ee791e4ca800298456ae6dc6392ea263cb20030841f
+size 33864
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549094.node-1.1599561.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549094.node-1.1599561.0
new file mode 100644
index 0000000000000000000000000000000000000000..57a27dd5a5eefa29f79e84de250848f8127ead68
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549094.node-1.1599561.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2da4c660c9b1f15101524960c5f22237d43a050a19537261602e37e320d02cbe
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549113.node-1.1599867.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549113.node-1.1599867.0
new file mode 100644
index 0000000000000000000000000000000000000000..b36800f45ba7ec88c1c8cb87f9d182ae611a07ba
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549113.node-1.1599867.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:48a204ee097af7ab183c7c2fd2a36827be3049532c5825d82d4d73f3adf52c76
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549169.node-1.1600992.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549169.node-1.1600992.0
new file mode 100644
index 0000000000000000000000000000000000000000..e59b784ef6f33eafb1e3387cdc7a2fe090ab95b3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549169.node-1.1600992.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f6d5cafa7657e38aeaef6ce9176a51cf01ffc3cd213009eab8ba451715548e4
+size 1177956
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549278.node-1.1602626.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549278.node-1.1602626.0
new file mode 100644
index 0000000000000000000000000000000000000000..01099f5c96cd2cffdfec0dbb2b118111d1c9216b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549278.node-1.1602626.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7cef1dc6fc34d341b5902490f0fd6dfbd824e0ff56a60dd8cf4fb95ef1bb45b
+size 10353674
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549384.node-1.1604690.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549384.node-1.1604690.0
new file mode 100644
index 0000000000000000000000000000000000000000..7e18b770c2efa2f8352ee18e086a6d9ec1ff27b1
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549384.node-1.1604690.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6534f5abf9aeeb52e71f5fb3e33087aa66ed13e568376907e4d62ee468ab6692
+size 10353674
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549468.node-1.1606132.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549468.node-1.1606132.0
new file mode 100644
index 0000000000000000000000000000000000000000..5ecddfebb3c1b19e8b393f90939d2e21e67dbaf7
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728549468.node-1.1606132.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb01656ef13de92ffd967ed9388d55c3855f63d904396ae59513612eb8dcba20
+size 26226930
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728644449.node-1.3758195.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728644449.node-1.3758195.0
new file mode 100644
index 0000000000000000000000000000000000000000..5f333c004855ae9da590be1b3deb3d59f76f5716
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728644449.node-1.3758195.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfabbf46b77b0e5dc36abf2d56c29e2296b61431ebf028a895f9ab8c54ccba6b
+size 716
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728644681.node-1.3762770.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728644681.node-1.3762770.0
new file mode 100644
index 0000000000000000000000000000000000000000..e239c33d2376bab97dfb2ba55c763984b542bb82
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728644681.node-1.3762770.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c8caa078ccf4e86d9896cc1ee66f3f0c77e83ba94839d29d7b0df73141de1643
+size 3856
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728645398.node-1.3770141.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728645398.node-1.3770141.0
new file mode 100644
index 0000000000000000000000000000000000000000..abeb7415d4071012737e4852ccecc3082719ccf8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728645398.node-1.3770141.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33526d4c6b8187c53ff8d4116c70d3bdeb961749764cb67ca7f5a0f162cdfa2f
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728645543.node-1.3772439.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728645543.node-1.3772439.0
new file mode 100644
index 0000000000000000000000000000000000000000..19d80745cac3c89bf164a83d788035f4ca8c6b05
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728645543.node-1.3772439.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:508e3074aa886213bb29ca55771242731d2a2bead768ab6e5d02f7880b4e2106
+size 26863728
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728979971.node-1.3878055.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728979971.node-1.3878055.0
new file mode 100644
index 0000000000000000000000000000000000000000..d4887c8d419e2183e811765940705b938891da4c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728979971.node-1.3878055.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a5078b88d04560e9848d24fe17b89793a3735d6af47b6812cc77488ec2e06b5
+size 1972
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728980219.node-1.3882534.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728980219.node-1.3882534.0
new file mode 100644
index 0000000000000000000000000000000000000000..1f19bdcbf8b17d162d34a6c4a0d6242fcae7107f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728980219.node-1.3882534.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:662075f09de281ffc7369405a090fb382db7ad84f18a565c5636ac88b1b1b479
+size 15944
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982418.node-1.3922374.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982418.node-1.3922374.0
new file mode 100644
index 0000000000000000000000000000000000000000..ec0734fa300da36013a4cb1b8c8db589938582a0
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982418.node-1.3922374.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:711fcd754302021296ca28f9c6427ac8e78bbbe07454ff31c7b40b3bd85dfe35
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982429.node-1.3922691.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982429.node-1.3922691.0
new file mode 100644
index 0000000000000000000000000000000000000000..d5419d3f3651e84510f7d833c5f138f03b80c074
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982429.node-1.3922691.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b83a092d0e0da3c2d79b43e0723ad2fdc3eb99e2a47ff85074e7590ffe299ddf
+size 1344
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982593.node-1.3927127.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982593.node-1.3927127.0
new file mode 100644
index 0000000000000000000000000000000000000000..b552fe88238c85d1bbc8f4e3f52e6f3ed9248741
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728982593.node-1.3927127.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9654fcf1590397703c01a11e3d757bd99da0bf01fbc31ffc6e97c78de7a7782
+size 10824
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728983684.node-1.3958141.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728983684.node-1.3958141.0
new file mode 100644
index 0000000000000000000000000000000000000000..82d142834a83ae73c7aef90903adf13e1b5d46ff
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1728983684.node-1.3958141.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a5c85fecf2e019a31ede15a54777624665665f8ffb7247a89473dc3478673a9
+size 322504
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729294875.node-1.2731766.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729294875.node-1.2731766.0
new file mode 100644
index 0000000000000000000000000000000000000000..8024fe5c416805da03d4c5da51dd41b327e0fc12
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729294875.node-1.2731766.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b48be5f9d6bf47c248d0387597976ac6a331f3c3cedd81886042a7cb6fda1564
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729294914.node-1.2732510.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729294914.node-1.2732510.0
new file mode 100644
index 0000000000000000000000000000000000000000..9bb9cf650e4075e0d344009cd39db74ad97d4096
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729294914.node-1.2732510.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:77452339a7f6a513721d58e487581c2036b9baace383cbecfab6b4df6edf9e94
+size 51784
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303723.node-1.2886912.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303723.node-1.2886912.0
new file mode 100644
index 0000000000000000000000000000000000000000..f3386c4c72a097d0c0460fc9d34c2decc9d8e2da
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303723.node-1.2886912.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:10be19a09bcebfe89130ac4bc613301afb841786af3867f586f4ece5669e3435
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303742.node-1.2887266.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303742.node-1.2887266.0
new file mode 100644
index 0000000000000000000000000000000000000000..07d2fac48dee5579b7803dce961f38671ceb790c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303742.node-1.2887266.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3270320eab02544321f469f3a8e705c80b49378050faf3dfe7d799f4b9cb45a7
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303823.node-1.2888113.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303823.node-1.2888113.0
new file mode 100644
index 0000000000000000000000000000000000000000..d653c508008d689315edea180b344ea586604d58
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729303823.node-1.2888113.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:318204bf5c77de485661b1d356cbb27cf6525e4875bce6e481813dff10379762
+size 54984
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729327950.node-1.3039261.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729327950.node-1.3039261.0
new file mode 100644
index 0000000000000000000000000000000000000000..b3069135a799ce7024d4a8d1dcc747dbdc60a38f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729327950.node-1.3039261.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b60cd1628924beb9edae9e527ea696ec08ba96a5ec3318f01d4bdb3779fdf84
+size 3228
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328371.node-1.3049098.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328371.node-1.3049098.0
new file mode 100644
index 0000000000000000000000000000000000000000..305f1a7fe73d2d60187d1cb960457337e18aa7ca
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328371.node-1.3049098.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45ae36671e2d59975d2d137a015111bba55b4dbe4598cf34db5c3ad9a62251c0
+size 2600
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328736.node-1.3056977.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328736.node-1.3056977.0
new file mode 100644
index 0000000000000000000000000000000000000000..da534549d0c17d01ebffc068c51f58675a3254fe
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328736.node-1.3056977.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5249351aabb69b1f8d10e4815151ffb4acad52e106253cb45f9421d7c31a260
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328839.node-1.3057462.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328839.node-1.3057462.0
new file mode 100644
index 0000000000000000000000000000000000000000..ce684e638cafbef85d331d1af78c5206df38f80d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729328839.node-1.3057462.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90a956a1a3886d4ffda4497c477176f05714a965b3b53d28716de9bce4ba3e5c
+size 5740
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729329519.node-1.3074545.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729329519.node-1.3074545.0
new file mode 100644
index 0000000000000000000000000000000000000000..2ed974a15e6d938f0116b12e0ec9d096fee6de6e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729329519.node-1.3074545.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1250999411b91ed2ea59941a3840e5adc119e0ec7e69cb84db1c8136553d4596
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729329715.node-1.3075970.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729329715.node-1.3075970.0
new file mode 100644
index 0000000000000000000000000000000000000000..c988baa83a710d5c03ef23620b974d4a08969cd6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729329715.node-1.3075970.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5e56af97cb89592f5c230189ff97deaa9526358b5512437ad3b9f0e4d293b41
+size 2600
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729330165.node-1.3085370.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729330165.node-1.3085370.0
new file mode 100644
index 0000000000000000000000000000000000000000..7aa69a71191dd2a2f1534ec877312b312b3bf96c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729330165.node-1.3085370.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c2bfa4432554da0c4d3fb07d4130886d0fd520b186ea36f86b1cc7112d722f1a
+size 1344
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729330386.node-1.3090061.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729330386.node-1.3090061.0
new file mode 100644
index 0000000000000000000000000000000000000000..07d7b14545f0e32c3b583ad10d2db3aa7bcdac84
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729330386.node-1.3090061.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:62f8df3dc34ca0ff946b5a75c8a067357128059c162035898c1c27efe0ead5a8
+size 7624
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729331321.node-1.3112479.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729331321.node-1.3112479.0
new file mode 100644
index 0000000000000000000000000000000000000000..26912e28de780b2edd45fa86aa400cb37cd899f9
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729331321.node-1.3112479.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2ba03d9edf26e1201394986bb06be3c46b3319ef31b92247fde97820aa78d43
+size 58824
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339178.node-1.3272954.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339178.node-1.3272954.0
new file mode 100644
index 0000000000000000000000000000000000000000..cb26d95c583185b94ef8cd46187fc17b5117a4de
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339178.node-1.3272954.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2615428643e3578c76f4bf8e31ef300e868c568bdff284194fa987c60b99603
+size 1344
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339469.node-1.3279495.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339469.node-1.3279495.0
new file mode 100644
index 0000000000000000000000000000000000000000..6395b352fcfe235ff67666f5c759efca94a00ed3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339469.node-1.3279495.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6c37c48d232237e737ba8a31da67e4845035dfc991e4a8b8b4e1cdbbf19fb996
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339574.node-1.3281697.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339574.node-1.3281697.0
new file mode 100644
index 0000000000000000000000000000000000000000..e3dd50e195be492ebc9c641c5695c03005e54e3c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729339574.node-1.3281697.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1054eac4801382f231072b6a83ccfa2efd4f231eec525127d07de5b2704f224
+size 102984
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350045.node-1.3513360.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350045.node-1.3513360.0
new file mode 100644
index 0000000000000000000000000000000000000000..c89ce90d02db1b7a054d8e24176cfe8b4f77bb38
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350045.node-1.3513360.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ead07f73b3db68b9e5722d419ae7c619494c7f915efce20444534fb99c26cf76
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350148.node-1.3514408.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350148.node-1.3514408.0
new file mode 100644
index 0000000000000000000000000000000000000000..b85763c44c7d71152644800aa951ca7242c66a7f
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350148.node-1.3514408.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aea50cd015d362ffa98c23d9b67f46c70d425b2ac83d28136c59698faa57129a
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350905.node-1.3516108.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350905.node-1.3516108.0
new file mode 100644
index 0000000000000000000000000000000000000000..42c061128ebfd34646ee07d8b04fc438f59935a4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350905.node-1.3516108.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:21980949171373d0d0a6f55438ea3a6ae3abb6792844fb82e5076b6e23ac5d58
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350936.node-1.3516556.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350936.node-1.3516556.0
new file mode 100644
index 0000000000000000000000000000000000000000..ac3285d5260f493689f57f3dc9bcd50955349936
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729350936.node-1.3516556.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d9cfe706a9363ec350d38b81b05dbffbf3d192a8ace7ee128836c53d8202f6
+size 716
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351098.node-1.3519552.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351098.node-1.3519552.0
new file mode 100644
index 0000000000000000000000000000000000000000..ec77238146701445db27e20a83a945f7595637ed
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351098.node-1.3519552.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b47be511804d2ab41622ba215b485a2539d483f5aa014fb233740aea86e0c057
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351146.node-1.3520338.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351146.node-1.3520338.0
new file mode 100644
index 0000000000000000000000000000000000000000..2210efc5866d6eb29a6a62c7e8dcde3a618ddc30
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351146.node-1.3520338.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1583baa5054da62a47462c8268e9ed9406576fb82d3f32c15ab88d7f192cf64
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351203.node-1.3521373.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351203.node-1.3521373.0
new file mode 100644
index 0000000000000000000000000000000000000000..890fad97a857f255b4647163ce40919183472d7a
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351203.node-1.3521373.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95f3b237334fd07abbcabbeeb3c53225aaea349048f47b2edf9cda45667d0ade
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351313.node-1.3523566.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351313.node-1.3523566.0
new file mode 100644
index 0000000000000000000000000000000000000000..f728e25ebe1e7c40ce2e19030846fd299b688aaf
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351313.node-1.3523566.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93e6d4662994da6a2326b75d4d93ed00ffc75ec6d88761d4ec481c955120f50c
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351395.node-1.3525570.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351395.node-1.3525570.0
new file mode 100644
index 0000000000000000000000000000000000000000..b6475129f0dd453f2b4cc4aab731e6bee879a3d6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351395.node-1.3525570.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89499b7bbbc487e79267c636a18e88eaf30b34823276725970e299f1c35a3b88
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351584.node-1.3526725.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351584.node-1.3526725.0
new file mode 100644
index 0000000000000000000000000000000000000000..6eca7dbb13e0a5c65caaa04b1862d6fb5fc056c4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351584.node-1.3526725.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc80399d3cedffbec423a0944bed1406a8e613377ca135027b3a2ff8c9b83acf
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351656.node-1.3527680.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351656.node-1.3527680.0
new file mode 100644
index 0000000000000000000000000000000000000000..9624edbf6da51b7a8fa030d7c0f72d27875f4caf
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351656.node-1.3527680.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:912ea5cfb0f4d0786d655ae63600779ce2d8f9e06e1017b2069afc50a0bfc68a
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351705.node-1.3528448.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351705.node-1.3528448.0
new file mode 100644
index 0000000000000000000000000000000000000000..a16a50eb84898d8170832f22df67d993098c2696
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729351705.node-1.3528448.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cddfce93e7d5c9cb4cfeab74f4822cea6cee285e6585c686e5e039adb0de61b
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352225.node-1.3529938.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352225.node-1.3529938.0
new file mode 100644
index 0000000000000000000000000000000000000000..24d142fd0706942e06ece852f1e65baeaefc8e4a
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352225.node-1.3529938.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75a1d803390ebfb7039b5360f88a6d8be2972bdab3e7a8438d686e448a90d92c
+size 716
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352337.node-1.3532266.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352337.node-1.3532266.0
new file mode 100644
index 0000000000000000000000000000000000000000..9dc1eebf1ecfc41f28a955e60d7a5e94238cefbb
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352337.node-1.3532266.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:938c0e4711c05b49bab94f234a986d2db3920b1610f346c2150acfbddbea748f
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352705.node-1.3537325.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352705.node-1.3537325.0
new file mode 100644
index 0000000000000000000000000000000000000000..64dd6b3c5a38625dda2a8abb3f926b730e9d71dd
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352705.node-1.3537325.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c97e773db236c08398ff49fe75898e21e331faa9045f5b6adc164e2c8f32be5b
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352812.node-1.3539723.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352812.node-1.3539723.0
new file mode 100644
index 0000000000000000000000000000000000000000..412d1107889e3ed8d38a80a824b71fe9f8b78472
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352812.node-1.3539723.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cdc1e4510a348091af69839df26e403889931b48faf491f50059e400c05d3ca
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352997.node-1.3543604.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352997.node-1.3543604.0
new file mode 100644
index 0000000000000000000000000000000000000000..8547c53421242e85e6ae97b24126423b8b2eaee6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729352997.node-1.3543604.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:351bf5242d9e2f18c5c0a75e3fe3197e55aa5dd04f0123eb68f4ff8a112ce88c
+size 1344
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729353330.node-1.3549111.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729353330.node-1.3549111.0
new file mode 100644
index 0000000000000000000000000000000000000000..626cb403e2aa6b7f242798e9aacfd1c4589b265e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729353330.node-1.3549111.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54c3967e735f17eb1d60628ee033c1d8d4a346ee693b6c49f4a6f81584303f4e
+size 1972
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729354011.node-1.3565052.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729354011.node-1.3565052.0
new file mode 100644
index 0000000000000000000000000000000000000000..5383bcf87f81b3202f456ff8da6fb2eec3904534
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729354011.node-1.3565052.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b28bd84112ee687c269afc082936d9087cdff5906cafcb70b541409ba5529452
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729354116.node-1.3567445.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729354116.node-1.3567445.0
new file mode 100644
index 0000000000000000000000000000000000000000..187c9d097ac7221c0ebd0268bab6a064b37463e4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729354116.node-1.3567445.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66b7baf399618002b609c928706803abbd6434f0761324f12342e061deba2f48
+size 18504
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729356306.node-1.3618453.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729356306.node-1.3618453.0
new file mode 100644
index 0000000000000000000000000000000000000000..f60ec70f6b01d589a2a5408ee8265307e85807f5
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729356306.node-1.3618453.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bec6b47321a87e1ddad8620d2a041acffb0b048ffbc26c92e5e1c76100b8a700
+size 53704
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729362124.node-1.3743068.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729362124.node-1.3743068.0
new file mode 100644
index 0000000000000000000000000000000000000000..ca3420074e2f2163835020c6d53a5df0447c03b4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729362124.node-1.3743068.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56d57ad1271c7aba28f7cfc9a4cdfd3caf7e6e2b835f67b1f49f127bf4d11db7
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729362193.node-1.3744335.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729362193.node-1.3744335.0
new file mode 100644
index 0000000000000000000000000000000000000000..d1980dd126f3d91e7b74d0fc7e15ec9c61ee060a
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729362193.node-1.3744335.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb4901d86b6f438aa91fe4fe9911a631f030bc078dd409ab633cd5aacf56fc1e
+size 451784
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729411163.node-1.523489.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729411163.node-1.523489.0
new file mode 100644
index 0000000000000000000000000000000000000000..34493f9931b86220023fa253643d637aa8c885c0
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729411163.node-1.523489.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e2bdea5afdff4a3e89e8d28b13753be827d1a21851bb37cdaba96c24786da5bc
+size 1344
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729699748.node-1.1333787.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729699748.node-1.1333787.0
new file mode 100644
index 0000000000000000000000000000000000000000..c8f64075f18177f0d22c4d6d08c5bd1f6cf71bf4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729699748.node-1.1333787.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f1994ab89a011f089a4caeb82c928fa5820c9f6ed69245d7e2f675ecf9fcea4
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729699815.node-1.1334619.0 b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729699815.node-1.1334619.0
new file mode 100644
index 0000000000000000000000000000000000000000..a6d9b6f415dad60c73e92c36a7aaa2c107b144d6
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/tensorboard/events.out.tfevents.1729699815.node-1.1334619.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9490daa88a8ec6657230c850fba64cb219231dfbb1892acffe70d35918ffacf
+size 88
diff --git a/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/train.log b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/train.log
new file mode 100644
index 0000000000000000000000000000000000000000..f21909e35d5da7d2839576cd553756d8232ff3ec
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Models/Style_Kanade_48khz/train.log
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f2fd3e2a47413fc1d956df0234b3e2a48ed8f72752fd894dc435f9c8e2d8a34
+size 13066864
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__init__.py b/stts_48khz/StyleTTS2_48khz/Modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/__init__.py
@@ -0,0 +1 @@
+
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6c91763ac76a51b042d30c354b5e00222668e06
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1cd33090f0ca6280c3c437d46c25a60a8d14d09
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f209a95bdf83f0745f02866b40cd873732f5102c
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b097bbd355c0a0139e87bce18e54bc9a7dad3a62
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ab91625bb9deba3a188a1c8fecc125e3c74322c
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..503c054aa82576387372bc797f171a847bdca098
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/discriminators.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/hifigan.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/hifigan.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8b52707d718be5687c72eeb905dc4577fd4df1a
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/hifigan.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ccfe9e19b57de14c7bba70a621951a6234a8de1
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..787d49b5d1e8878e3e7445c82ce272d9f2b23099
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b964b8f6c9502446e5312788ec0845674a8826a1
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/istftnet.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/slmadv.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/slmadv.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a0d16b7bc8e128c54c9865c2e53c850c0475b0f
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/slmadv.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..755551c91ee299efb0a63fcc563273e66ce026a9
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dac1fa78045599f8b6b5fc0fdd4c79334f87a31d
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe897bc5ea7c7d986ed3b85c00ed73ed16f2f767
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/__pycache__/utils.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__init__.py b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__init__.py
@@ -0,0 +1 @@
+
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d27c56cfc7be01e90961c778f42e2232dd67ecb0
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27a9f04cec9e931982554967e29005dc6e4392ed
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b8cfd23db0c37e4c1ae7e665b8b308ceb06e794
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2abc47ba98ff21e96e87f4efb12a7bb5523fcb3f
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4159f9f8662c54be5e784f07c947aea40c19ec5
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6bd95c24fd88e43a8429c50f463cf7a8443b5779
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/diffusion.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a23277e94ce35f01248a98f4edc6cdd48d2ed4e4
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e960ba0c1f8f47db3e1feba78142430082c54a67
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a2973dee6978fdb2ad34d381742db8bedebf298
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/modules.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77b6346eec451328527277141a2e3da97b391e9c
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f5e5d925c7a95d1b6f4f16d892a73e9cdc7ca54
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83f8f2ec069ae65ea0467a95559bc75056a4fede
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/sampler.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90e01780589df03d4ce609905373be1d244e8cfa
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e759f7923b7855ac669d3198c6059ffa82cdde6
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..980bc9b5601b5eeefec20693523e5932f16409de
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/__pycache__/utils.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/diffusion.py b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f8d40041e4493f1d037cdbd33bd25091ffbdfe4
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/diffusion.py
@@ -0,0 +1,94 @@
+from math import pi
+from random import randint
+from typing import Any, Optional, Sequence, Tuple, Union
+
+import torch
+from einops import rearrange
+from torch import Tensor, nn
+from tqdm import tqdm
+
+from .utils import *
+from .sampler import *
+
+"""
+Diffusion Classes (generic for 1d data)
+"""
+
+
+class Model1d(nn.Module):
+ def __init__(self, unet_type: str = "base", **kwargs):
+ super().__init__()
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
+ self.unet = None
+ self.diffusion = None
+
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
+ return self.diffusion(x, **kwargs)
+
+ def sample(self, *args, **kwargs) -> Tensor:
+ return self.diffusion.sample(*args, **kwargs)
+
+
+"""
+Audio Diffusion Classes (specific for 1d audio data)
+"""
+
+
+def get_default_model_kwargs():
+ return dict(
+ channels=128,
+ patch_size=16,
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
+ factors=[4, 4, 4, 2, 2, 2],
+ num_blocks=[2, 2, 2, 2, 2, 2],
+ attentions=[0, 0, 0, 1, 1, 1, 1],
+ attention_heads=8,
+ attention_features=64,
+ attention_multiplier=2,
+ attention_use_rel_pos=False,
+ diffusion_type="v",
+ diffusion_sigma_distribution=UniformDistribution(),
+ )
+
+
+def get_default_sampling_kwargs():
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
+
+
+class AudioDiffusionModel(Model1d):
+ def __init__(self, **kwargs):
+ super().__init__(**{**get_default_model_kwargs(), **kwargs})
+
+ def sample(self, *args, **kwargs):
+ return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
+
+
+class AudioDiffusionConditional(Model1d):
+ def __init__(
+ self,
+ embedding_features: int,
+ embedding_max_length: int,
+ embedding_mask_proba: float = 0.1,
+ **kwargs,
+ ):
+ self.embedding_mask_proba = embedding_mask_proba
+ default_kwargs = dict(
+ **get_default_model_kwargs(),
+ unet_type="cfg",
+ context_embedding_features=embedding_features,
+ context_embedding_max_length=embedding_max_length,
+ )
+ super().__init__(**{**default_kwargs, **kwargs})
+
+ def forward(self, *args, **kwargs):
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
+ return super().forward(*args, **{**default_kwargs, **kwargs})
+
+ def sample(self, *args, **kwargs):
+ default_kwargs = dict(
+ **get_default_sampling_kwargs(),
+ embedding_scale=5.0,
+ )
+ return super().sample(*args, **{**default_kwargs, **kwargs})
+
+
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/modules.py b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..9df85b993c3839ab01df6b6168334bdbf3cb1f30
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/modules.py
@@ -0,0 +1,693 @@
+from math import floor, log, pi
+from typing import Any, List, Optional, Sequence, Tuple, Union
+
+from .utils import *
+
+import torch
+import torch.nn as nn
+from einops import rearrange, reduce, repeat
+from einops.layers.torch import Rearrange
+from einops_exts import rearrange_many
+from torch import Tensor, einsum
+
+
+"""
+Utils
+"""
+
+class AdaLayerNorm(nn.Module):
+ def __init__(self, style_dim, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.fc = nn.Linear(style_dim, channels*2)
+
+ def forward(self, x, s):
+ x = x.transpose(-1, -2)
+ x = x.transpose(1, -1)
+
+ h = self.fc(s)
+ h = h.view(h.size(0), h.size(1), 1)
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
+
+
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
+ x = (1 + gamma) * x + beta
+ return x.transpose(1, -1).transpose(-1, -2)
+
+class StyleTransformer1d(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ channels: int,
+ num_heads: int,
+ head_features: int,
+ multiplier: int,
+ use_context_time: bool = True,
+ use_rel_pos: bool = False,
+ context_features_multiplier: int = 1,
+ rel_pos_num_buckets: Optional[int] = None,
+ rel_pos_max_distance: Optional[int] = None,
+ context_features: Optional[int] = None,
+ context_embedding_features: Optional[int] = None,
+ embedding_max_length: int = 512,
+ ):
+ super().__init__()
+
+ self.blocks = nn.ModuleList(
+ [
+ StyleTransformerBlock(
+ features=channels + context_embedding_features,
+ head_features=head_features,
+ num_heads=num_heads,
+ multiplier=multiplier,
+ style_dim=context_features,
+ use_rel_pos=use_rel_pos,
+ rel_pos_num_buckets=rel_pos_num_buckets,
+ rel_pos_max_distance=rel_pos_max_distance,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.to_out = nn.Sequential(
+ Rearrange("b t c -> b c t"),
+ nn.Conv1d(
+ in_channels=channels + context_embedding_features,
+ out_channels=channels,
+ kernel_size=1,
+ ),
+ )
+
+ use_context_features = exists(context_features)
+ self.use_context_features = use_context_features
+ self.use_context_time = use_context_time
+
+ if use_context_time or use_context_features:
+ context_mapping_features = channels + context_embedding_features
+
+ self.to_mapping = nn.Sequential(
+ nn.Linear(context_mapping_features, context_mapping_features),
+ nn.GELU(),
+ nn.Linear(context_mapping_features, context_mapping_features),
+ nn.GELU(),
+ )
+
+ if use_context_time:
+ assert exists(context_mapping_features)
+ self.to_time = nn.Sequential(
+ TimePositionalEmbedding(
+ dim=channels, out_features=context_mapping_features
+ ),
+ nn.GELU(),
+ )
+
+ if use_context_features:
+ assert exists(context_features) and exists(context_mapping_features)
+ self.to_features = nn.Sequential(
+ nn.Linear(
+ in_features=context_features, out_features=context_mapping_features
+ ),
+ nn.GELU(),
+ )
+
+ self.fixed_embedding = FixedEmbedding(
+ max_length=embedding_max_length, features=context_embedding_features
+ )
+
+
+ def get_mapping(
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
+ ) -> Optional[Tensor]:
+ """Combines context time features and features into mapping"""
+ items, mapping = [], None
+ # Compute time features
+ if self.use_context_time:
+ assert_message = "use_context_time=True but no time features provided"
+ assert exists(time), assert_message
+ items += [self.to_time(time)]
+ # Compute features
+ if self.use_context_features:
+ assert_message = "context_features exists but no features provided"
+ assert exists(features), assert_message
+ items += [self.to_features(features)]
+
+ # Compute joint mapping
+ if self.use_context_time or self.use_context_features:
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
+ mapping = self.to_mapping(mapping)
+
+ return mapping
+
+ def run(self, x, time, embedding, features):
+
+ mapping = self.get_mapping(time, features)
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
+
+ for block in self.blocks:
+ x = x + mapping
+ x = block(x, features)
+
+ x = x.mean(axis=1).unsqueeze(1)
+ x = self.to_out(x)
+ x = x.transpose(-1, -2)
+
+ return x
+
+ def forward(self, x: Tensor,
+ time: Tensor,
+ embedding_mask_proba: float = 0.0,
+ embedding: Optional[Tensor] = None,
+ features: Optional[Tensor] = None,
+ embedding_scale: float = 1.0) -> Tensor:
+
+ b, device = embedding.shape[0], embedding.device
+ fixed_embedding = self.fixed_embedding(embedding)
+ if embedding_mask_proba > 0.0:
+ # Randomly mask embedding
+ batch_mask = rand_bool(
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
+ )
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
+
+ if embedding_scale != 1.0:
+ # Compute both normal and fixed embedding outputs
+ out = self.run(x, time, embedding=embedding, features=features)
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
+ # Scale conditional output using classifier-free guidance
+ return out_masked + (out - out_masked) * embedding_scale
+ else:
+ return self.run(x, time, embedding=embedding, features=features)
+
+ return x
+
+
+class StyleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ num_heads: int,
+ head_features: int,
+ style_dim: int,
+ multiplier: int,
+ use_rel_pos: bool,
+ rel_pos_num_buckets: Optional[int] = None,
+ rel_pos_max_distance: Optional[int] = None,
+ context_features: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.use_cross_attention = exists(context_features) and context_features > 0
+
+ self.attention = StyleAttention(
+ features=features,
+ style_dim=style_dim,
+ num_heads=num_heads,
+ head_features=head_features,
+ use_rel_pos=use_rel_pos,
+ rel_pos_num_buckets=rel_pos_num_buckets,
+ rel_pos_max_distance=rel_pos_max_distance,
+ )
+
+ if self.use_cross_attention:
+ self.cross_attention = StyleAttention(
+ features=features,
+ style_dim=style_dim,
+ num_heads=num_heads,
+ head_features=head_features,
+ context_features=context_features,
+ use_rel_pos=use_rel_pos,
+ rel_pos_num_buckets=rel_pos_num_buckets,
+ rel_pos_max_distance=rel_pos_max_distance,
+ )
+
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
+
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
+ x = self.attention(x, s) + x
+ if self.use_cross_attention:
+ x = self.cross_attention(x, s, context=context) + x
+ x = self.feed_forward(x) + x
+ return x
+
+class StyleAttention(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ *,
+ style_dim: int,
+ head_features: int,
+ num_heads: int,
+ context_features: Optional[int] = None,
+ use_rel_pos: bool,
+ rel_pos_num_buckets: Optional[int] = None,
+ rel_pos_max_distance: Optional[int] = None,
+ ):
+ super().__init__()
+ self.context_features = context_features
+ mid_features = head_features * num_heads
+ context_features = default(context_features, features)
+
+ self.norm = AdaLayerNorm(style_dim, features)
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
+ self.to_q = nn.Linear(
+ in_features=features, out_features=mid_features, bias=False
+ )
+ self.to_kv = nn.Linear(
+ in_features=context_features, out_features=mid_features * 2, bias=False
+ )
+ self.attention = AttentionBase(
+ features,
+ num_heads=num_heads,
+ head_features=head_features,
+ use_rel_pos=use_rel_pos,
+ rel_pos_num_buckets=rel_pos_num_buckets,
+ rel_pos_max_distance=rel_pos_max_distance,
+ )
+
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
+ assert_message = "You must provide a context when using context_features"
+ assert not self.context_features or exists(context), assert_message
+ # Use context if provided
+ context = default(context, x)
+ # Normalize then compute q from input and k,v from context
+ x, context = self.norm(x, s), self.norm_context(context, s)
+
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
+ # Compute and return attention
+ return self.attention(q, k, v)
+
+class Transformer1d(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ channels: int,
+ num_heads: int,
+ head_features: int,
+ multiplier: int,
+ use_context_time: bool = True,
+ use_rel_pos: bool = False,
+ context_features_multiplier: int = 1,
+ rel_pos_num_buckets: Optional[int] = None,
+ rel_pos_max_distance: Optional[int] = None,
+ context_features: Optional[int] = None,
+ context_embedding_features: Optional[int] = None,
+ embedding_max_length: int = 512,
+ ):
+ super().__init__()
+
+ self.blocks = nn.ModuleList(
+ [
+ TransformerBlock(
+ features=channels + context_embedding_features,
+ head_features=head_features,
+ num_heads=num_heads,
+ multiplier=multiplier,
+ use_rel_pos=use_rel_pos,
+ rel_pos_num_buckets=rel_pos_num_buckets,
+ rel_pos_max_distance=rel_pos_max_distance,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.to_out = nn.Sequential(
+ Rearrange("b t c -> b c t"),
+ nn.Conv1d(
+ in_channels=channels + context_embedding_features,
+ out_channels=channels,
+ kernel_size=1,
+ ),
+ )
+
+ use_context_features = exists(context_features)
+ self.use_context_features = use_context_features
+ self.use_context_time = use_context_time
+
+ if use_context_time or use_context_features:
+ context_mapping_features = channels + context_embedding_features
+
+ self.to_mapping = nn.Sequential(
+ nn.Linear(context_mapping_features, context_mapping_features),
+ nn.GELU(),
+ nn.Linear(context_mapping_features, context_mapping_features),
+ nn.GELU(),
+ )
+
+ if use_context_time:
+ assert exists(context_mapping_features)
+ self.to_time = nn.Sequential(
+ TimePositionalEmbedding(
+ dim=channels, out_features=context_mapping_features
+ ),
+ nn.GELU(),
+ )
+
+ if use_context_features:
+ assert exists(context_features) and exists(context_mapping_features)
+ self.to_features = nn.Sequential(
+ nn.Linear(
+ in_features=context_features, out_features=context_mapping_features
+ ),
+ nn.GELU(),
+ )
+
+ self.fixed_embedding = FixedEmbedding(
+ max_length=embedding_max_length, features=context_embedding_features
+ )
+
+
+ def get_mapping(
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
+ ) -> Optional[Tensor]:
+ """Combines context time features and features into mapping"""
+ items, mapping = [], None
+ # Compute time features
+ if self.use_context_time:
+ assert_message = "use_context_time=True but no time features provided"
+ assert exists(time), assert_message
+ items += [self.to_time(time)]
+ # Compute features
+ if self.use_context_features:
+ assert_message = "context_features exists but no features provided"
+ assert exists(features), assert_message
+ items += [self.to_features(features)]
+
+ # Compute joint mapping
+ if self.use_context_time or self.use_context_features:
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
+ mapping = self.to_mapping(mapping)
+
+ return mapping
+
+ def run(self, x, time, embedding, features):
+
+ mapping = self.get_mapping(time, features)
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
+
+ for block in self.blocks:
+ x = x + mapping
+ x = block(x)
+
+ x = x.mean(axis=1).unsqueeze(1)
+ x = self.to_out(x)
+ x = x.transpose(-1, -2)
+
+ return x
+
+ def forward(self, x: Tensor,
+ time: Tensor,
+ embedding_mask_proba: float = 0.0,
+ embedding: Optional[Tensor] = None,
+ features: Optional[Tensor] = None,
+ embedding_scale: float = 1.0) -> Tensor:
+
+ b, device = embedding.shape[0], embedding.device
+ fixed_embedding = self.fixed_embedding(embedding)
+ if embedding_mask_proba > 0.0:
+ # Randomly mask embedding
+ batch_mask = rand_bool(
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
+ )
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
+
+ if embedding_scale != 1.0:
+ # Compute both normal and fixed embedding outputs
+ out = self.run(x, time, embedding=embedding, features=features)
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
+ # Scale conditional output using classifier-free guidance
+ return out_masked + (out - out_masked) * embedding_scale
+ else:
+ return self.run(x, time, embedding=embedding, features=features)
+
+ return x
+
+
+"""
+Attention Components
+"""
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.num_heads = num_heads
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
+
+ @staticmethod
+ def _relative_position_bucket(
+ relative_position: Tensor, num_buckets: int, max_distance: int
+ ):
+ num_buckets //= 2
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
+ n = torch.abs(relative_position)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = (
+ max_exact
+ + (
+ torch.log(n.float() / max_exact)
+ / log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).long()
+ )
+ val_if_large = torch.min(
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
+ )
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
+
+ relative_position_bucket = self._relative_position_bucket(
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
+ )
+
+ bias = self.relative_attention_bias(relative_position_bucket)
+ bias = rearrange(bias, "m n h -> 1 h m n")
+ return bias
+
+
+def FeedForward(features: int, multiplier: int) -> nn.Module:
+ mid_features = features * multiplier
+ return nn.Sequential(
+ nn.Linear(in_features=features, out_features=mid_features),
+ nn.GELU(),
+ nn.Linear(in_features=mid_features, out_features=features),
+ )
+
+
+class AttentionBase(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ *,
+ head_features: int,
+ num_heads: int,
+ use_rel_pos: bool,
+ out_features: Optional[int] = None,
+ rel_pos_num_buckets: Optional[int] = None,
+ rel_pos_max_distance: Optional[int] = None,
+ ):
+ super().__init__()
+ self.scale = head_features ** -0.5
+ self.num_heads = num_heads
+ self.use_rel_pos = use_rel_pos
+ mid_features = head_features * num_heads
+
+ if use_rel_pos:
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
+ self.rel_pos = RelativePositionBias(
+ num_buckets=rel_pos_num_buckets,
+ max_distance=rel_pos_max_distance,
+ num_heads=num_heads,
+ )
+ if out_features is None:
+ out_features = features
+
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Split heads
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
+ # Compute similarity matrix
+ sim = einsum("... n d, ... m d -> ... n m", q, k)
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
+ sim = sim * self.scale
+ # Get attention matrix with softmax
+ attn = sim.softmax(dim=-1)
+ # Compute values
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ *,
+ head_features: int,
+ num_heads: int,
+ out_features: Optional[int] = None,
+ context_features: Optional[int] = None,
+ use_rel_pos: bool,
+ rel_pos_num_buckets: Optional[int] = None,
+ rel_pos_max_distance: Optional[int] = None,
+ ):
+ super().__init__()
+ self.context_features = context_features
+ mid_features = head_features * num_heads
+ context_features = default(context_features, features)
+
+ self.norm = nn.LayerNorm(features)
+ self.norm_context = nn.LayerNorm(context_features)
+ self.to_q = nn.Linear(
+ in_features=features, out_features=mid_features, bias=False
+ )
+ self.to_kv = nn.Linear(
+ in_features=context_features, out_features=mid_features * 2, bias=False
+ )
+
+ self.attention = AttentionBase(
+ features,
+ out_features=out_features,
+ num_heads=num_heads,
+ head_features=head_features,
+ use_rel_pos=use_rel_pos,
+ rel_pos_num_buckets=rel_pos_num_buckets,
+ rel_pos_max_distance=rel_pos_max_distance,
+ )
+
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
+ assert_message = "You must provide a context when using context_features"
+ assert not self.context_features or exists(context), assert_message
+ # Use context if provided
+ context = default(context, x)
+ # Normalize then compute q from input and k,v from context
+ x, context = self.norm(x), self.norm_context(context)
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
+ # Compute and return attention
+ return self.attention(q, k, v)
+
+
+"""
+Transformer Blocks
+"""
+
+
+class TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ num_heads: int,
+ head_features: int,
+ multiplier: int,
+ use_rel_pos: bool,
+ rel_pos_num_buckets: Optional[int] = None,
+ rel_pos_max_distance: Optional[int] = None,
+ context_features: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.use_cross_attention = exists(context_features) and context_features > 0
+
+ self.attention = Attention(
+ features=features,
+ num_heads=num_heads,
+ head_features=head_features,
+ use_rel_pos=use_rel_pos,
+ rel_pos_num_buckets=rel_pos_num_buckets,
+ rel_pos_max_distance=rel_pos_max_distance,
+ )
+
+ if self.use_cross_attention:
+ self.cross_attention = Attention(
+ features=features,
+ num_heads=num_heads,
+ head_features=head_features,
+ context_features=context_features,
+ use_rel_pos=use_rel_pos,
+ rel_pos_num_buckets=rel_pos_num_buckets,
+ rel_pos_max_distance=rel_pos_max_distance,
+ )
+
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
+
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
+ x = self.attention(x) + x
+ if self.use_cross_attention:
+ x = self.cross_attention(x, context=context) + x
+ x = self.feed_forward(x) + x
+ return x
+
+
+
+"""
+Time Embeddings
+"""
+
+
+class SinusoidalEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ device, half_dim = x.device, self.dim // 2
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
+
+
+class LearnedPositionalEmbedding(nn.Module):
+ """Used for continuous time"""
+
+ def __init__(self, dim: int):
+ super().__init__()
+ assert (dim % 2) == 0
+ half_dim = dim // 2
+ self.weights = nn.Parameter(torch.randn(half_dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = rearrange(x, "b -> b 1")
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
+ fouriered = torch.cat((x, fouriered), dim=-1)
+ return fouriered
+
+
+def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
+ return nn.Sequential(
+ LearnedPositionalEmbedding(dim),
+ nn.Linear(in_features=dim + 1, out_features=out_features),
+ )
+
+class FixedEmbedding(nn.Module):
+ def __init__(self, max_length: int, features: int):
+ super().__init__()
+ self.max_length = max_length
+ self.embedding = nn.Embedding(max_length, features)
+
+ def forward(self, x: Tensor) -> Tensor:
+ batch_size, length, device = *x.shape[0:2], x.device
+ assert_message = "Input sequence length must be <= max_length"
+ assert length <= self.max_length, assert_message
+ position = torch.arange(length, device=device)
+ fixed_embedding = self.embedding(position)
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
+ return fixed_embedding
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/sampler.py b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..467bd0d16093564e3ddcfe9033d2097986d54a27
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/sampler.py
@@ -0,0 +1,691 @@
+from math import atan, cos, pi, sin, sqrt
+from typing import Any, Callable, List, Optional, Tuple, Type
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, reduce
+from torch import Tensor
+
+from .utils import *
+
+"""
+Diffusion Training
+"""
+
+""" Distributions """
+
+
+class Distribution:
+ def __call__(self, num_samples: int, device: torch.device):
+ raise NotImplementedError()
+
+
+class LogNormalDistribution(Distribution):
+ def __init__(self, mean: float, std: float):
+ self.mean = mean
+ self.std = std
+
+ def __call__(
+ self, num_samples: int, device: torch.device = torch.device("cpu")
+ ) -> Tensor:
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
+ return normal.exp()
+
+
+class UniformDistribution(Distribution):
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
+ return torch.rand(num_samples, device=device)
+
+
+class VKDistribution(Distribution):
+ def __init__(
+ self,
+ min_value: float = 0.0,
+ max_value: float = float("inf"),
+ sigma_data: float = 1.0,
+ ):
+ self.min_value = min_value
+ self.max_value = max_value
+ self.sigma_data = sigma_data
+
+ def __call__(
+ self, num_samples: int, device: torch.device = torch.device("cpu")
+ ) -> Tensor:
+ sigma_data = self.sigma_data
+ min_cdf = atan(self.min_value / sigma_data) * 2 / pi
+ max_cdf = atan(self.max_value / sigma_data) * 2 / pi
+ u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
+ return torch.tan(u * pi / 2) * sigma_data
+
+
+""" Diffusion Classes """
+
+
+def pad_dims(x: Tensor, ndim: int) -> Tensor:
+ # Pads additional ndims to the right of the tensor
+ return x.view(*x.shape, *((1,) * ndim))
+
+
+def clip(x: Tensor, dynamic_threshold: float = 0.0):
+ if dynamic_threshold == 0.0:
+ return x.clamp(-1.0, 1.0)
+ else:
+ # Dynamic thresholding
+ # Find dynamic threshold quantile for each batch
+ x_flat = rearrange(x, "b ... -> b (...)")
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
+ # Clamp to a min of 1.0
+ scale.clamp_(min=1.0)
+ # Clamp all values and scale
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
+ x = x.clamp(-scale, scale) / scale
+ return x
+
+
+def to_batch(
+ batch_size: int,
+ device: torch.device,
+ x: Optional[float] = None,
+ xs: Optional[Tensor] = None,
+) -> Tensor:
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
+ # If x provided use the same for all batch items
+ if exists(x):
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
+ assert exists(xs)
+ return xs
+
+
+class Diffusion(nn.Module):
+
+ alias: str = ""
+
+ """Base diffusion class"""
+
+ def denoise_fn(
+ self,
+ x_noisy: Tensor,
+ sigmas: Optional[Tensor] = None,
+ sigma: Optional[float] = None,
+ **kwargs,
+ ) -> Tensor:
+ raise NotImplementedError("Diffusion class missing denoise_fn")
+
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
+ raise NotImplementedError("Diffusion class missing forward function")
+
+
+class VDiffusion(Diffusion):
+
+ alias = "v"
+
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
+ super().__init__()
+ self.net = net
+ self.sigma_distribution = sigma_distribution
+
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
+ angle = sigmas * pi / 2
+ alpha = torch.cos(angle)
+ beta = torch.sin(angle)
+ return alpha, beta
+
+ def denoise_fn(
+ self,
+ x_noisy: Tensor,
+ sigmas: Optional[Tensor] = None,
+ sigma: Optional[float] = None,
+ **kwargs,
+ ) -> Tensor:
+ batch_size, device = x_noisy.shape[0], x_noisy.device
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
+ return self.net(x_noisy, sigmas, **kwargs)
+
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
+ batch_size, device = x.shape[0], x.device
+
+ # Sample amount of noise to add for each batch element
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
+
+ # Get noise
+ noise = default(noise, lambda: torch.randn_like(x))
+
+ # Combine input and noise weighted by half-circle
+ alpha, beta = self.get_alpha_beta(sigmas_padded)
+ x_noisy = x * alpha + noise * beta
+ x_target = noise * alpha - x * beta
+
+ # Denoise and return loss
+ x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
+ return F.mse_loss(x_denoised, x_target)
+
+
+class KDiffusion(Diffusion):
+ """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
+
+ alias = "k"
+
+ def __init__(
+ self,
+ net: nn.Module,
+ *,
+ sigma_distribution: Distribution,
+ sigma_data: float, # data distribution standard deviation
+ dynamic_threshold: float = 0.0,
+ ):
+ super().__init__()
+ self.net = net
+ self.sigma_data = sigma_data
+ self.sigma_distribution = sigma_distribution
+ self.dynamic_threshold = dynamic_threshold
+
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
+ sigma_data = self.sigma_data
+ c_noise = torch.log(sigmas) * 0.25
+ sigmas = rearrange(sigmas, "b -> b 1 1")
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
+ c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
+ return c_skip, c_out, c_in, c_noise
+
+ def denoise_fn(
+ self,
+ x_noisy: Tensor,
+ sigmas: Optional[Tensor] = None,
+ sigma: Optional[float] = None,
+ **kwargs,
+ ) -> Tensor:
+ batch_size, device = x_noisy.shape[0], x_noisy.device
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
+
+ # Predict network output and add skip connection
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
+ x_denoised = c_skip * x_noisy + c_out * x_pred
+
+ return x_denoised
+
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
+ # Computes weight depending on data distribution
+ return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
+
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
+ batch_size, device = x.shape[0], x.device
+ from einops import rearrange, reduce
+
+ # Sample amount of noise to add for each batch element
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
+
+ # Add noise to input
+ noise = default(noise, lambda: torch.randn_like(x))
+ x_noisy = x + sigmas_padded * noise
+
+ # Compute denoised values
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
+
+ # Compute weighted loss
+ losses = F.mse_loss(x_denoised, x, reduction="none")
+ losses = reduce(losses, "b ... -> b", "mean")
+ losses = losses * self.loss_weight(sigmas)
+ loss = losses.mean()
+ return loss
+
+
+class VKDiffusion(Diffusion):
+
+ alias = "vk"
+
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
+ super().__init__()
+ self.net = net
+ self.sigma_distribution = sigma_distribution
+
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
+ sigma_data = 1.0
+ sigmas = rearrange(sigmas, "b -> b 1 1")
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
+ c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
+ return c_skip, c_out, c_in
+
+ def sigma_to_t(self, sigmas: Tensor) -> Tensor:
+ return sigmas.atan() / pi * 2
+
+ def t_to_sigma(self, t: Tensor) -> Tensor:
+ return (t * pi / 2).tan()
+
+ def denoise_fn(
+ self,
+ x_noisy: Tensor,
+ sigmas: Optional[Tensor] = None,
+ sigma: Optional[float] = None,
+ **kwargs,
+ ) -> Tensor:
+ batch_size, device = x_noisy.shape[0], x_noisy.device
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
+
+ # Predict network output and add skip connection
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
+ x_denoised = c_skip * x_noisy + c_out * x_pred
+ return x_denoised
+
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
+ batch_size, device = x.shape[0], x.device
+
+ # Sample amount of noise to add for each batch element
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
+
+ # Add noise to input
+ noise = default(noise, lambda: torch.randn_like(x))
+ x_noisy = x + sigmas_padded * noise
+
+ # Compute model output
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
+
+ # Compute v-objective target
+ v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
+
+ # Compute loss
+ loss = F.mse_loss(x_pred, v_target)
+ return loss
+
+
+"""
+Diffusion Sampling
+"""
+
+""" Schedules """
+
+
+class Schedule(nn.Module):
+ """Interface used by different sampling schedules"""
+
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
+ raise NotImplementedError()
+
+
+class LinearSchedule(Schedule):
+ def forward(self, num_steps: int, device: Any) -> Tensor:
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
+ return sigmas
+
+
+class KarrasSchedule(Schedule):
+ """https://arxiv.org/abs/2206.00364 equation 5"""
+
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
+ super().__init__()
+ self.sigma_min = sigma_min
+ self.sigma_max = sigma_max
+ self.rho = rho
+
+ def forward(self, num_steps: int, device: Any) -> Tensor:
+ rho_inv = 1.0 / self.rho
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
+ sigmas = (
+ self.sigma_max ** rho_inv
+ + (steps / (num_steps - 1))
+ * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
+ ) ** self.rho
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
+ return sigmas
+
+
+""" Samplers """
+
+
+class Sampler(nn.Module):
+
+ diffusion_types: List[Type[Diffusion]] = []
+
+ def forward(
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
+ ) -> Tensor:
+ raise NotImplementedError()
+
+ def inpaint(
+ self,
+ source: Tensor,
+ mask: Tensor,
+ fn: Callable,
+ sigmas: Tensor,
+ num_steps: int,
+ num_resamples: int,
+ ) -> Tensor:
+ raise NotImplementedError("Inpainting not available with current sampler")
+
+
+class VSampler(Sampler):
+
+ diffusion_types = [VDiffusion]
+
+ def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
+ angle = sigma * pi / 2
+ alpha = cos(angle)
+ beta = sin(angle)
+ return alpha, beta
+
+ def forward(
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
+ ) -> Tensor:
+ x = sigmas[0] * noise
+ alpha, beta = self.get_alpha_beta(sigmas[0].item())
+
+ for i in range(num_steps - 1):
+ is_last = i == num_steps - 1
+
+ x_denoised = fn(x, sigma=sigmas[i])
+ x_pred = x * alpha - x_denoised * beta
+ x_eps = x * beta + x_denoised * alpha
+
+ if not is_last:
+ alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
+ x = x_pred * alpha + x_eps * beta
+
+ return x_pred
+
+
+class KarrasSampler(Sampler):
+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
+
+ diffusion_types = [KDiffusion, VKDiffusion]
+
+ def __init__(
+ self,
+ s_tmin: float = 0,
+ s_tmax: float = float("inf"),
+ s_churn: float = 0.0,
+ s_noise: float = 1.0,
+ ):
+ super().__init__()
+ self.s_tmin = s_tmin
+ self.s_tmax = s_tmax
+ self.s_noise = s_noise
+ self.s_churn = s_churn
+
+ def step(
+ self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
+ ) -> Tensor:
+ """Algorithm 2 (step)"""
+ # Select temporarily increased noise level
+ sigma_hat = sigma + gamma * sigma
+ # Add noise to move from sigma to sigma_hat
+ epsilon = self.s_noise * torch.randn_like(x)
+ x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
+ # Evaluate ∂x/∂sigma at sigma_hat
+ d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
+ # Take euler step from sigma_hat to sigma_next
+ x_next = x_hat + (sigma_next - sigma_hat) * d
+ # Second order correction
+ if sigma_next != 0:
+ model_out_next = fn(x_next, sigma=sigma_next)
+ d_prime = (x_next - model_out_next) / sigma_next
+ x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
+ return x_next
+
+ def forward(
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
+ ) -> Tensor:
+ x = sigmas[0] * noise
+ # Compute gammas
+ gammas = torch.where(
+ (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
+ min(self.s_churn / num_steps, sqrt(2) - 1),
+ 0.0,
+ )
+ # Denoise to sample
+ for i in range(num_steps - 1):
+ x = self.step(
+ x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
+ )
+
+ return x
+
+
+class AEulerSampler(Sampler):
+
+ diffusion_types = [KDiffusion, VKDiffusion]
+
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
+ return sigma_up, sigma_down
+
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
+ # Sigma steps
+ sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
+ # Derivative at sigma (∂x/∂sigma)
+ d = (x - fn(x, sigma=sigma)) / sigma
+ # Euler method
+ x_next = x + d * (sigma_down - sigma)
+ # Add randomness
+ x_next = x_next + torch.randn_like(x) * sigma_up
+ return x_next
+
+ def forward(
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
+ ) -> Tensor:
+ x = sigmas[0] * noise
+ # Denoise to sample
+ for i in range(num_steps - 1):
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
+ return x
+
+
+class ADPM2Sampler(Sampler):
+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
+
+ diffusion_types = [KDiffusion, VKDiffusion]
+
+ def __init__(self, rho: float = 1.0):
+ super().__init__()
+ self.rho = rho
+
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
+ r = self.rho
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
+ return sigma_up, sigma_down, sigma_mid
+
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
+ # Sigma steps
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
+ # Derivative at sigma (∂x/∂sigma)
+ d = (x - fn(x, sigma=sigma)) / sigma
+ # Denoise to midpoint
+ x_mid = x + d * (sigma_mid - sigma)
+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
+ # Denoise to next
+ x = x + d_mid * (sigma_down - sigma)
+ # Add randomness
+ x_next = x + torch.randn_like(x) * sigma_up
+ return x_next
+
+ def forward(
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
+ ) -> Tensor:
+ x = sigmas[0] * noise
+ # Denoise to sample
+ for i in range(num_steps - 1):
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
+ return x
+
+ def inpaint(
+ self,
+ source: Tensor,
+ mask: Tensor,
+ fn: Callable,
+ sigmas: Tensor,
+ num_steps: int,
+ num_resamples: int,
+ ) -> Tensor:
+ x = sigmas[0] * torch.randn_like(source)
+
+ for i in range(num_steps - 1):
+ # Noise source to current noise level
+ source_noisy = source + sigmas[i] * torch.randn_like(source)
+ for r in range(num_resamples):
+ # Merge noisy source and current then denoise
+ x = source_noisy * mask + x * ~mask
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
+ # Renoise if not last resample step
+ if r < num_resamples - 1:
+ sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
+ x = x + sigma * torch.randn_like(x)
+
+ return source * mask + x * ~mask
+
+
+""" Main Classes """
+
+
+class DiffusionSampler(nn.Module):
+ def __init__(
+ self,
+ diffusion: Diffusion,
+ *,
+ sampler: Sampler,
+ sigma_schedule: Schedule,
+ num_steps: Optional[int] = None,
+ clamp: bool = True,
+ ):
+ super().__init__()
+ self.denoise_fn = diffusion.denoise_fn
+ self.sampler = sampler
+ self.sigma_schedule = sigma_schedule
+ self.num_steps = num_steps
+ self.clamp = clamp
+
+ # Check sampler is compatible with diffusion type
+ sampler_class = sampler.__class__.__name__
+ diffusion_class = diffusion.__class__.__name__
+ message = f"{sampler_class} incompatible with {diffusion_class}"
+ assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
+
+ def forward(
+ self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
+ ) -> Tensor:
+ device = noise.device
+ num_steps = default(num_steps, self.num_steps) # type: ignore
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
+ # Compute sigmas using schedule
+ sigmas = self.sigma_schedule(num_steps, device)
+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
+ # Sample using sampler
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
+ return x
+
+
+class DiffusionInpainter(nn.Module):
+ def __init__(
+ self,
+ diffusion: Diffusion,
+ *,
+ num_steps: int,
+ num_resamples: int,
+ sampler: Sampler,
+ sigma_schedule: Schedule,
+ ):
+ super().__init__()
+ self.denoise_fn = diffusion.denoise_fn
+ self.num_steps = num_steps
+ self.num_resamples = num_resamples
+ self.inpaint_fn = sampler.inpaint
+ self.sigma_schedule = sigma_schedule
+
+ @torch.no_grad()
+ def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
+ x = self.inpaint_fn(
+ source=inpaint,
+ mask=inpaint_mask,
+ fn=self.denoise_fn,
+ sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
+ num_steps=self.num_steps,
+ num_resamples=self.num_resamples,
+ )
+ return x
+
+
+def sequential_mask(like: Tensor, start: int) -> Tensor:
+ length, device = like.shape[2], like.device
+ mask = torch.ones_like(like, dtype=torch.bool)
+ mask[:, :, start:] = torch.zeros((length - start,), device=device)
+ return mask
+
+
+class SpanBySpanComposer(nn.Module):
+ def __init__(
+ self,
+ inpainter: DiffusionInpainter,
+ *,
+ num_spans: int,
+ ):
+ super().__init__()
+ self.inpainter = inpainter
+ self.num_spans = num_spans
+
+ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
+ half_length = start.shape[2] // 2
+
+ spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
+ # Inpaint second half from first half
+ inpaint = torch.zeros_like(start)
+ inpaint[:, :, :half_length] = start[:, :, half_length:]
+ inpaint_mask = sequential_mask(like=start, start=half_length)
+
+ for i in range(self.num_spans):
+ # Inpaint second half
+ span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
+ # Replace first half with generated second half
+ second_half = span[:, :, half_length:]
+ inpaint[:, :, :half_length] = second_half
+ # Save generated span
+ spans.append(second_half)
+
+ return torch.cat(spans, dim=2)
+
+
+class XDiffusion(nn.Module):
+ def __init__(self, type: str, net: nn.Module, **kwargs):
+ super().__init__()
+
+ diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
+ aliases = [t.alias for t in diffusion_classes] # type: ignore
+ message = f"type='{type}' must be one of {*aliases,}"
+ assert type in aliases, message
+ self.net = net
+
+ for XDiffusion in diffusion_classes:
+ if XDiffusion.alias == type: # type: ignore
+ self.diffusion = XDiffusion(net=net, **kwargs)
+
+ def forward(self, *args, **kwargs) -> Tensor:
+ return self.diffusion(*args, **kwargs)
+
+ def sample(
+ self,
+ noise: Tensor,
+ num_steps: int,
+ sigma_schedule: Schedule,
+ sampler: Sampler,
+ clamp: bool,
+ **kwargs,
+ ) -> Tensor:
+ diffusion_sampler = DiffusionSampler(
+ diffusion=self.diffusion,
+ sampler=sampler,
+ sigma_schedule=sigma_schedule,
+ num_steps=num_steps,
+ clamp=clamp,
+ )
+ return diffusion_sampler(noise, **kwargs)
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/diffusion/utils.py b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f330b98dddf65f6ae8473303a7d5abd31563f362
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/diffusion/utils.py
@@ -0,0 +1,82 @@
+from functools import reduce
+from inspect import isfunction
+from math import ceil, floor, log2, pi
+from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import Generator, Tensor
+from typing_extensions import TypeGuard
+
+T = TypeVar("T")
+
+
+def exists(val: Optional[T]) -> TypeGuard[T]:
+ return val is not None
+
+
+def iff(condition: bool, value: T) -> Optional[T]:
+ return value if condition else None
+
+
+def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
+ return isinstance(obj, list) or isinstance(obj, tuple)
+
+
+def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def to_list(val: Union[T, Sequence[T]]) -> List[T]:
+ if isinstance(val, tuple):
+ return list(val)
+ if isinstance(val, list):
+ return val
+ return [val] # type: ignore
+
+
+def prod(vals: Sequence[int]) -> int:
+ return reduce(lambda x, y: x * y, vals)
+
+
+def closest_power_2(x: float) -> int:
+ exponent = log2(x)
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
+ return 2 ** int(exponent_closest)
+
+def rand_bool(shape, proba, device = None):
+ if proba == 1:
+ return torch.ones(shape, device=device, dtype=torch.bool)
+ elif proba == 0:
+ return torch.zeros(shape, device=device, dtype=torch.bool)
+ else:
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
+
+
+"""
+Kwargs Utils
+"""
+
+
+def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
+ for key in d.keys():
+ no_prefix = int(not key.startswith(prefix))
+ return_dicts[no_prefix][key] = d[key]
+ return return_dicts
+
+
+def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
+ if keep_prefix:
+ return kwargs_with_prefix, kwargs
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
+ return kwargs_no_prefix, kwargs
+
+
+def prefix_dict(prefix: str, d: Dict) -> Dict:
+ return {prefix + str(k): v for k, v in d.items()}
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/discriminators.py b/stts_48khz/StyleTTS2_48khz/Modules/discriminators.py
new file mode 100644
index 0000000000000000000000000000000000000000..31a187acc41d04d6afb29a05dbbf5d951889598c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/discriminators.py
@@ -0,0 +1,188 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, spectral_norm
+
+from .utils import get_padding
+
+LRELU_SLOPE = 0.1
+
+def stft(x, fft_size, hop_size, win_length, window):
+ """Perform STFT and convert to magnitude spectrogram.
+ Args:
+ x (Tensor): Input signal tensor (B, T).
+ fft_size (int): FFT size.
+ hop_size (int): Hop size.
+ win_length (int): Window length.
+ window (str): Window function type.
+ Returns:
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+ """
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
+ return_complex=True)
+ real = x_stft[..., 0]
+ imag = x_stft[..., 1]
+
+ return torch.abs(x_stft).transpose(2, 1)
+
+class SpecDiscriminator(nn.Module):
+ """docstring for Discriminator."""
+
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
+ super(SpecDiscriminator, self).__init__()
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.fft_size = fft_size
+ self.shift_size = shift_size
+ self.win_length = win_length
+ self.window = getattr(torch, window)(win_length)
+ self.discriminators = nn.ModuleList([
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
+ ])
+
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
+
+ def forward(self, y):
+
+ fmap = []
+ y = y.squeeze(1)
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
+ y = y.unsqueeze(1)
+ for i, d in enumerate(self.discriminators):
+ y = d(y)
+ y = F.leaky_relu(y, LRELU_SLOPE)
+ fmap.append(y)
+
+ y = self.out(y)
+ fmap.append(y)
+
+ return torch.flatten(y, 1, -1), fmap
+
+class MultiResSpecDiscriminator(torch.nn.Module):
+
+ def __init__(self,
+ fft_sizes=[1024, 2048, 512],
+ hop_sizes=[120, 240, 50],
+ win_lengths=[600, 1200, 240],
+ window="hann_window"):
+
+ super(MultiResSpecDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList([
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
+ ])
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList([
+ DiscriminatorP(2),
+ DiscriminatorP(3),
+ DiscriminatorP(5),
+ DiscriminatorP(7),
+ DiscriminatorP(11),
+ ])
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+class WavLMDiscriminator(nn.Module):
+ """docstring for Discriminator."""
+
+ def __init__(self, slm_hidden=768,
+ slm_layers=13,
+ initial_channel=64,
+ use_spectral_norm=False):
+ super(WavLMDiscriminator, self).__init__()
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.pre = norm_f(Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
+
+ self.convs = nn.ModuleList([
+ norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)),
+ norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)),
+ norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)),
+ ])
+
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ x = self.pre(x)
+
+ fmap = []
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/hifigan.py b/stts_48khz/StyleTTS2_48khz/Modules/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7ed14df89d35d4c810364f9cfb1b96a6a55056d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/hifigan.py
@@ -0,0 +1,477 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from .utils import init_weights, get_padding
+
+import math
+import random
+import numpy as np
+
+LRELU_SLOPE = 0.1
+
+class AdaIN1d(nn.Module):
+ def __init__(self, style_dim, num_features):
+ super().__init__()
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
+ self.fc = nn.Linear(style_dim, num_features*2)
+
+ def forward(self, x, s):
+ h = self.fc(s)
+ h = h.view(h.size(0), h.size(1), 1)
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
+ return (1 + gamma) * self.norm(x) + beta
+
+class AdaINResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
+ super(AdaINResBlock1, self).__init__()
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ self.adain1 = nn.ModuleList([
+ AdaIN1d(style_dim, channels),
+ AdaIN1d(style_dim, channels),
+ AdaIN1d(style_dim, channels),
+ ])
+
+ self.adain2 = nn.ModuleList([
+ AdaIN1d(style_dim, channels),
+ AdaIN1d(style_dim, channels),
+ AdaIN1d(style_dim, channels),
+ ])
+
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
+
+
+ def forward(self, x, s):
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
+ xt = n1(x, s)
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
+ xt = c1(xt)
+ xt = n2(xt, s)
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0,
+ flag_for_pulse=False):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = self.harmonic_num + 1
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.flag_for_pulse = flag_for_pulse
+ self.upsample_scale = upsample_scale
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
+ return uv
+
+ def _f02sine(self, f0_values):
+ """ f0_values: (batchsize, length, dim)
+ where dim indicates fundamental tone and overtones
+ """
+ # convert to F0 in rad. The interger part n can be ignored
+ # because 2 * np.pi * n doesn't affect phase
+ rad_values = (f0_values / self.sampling_rate) % 1
+
+ # initial phase noise (no noise for fundamental component)
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
+ device=f0_values.device)
+ rand_ini[:, 0] = 0
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+ if not self.flag_for_pulse:
+# # for normal case
+
+# # To prevent torch.cumsum numerical overflow,
+# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+# # Buffer tmp_over_one_idx indicates the time step to add -1.
+# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+# tmp_over_one = torch.cumsum(rad_values, 1) % 1
+# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
+# cumsum_shift = torch.zeros_like(rad_values)
+# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
+ scale_factor=1/self.upsample_scale,
+ mode="linear").transpose(1, 2)
+
+# tmp_over_one = torch.cumsum(rad_values, 1) % 1
+# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
+# cumsum_shift = torch.zeros_like(rad_values)
+# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
+ sines = torch.sin(phase)
+
+ else:
+ # If necessary, make sure that the first time step of every
+ # voiced segments is sin(pi) or cos(0)
+ # This is used for pulse-train generation
+
+ # identify the last time step in unvoiced segments
+ uv = self._f02uv(f0_values)
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
+ uv_1[:, -1, :] = 1
+ u_loc = (uv < 1) * (uv_1 > 0)
+
+ # get the instantanouse phase
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
+ # different batch needs to be processed differently
+ for idx in range(f0_values.shape[0]):
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+ # stores the accumulation of i.phase within
+ # each voiced segments
+ tmp_cumsum[idx, :, :] = 0
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
+ # within the previous voiced segment.
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+ # get the sines
+ sines = torch.cos(i_phase * 2 * np.pi)
+ return sines
+
+ def forward(self, f0):
+ """ sine_tensor, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output sine_tensor: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+ """
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
+ device=f0.device)
+ # fundamental component
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
+
+ # generate sine waveforms
+ sine_waves = self._f02sine(fn) * self.sine_amp
+
+ # generate uv signal
+ # uv = torch.ones(f0.shape)
+ # uv = uv * (f0 > self.voiced_threshold)
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ with torch.no_grad():
+ sine_wavs, uv, _ = self.l_sin_gen(x)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.sine_amp / 3
+ return sine_merge, noise, uv
+def padDiff(x):
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
+
+class Generator(torch.nn.Module):
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
+ super(Generator, self).__init__()
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ resblock = AdaINResBlock1
+
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=48000,
+ upsample_scale=np.prod(upsample_rates),
+ harmonic_num=8, voiced_threshod=10)
+
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
+ self.noise_convs = nn.ModuleList()
+ self.ups = nn.ModuleList()
+ self.noise_res = nn.ModuleList()
+
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
+
+ self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
+ upsample_initial_channel//(2**(i+1)),
+ k, u, padding=(u//2 + u%2), output_padding=u%2)))
+
+ if i + 1 < len(upsample_rates): #
+ stride_f0 = np.prod(upsample_rates[i + 1:])
+ self.noise_convs.append(Conv1d(
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
+ else:
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
+
+ self.resblocks = nn.ModuleList()
+
+ self.alphas = nn.ParameterList()
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
+
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel//(2**(i+1))
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
+
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(resblock(ch, k, d, style_dim))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x, s, f0):
+
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+
+ har_source, noi_source, uv = self.m_source(f0)
+ har_source = har_source.transpose(1, 2)
+
+ for i in range(self.num_upsamples):
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
+ x_source = self.noise_convs[i](har_source)
+ x_source = self.noise_res[i](x_source, s)
+
+ x = self.ups[i](x)
+ x = x + x_source
+
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
+ else:
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
+ x = xs / self.num_kernels
+ x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+class AdainResBlk1d(nn.Module):
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
+ upsample='none', dropout_p=0.0):
+ super().__init__()
+ self.actv = actv
+ self.upsample_type = upsample
+ self.upsample = UpSample1d(upsample)
+ self.learned_sc = dim_in != dim_out
+ self._build_weights(dim_in, dim_out, style_dim)
+ self.dropout = nn.Dropout(dropout_p)
+
+ if upsample == 'none':
+ self.pool = nn.Identity()
+ else:
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
+
+
+ def _build_weights(self, dim_in, dim_out, style_dim):
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
+ self.norm1 = AdaIN1d(style_dim, dim_in)
+ self.norm2 = AdaIN1d(style_dim, dim_out)
+ if self.learned_sc:
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+ def _shortcut(self, x):
+ x = self.upsample(x)
+ if self.learned_sc:
+ x = self.conv1x1(x)
+ return x
+
+ def _residual(self, x, s):
+ x = self.norm1(x, s)
+ x = self.actv(x)
+ x = self.pool(x)
+ x = self.conv1(self.dropout(x))
+ x = self.norm2(x, s)
+ x = self.actv(x)
+ x = self.conv2(self.dropout(x))
+ return x
+
+ def forward(self, x, s):
+ out = self._residual(x, s)
+ out = (out + self._shortcut(x)) / math.sqrt(2)
+ return out
+
+class UpSample1d(nn.Module):
+ def __init__(self, layer_type):
+ super().__init__()
+ self.layer_type = layer_type
+
+ def forward(self, x):
+ if self.layer_type == 'none':
+ return x
+ else:
+ return F.interpolate(x, scale_factor=2, mode='nearest')
+
+class Decoder(nn.Module):
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
+ resblock_kernel_sizes = [3,7,11],
+ upsample_rates = [10,5,3,2],
+ upsample_initial_channel=512,
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
+ upsample_kernel_sizes=[20,10,6,4]):
+ super().__init__()
+
+ self.decode = nn.ModuleList()
+
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
+
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
+
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
+
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
+
+ self.asr_res = nn.Sequential(
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
+ )
+
+
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
+
+
+ def forward(self, asr, F0_curve, N, s):
+ if self.training:
+ downlist = [0, 3, 7]
+ F0_down = downlist[random.randint(0, 2)]
+ downlist = [0, 3, 7, 15]
+ N_down = downlist[random.randint(0, 3)]
+ if F0_down:
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
+ if N_down:
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
+
+
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
+ N = self.N_conv(N.unsqueeze(1))
+
+ x = torch.cat([asr, F0, N], axis=1)
+ x = self.encode(x, s)
+
+ asr_res = self.asr_res(asr)
+
+ res = True
+ for block in self.decode:
+ if res:
+ x = torch.cat([x, asr_res, F0, N], axis=1)
+ x = block(x, s)
+ if block.upsample_type != "none":
+ res = False
+
+ x = self.generator(x, s, F0_curve)
+ return x
+
+
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/istftnet.py b/stts_48khz/StyleTTS2_48khz/Modules/istftnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..80a30b89fbf4c6b77ab845736f7b7c447e05a33e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/istftnet.py
@@ -0,0 +1,530 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from .utils import init_weights, get_padding
+
+import math
+import random
+import numpy as np
+from scipy.signal import get_window
+
+LRELU_SLOPE = 0.1
+
+class AdaIN1d(nn.Module):
+ def __init__(self, style_dim, num_features):
+ super().__init__()
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
+ self.fc = nn.Linear(style_dim, num_features*2)
+
+ def forward(self, x, s):
+ h = self.fc(s)
+ h = h.view(h.size(0), h.size(1), 1)
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
+ return (1 + gamma) * self.norm(x) + beta
+
+class AdaINResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
+ super(AdaINResBlock1, self).__init__()
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ self.adain1 = nn.ModuleList([
+ AdaIN1d(style_dim, channels),
+ AdaIN1d(style_dim, channels),
+ AdaIN1d(style_dim, channels),
+ ])
+
+ self.adain2 = nn.ModuleList([
+ AdaIN1d(style_dim, channels),
+ AdaIN1d(style_dim, channels),
+ AdaIN1d(style_dim, channels),
+ ])
+
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
+
+
+ def forward(self, x, s):
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
+ xt = n1(x, s)
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
+ xt = c1(xt)
+ xt = n2(xt, s)
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+class TorchSTFT(torch.nn.Module):
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
+ super().__init__()
+ self.filter_length = filter_length
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
+
+ def transform(self, input_data):
+ forward_transform = torch.stft(
+ input_data,
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
+ return_complex=True)
+
+ return torch.abs(forward_transform), torch.angle(forward_transform)
+
+ def inverse(self, magnitude, phase):
+ inverse_transform = torch.istft(
+ magnitude * torch.exp(phase * 1j),
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
+
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
+
+ def forward(self, input_data):
+ self.magnitude, self.phase = self.transform(input_data)
+ reconstruction = self.inverse(self.magnitude, self.phase)
+ return reconstruction
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0,
+ flag_for_pulse=False):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = self.harmonic_num + 1
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.flag_for_pulse = flag_for_pulse
+ self.upsample_scale = upsample_scale
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
+ return uv
+
+ def _f02sine(self, f0_values):
+ """ f0_values: (batchsize, length, dim)
+ where dim indicates fundamental tone and overtones
+ """
+ # convert to F0 in rad. The interger part n can be ignored
+ # because 2 * np.pi * n doesn't affect phase
+ rad_values = (f0_values / self.sampling_rate) % 1
+
+ # initial phase noise (no noise for fundamental component)
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
+ device=f0_values.device)
+ rand_ini[:, 0] = 0
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+ if not self.flag_for_pulse:
+# # for normal case
+
+# # To prevent torch.cumsum numerical overflow,
+# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+# # Buffer tmp_over_one_idx indicates the time step to add -1.
+# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+# tmp_over_one = torch.cumsum(rad_values, 1) % 1
+# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
+# cumsum_shift = torch.zeros_like(rad_values)
+# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
+ scale_factor=1/self.upsample_scale,
+ mode="linear").transpose(1, 2)
+
+# tmp_over_one = torch.cumsum(rad_values, 1) % 1
+# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
+# cumsum_shift = torch.zeros_like(rad_values)
+# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
+ sines = torch.sin(phase)
+
+ else:
+ # If necessary, make sure that the first time step of every
+ # voiced segments is sin(pi) or cos(0)
+ # This is used for pulse-train generation
+
+ # identify the last time step in unvoiced segments
+ uv = self._f02uv(f0_values)
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
+ uv_1[:, -1, :] = 1
+ u_loc = (uv < 1) * (uv_1 > 0)
+
+ # get the instantanouse phase
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
+ # different batch needs to be processed differently
+ for idx in range(f0_values.shape[0]):
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+ # stores the accumulation of i.phase within
+ # each voiced segments
+ tmp_cumsum[idx, :, :] = 0
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
+ # within the previous voiced segment.
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+ # get the sines
+ sines = torch.cos(i_phase * 2 * np.pi)
+ return sines
+
+ def forward(self, f0):
+ """ sine_tensor, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output sine_tensor: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+ """
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
+ device=f0.device)
+ # fundamental component
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
+
+ # generate sine waveforms
+ sine_waves = self._f02sine(fn) * self.sine_amp
+
+ # generate uv signal
+ # uv = torch.ones(f0.shape)
+ # uv = uv * (f0 > self.voiced_threshold)
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ with torch.no_grad():
+ sine_wavs, uv, _ = self.l_sin_gen(x)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.sine_amp / 3
+ return sine_merge, noise, uv
+def padDiff(x):
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
+ super(Generator, self).__init__()
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ resblock = AdaINResBlock1
+
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=48000,
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
+ harmonic_num=8, voiced_threshod=10)
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
+ self.noise_convs = nn.ModuleList()
+ self.noise_res = nn.ModuleList()
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(weight_norm(
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
+ k, u, padding=(k-u)//2)))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel//(2**(i+1))
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
+ self.resblocks.append(resblock(ch, k, d, style_dim))
+
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
+
+ if i + 1 < len(upsample_rates): #
+ stride_f0 = np.prod(upsample_rates[i + 1:])
+ self.noise_convs.append(Conv1d(
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
+ else:
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
+
+
+ self.post_n_fft = gen_istft_n_fft
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
+ self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
+
+
+ def forward(self, x, s, f0):
+ with torch.no_grad():
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+
+ har_source, noi_source, uv = self.m_source(f0)
+ har_source = har_source.transpose(1, 2).squeeze(1)
+ har_spec, har_phase = self.stft.transform(har_source)
+ har = torch.cat([har_spec, har_phase], dim=1)
+
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x_source = self.noise_convs[i](har)
+ x_source = self.noise_res[i](x_source, s)
+
+ x = self.ups[i](x)
+ if i == self.num_upsamples - 1:
+ x = self.reflection_pad(x)
+
+ x = x + x_source
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
+ else:
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
+ return self.stft.inverse(spec, phase)
+
+ def fw_phase(self, x, s):
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
+ else:
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.reflection_pad(x)
+ x = self.conv_post(x)
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
+ return spec, phase
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+class AdainResBlk1d(nn.Module):
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
+ upsample='none', dropout_p=0.0):
+ super().__init__()
+ self.actv = actv
+ self.upsample_type = upsample
+ self.upsample = UpSample1d(upsample)
+ self.learned_sc = dim_in != dim_out
+ self._build_weights(dim_in, dim_out, style_dim)
+ self.dropout = nn.Dropout(dropout_p)
+
+ if upsample == 'none':
+ self.pool = nn.Identity()
+ else:
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
+
+
+ def _build_weights(self, dim_in, dim_out, style_dim):
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
+ self.norm1 = AdaIN1d(style_dim, dim_in)
+ self.norm2 = AdaIN1d(style_dim, dim_out)
+ if self.learned_sc:
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+ def _shortcut(self, x):
+ x = self.upsample(x)
+ if self.learned_sc:
+ x = self.conv1x1(x)
+ return x
+
+ def _residual(self, x, s):
+ x = self.norm1(x, s)
+ x = self.actv(x)
+ x = self.pool(x)
+ x = self.conv1(self.dropout(x))
+ x = self.norm2(x, s)
+ x = self.actv(x)
+ x = self.conv2(self.dropout(x))
+ return x
+
+ def forward(self, x, s):
+ out = self._residual(x, s)
+ out = (out + self._shortcut(x)) / math.sqrt(2)
+ return out
+
+class UpSample1d(nn.Module):
+ def __init__(self, layer_type):
+ super().__init__()
+ self.layer_type = layer_type
+
+ def forward(self, x):
+ if self.layer_type == 'none':
+ return x
+ else:
+ return F.interpolate(x, scale_factor=2, mode='nearest')
+
+class Decoder(nn.Module):
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
+ resblock_kernel_sizes = [3,7,11],
+ upsample_rates = [10, 6],
+ upsample_initial_channel=512,
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
+ upsample_kernel_sizes=[20, 12],
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
+ super().__init__()
+
+ self.decode = nn.ModuleList()
+
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
+
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
+
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
+
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
+
+ self.asr_res = nn.Sequential(
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
+ )
+
+
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
+ upsample_initial_channel, resblock_dilation_sizes,
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
+
+ def forward(self, asr, F0_curve, N, s):
+ if self.training:
+ downlist = [0, 3, 7]
+ F0_down = downlist[random.randint(0, 2)]
+ downlist = [0, 3, 7, 15]
+ N_down = downlist[random.randint(0, 3)]
+ if F0_down:
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
+ if N_down:
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
+
+
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
+ N = self.N_conv(N.unsqueeze(1))
+
+ x = torch.cat([asr, F0, N], axis=1)
+ x = self.encode(x, s)
+
+ asr_res = self.asr_res(asr)
+
+ res = True
+ for block in self.decode:
+ if res:
+ x = torch.cat([x, asr_res, F0, N], axis=1)
+ x = block(x, s)
+ if block.upsample_type != "none":
+ res = False
+
+ x = self.generator(x, s, F0_curve)
+ return x
+
+
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/slmadv.py b/stts_48khz/StyleTTS2_48khz/Modules/slmadv.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b8b81fde40780104d49b386954b8eaba8838fef
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/slmadv.py
@@ -0,0 +1,400 @@
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+class SLMAdversarialLoss(torch.nn.Module):
+
+ def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
+ super(SLMAdversarialLoss, self).__init__()
+ self.model = model
+ self.wl = wl
+ self.sampler = sampler
+
+ self.min_len = min_len
+ self.max_len = max_len
+ self.batch_percentage = batch_percentage
+
+ self.sig = sig
+ self.skip_update = skip_update
+
+ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
+
+ text_mask = length_to_mask(ref_lengths).to(ref_text.device)
+
+
+ # if text_mask.shape[1] > 512:
+ # text_mask = text_mask[:, :512]
+
+ bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
+
+
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ if use_ind and np.random.rand() < 0.5:
+ s_preds = s_trg
+ else:
+ num_steps = np.random.randint(3, 5)
+ if ref_s is not None:
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ features=ref_s, # reference from the same speaker as the embedding
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ else:
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+
+ s_dur = s_preds[:, 128:]
+ s = s_preds[:, :128]
+
+ d, _ = self.model.predictor(d_en, s_dur,
+ ref_lengths,
+ torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
+ text_mask)
+
+ bib = 0
+
+ output_lengths = []
+ attn_preds = []
+
+ # differentiable duration modeling
+ for _s2s_pred, _text_length in zip(d, ref_lengths):
+
+ _s2s_pred_org = _s2s_pred[:_text_length, :]
+
+ _s2s_pred = torch.sigmoid(_s2s_pred_org)
+ _dur_pred = _s2s_pred.sum(axis=-1)
+
+ l = int(torch.round(_s2s_pred.sum()).item())
+ t = torch.arange(0, l).expand(l)
+
+ t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
+
+ h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
+
+ out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
+ h.unsqueeze(1),
+ padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
+
+ output_lengths.append(l)
+
+ max_len = max(output_lengths)
+
+ with torch.no_grad():
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
+
+ s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
+ for bib in range(len(output_lengths)):
+ s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
+
+ asr_pred = t_en @ s2s_attn
+
+ _, p_pred = self.model.predictor(d_en, s_dur,
+ ref_lengths,
+ s2s_attn,
+ text_mask)
+
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
+ mel_len = min(mel_len, self.max_len // 2)
+
+ # get clips
+
+ en = []
+ p_en = []
+ sp = []
+
+ F0_fakes = []
+ N_fakes = []
+
+ wav = []
+
+ for bib in range(len(output_lengths)):
+ mel_length_pred = output_lengths[bib]
+ mel_length_gt = int(mel_input_length[bib].item() / 2)
+ if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
+ continue
+
+ sp.append(s_preds[bib])
+
+ random_start = np.random.randint(0, mel_length_pred - mel_len)
+ en.append(asr_pred[bib, :, random_start:random_start+mel_len])
+ p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
+
+ # get ground truth clips
+ random_start = np.random.randint(0, mel_length_gt - mel_len)
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(ref_text.device))
+
+ if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
+ break
+
+ if len(sp) <= 1:
+ return None
+
+ sp = torch.stack(sp)
+ wav = torch.stack(wav).float()
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+
+ F0_fake, N_fake = self.model.predictor(texts=p_en, style=sp[:, 128:], f0=True)
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
+
+ # discriminator loss
+ if (iters + 1) % self.skip_update == 0:
+ if np.random.randint(0, 2) == 0:
+ wav = y_rec_gt_pred
+ use_rec = True
+ else:
+ use_rec = False
+
+ crop_size = min(wav.size(-1), y_pred.size(-1))
+ if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
+ if wav.size(-1) > y_pred.size(-1):
+ real_GP = wav[:, : , :crop_size]
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
+ out_org = self.wl.discriminator_forward(wav.detach().squeeze())
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
+
+ if np.random.randint(0, 2) == 0:
+ d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean()
+ else:
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
+ else:
+ real_GP = y_pred[:, : , :crop_size]
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
+ out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
+
+ if np.random.randint(0, 2) == 0:
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean()
+ else:
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
+
+ # regularization (ignore length variation)
+ d_loss += loss_reg
+
+ out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
+ out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze())
+
+ # regularization (ignore reconstruction artifacts)
+ d_loss += F.l1_loss(out_gt, out_rec)
+
+ else:
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
+ else:
+ d_loss = 0
+
+ # generator loss
+ gen_loss = self.wl.generator(y_pred.squeeze())
+
+ gen_loss = gen_loss.mean()
+
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
+
+def length_to_mask(lengths):
+
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+# import torch
+# import numpy as np
+# import torch.nn.functional as F
+
+# class SLMAdversarialLoss(torch.nn.Module):
+
+# def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
+# super(SLMAdversarialLoss, self).__init__()
+# self.model = model
+# self.wl = wl
+# self.sampler = sampler
+
+# self.min_len = min_len
+# self.max_len = max_len
+# self.batch_percentage = batch_percentage
+
+# self.sig = sig
+# self.skip_update = skip_update
+
+# def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
+# text_mask = length_to_mask(ref_lengths).to(ref_text.device)
+# bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
+# d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
+
+# if use_ind and np.random.rand() < 0.5:
+# s_preds = s_trg
+# else:
+# num_steps = np.random.randint(3, 5)
+# if ref_s is not None:
+# s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
+# embedding=bert_dur,
+# embedding_scale=1,
+# features=ref_s, # reference from the same speaker as the embedding
+# embedding_mask_proba=0.1,
+# num_steps=num_steps).squeeze(1)
+# else:
+# s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
+# embedding=bert_dur,
+# embedding_scale=1,
+# embedding_mask_proba=0.1,
+# num_steps=num_steps).squeeze(1)
+
+# s_dur = s_preds[:, 128:]
+# s = s_preds[:, :128]
+
+# d, _ = self.model.predictor(d_en, s_dur,
+# ref_lengths,
+# torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
+# text_mask)
+
+# bib = 0
+
+# output_lengths = []
+# attn_preds = []
+
+# # differentiable duration modeling
+# for _s2s_pred, _text_length in zip(d, ref_lengths):
+
+# _s2s_pred_org = _s2s_pred[:_text_length, :]
+
+# _s2s_pred = torch.sigmoid(_s2s_pred_org)
+# _dur_pred = _s2s_pred.sum(axis=-1)
+
+# l = int(torch.round(_s2s_pred.sum()).item())
+# t = torch.arange(0, l).expand(l)
+
+# t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
+# loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
+
+# h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
+
+# out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
+# h.unsqueeze(1),
+# padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
+# attn_preds.append(F.softmax(out.squeeze(), dim=0))
+
+# output_lengths.append(l)
+
+# max_len = max(output_lengths)
+
+# with torch.no_grad():
+# t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
+
+# s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
+# for bib in range(len(output_lengths)):
+# s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
+
+# asr_pred = t_en @ s2s_attn
+
+# _, p_pred = self.model.predictor(d_en, s_dur,
+# ref_lengths,
+# s2s_attn,
+# text_mask)
+
+# mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
+# mel_len = min(mel_len, self.max_len // 2)
+
+# # get clips
+
+# en = []
+# p_en = []
+# sp = []
+
+# F0_fakes = []
+# N_fakes = []
+
+# wav = []
+
+# for bib in range(len(output_lengths)):
+# mel_length_pred = output_lengths[bib]
+# mel_length_gt = int(mel_input_length[bib].item() / 2)
+# if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
+# continue
+
+# sp.append(s_preds[bib])
+
+# random_start = np.random.randint(0, mel_length_pred - mel_len)
+# en.append(asr_pred[bib, :, random_start:random_start+mel_len])
+# p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
+
+# # get ground truth clips
+# random_start = np.random.randint(0, mel_length_gt - mel_len)
+# y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+# wav.append(torch.from_numpy(y).to(ref_text.device))
+
+# if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
+# break
+
+# if len(sp) <= 1:
+# return None
+
+# sp = torch.stack(sp)
+# wav = torch.stack(wav).float()
+# en = torch.stack(en)
+# p_en = torch.stack(p_en)
+
+# F0_fake, N_fake = self.model.predictor(texts=p_en, style=sp[:, 128:], f0=True)
+# y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
+
+# # discriminator loss
+# if (iters + 1) % self.skip_update == 0:
+# if np.random.randint(0, 2) == 0:
+# wav = y_rec_gt_pred
+# use_rec = True
+# else:
+# use_rec = False
+
+# crop_size = min(wav.size(-1), y_pred.size(-1))
+# if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
+# if wav.size(-1) > y_pred.size(-1):
+# real_GP = wav[:, : , :crop_size]
+# out_crop = self.wl(wav = real_GP.detach().squeeze(),y_rec=None, discriminator_forward=True)
+# out_org = self.wl(wav = wav.detach().squeeze(),y_rec=None, discriminator_forward=True)
+# loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
+
+# if np.random.randint(0, 2) == 0:
+# d_loss = self.wl(wav = real_GP.detach().squeeze(),y_rec= y_pred.detach().squeeze(), discriminator=True).mean()
+# else:
+# d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = y_pred.detach().squeeze(), discriminator=True).mean()
+# else:
+# real_GP = y_pred[:, : , :crop_size]
+# out_crop = self.wl(wav = real_GP.detach().squeeze(), y_rec=None, discriminator_forward=True)
+# out_org = self.wl(wav = y_pred.detach().squeeze(),y_rec=None, discriminator_forward=True)
+# loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
+
+# if np.random.randint(0, 2) == 0:
+# d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = real_GP.detach().squeeze(), discriminator=True ).mean()
+# else:
+# d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = y_pred.detach().squeeze(), discriminator=True).mean()
+
+# # regularization (ignore length variation)
+# d_loss += loss_reg
+
+# out_gt = self.wl(wav = y_rec_gt.detach().squeeze(),y_rec=None, discriminator_forward=True)
+# out_rec = self.wl(wav = y_rec_gt_pred.detach().squeeze(), y_rec=None, discriminator_forward=True)
+
+# # regularization (ignore reconstruction artifacts)
+# d_loss += F.l1_loss(out_gt, out_rec)
+
+# else:
+# d_loss = self.wl(wav = wav.detach().squeeze(),y_rec= y_pred.detach().squeeze(), discriminator=True).mean()
+# else:
+# d_loss = 0
+
+# # generator loss
+# gen_loss = self.wl(wav = None, y_rec = y_pred.squeeze(), generator=True)
+
+# gen_loss = gen_loss.mean()
+
+# return d_loss, gen_loss, y_pred.detach().cpu().numpy()
+
+# def length_to_mask(lengths):
+# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# return mask
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/slmadv_og.py b/stts_48khz/StyleTTS2_48khz/Modules/slmadv_og.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b8b81fde40780104d49b386954b8eaba8838fef
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/slmadv_og.py
@@ -0,0 +1,400 @@
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+class SLMAdversarialLoss(torch.nn.Module):
+
+ def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
+ super(SLMAdversarialLoss, self).__init__()
+ self.model = model
+ self.wl = wl
+ self.sampler = sampler
+
+ self.min_len = min_len
+ self.max_len = max_len
+ self.batch_percentage = batch_percentage
+
+ self.sig = sig
+ self.skip_update = skip_update
+
+ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
+
+ text_mask = length_to_mask(ref_lengths).to(ref_text.device)
+
+
+ # if text_mask.shape[1] > 512:
+ # text_mask = text_mask[:, :512]
+
+ bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
+
+
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ if use_ind and np.random.rand() < 0.5:
+ s_preds = s_trg
+ else:
+ num_steps = np.random.randint(3, 5)
+ if ref_s is not None:
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ features=ref_s, # reference from the same speaker as the embedding
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ else:
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+
+ s_dur = s_preds[:, 128:]
+ s = s_preds[:, :128]
+
+ d, _ = self.model.predictor(d_en, s_dur,
+ ref_lengths,
+ torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
+ text_mask)
+
+ bib = 0
+
+ output_lengths = []
+ attn_preds = []
+
+ # differentiable duration modeling
+ for _s2s_pred, _text_length in zip(d, ref_lengths):
+
+ _s2s_pred_org = _s2s_pred[:_text_length, :]
+
+ _s2s_pred = torch.sigmoid(_s2s_pred_org)
+ _dur_pred = _s2s_pred.sum(axis=-1)
+
+ l = int(torch.round(_s2s_pred.sum()).item())
+ t = torch.arange(0, l).expand(l)
+
+ t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
+
+ h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
+
+ out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
+ h.unsqueeze(1),
+ padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
+
+ output_lengths.append(l)
+
+ max_len = max(output_lengths)
+
+ with torch.no_grad():
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
+
+ s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
+ for bib in range(len(output_lengths)):
+ s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
+
+ asr_pred = t_en @ s2s_attn
+
+ _, p_pred = self.model.predictor(d_en, s_dur,
+ ref_lengths,
+ s2s_attn,
+ text_mask)
+
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
+ mel_len = min(mel_len, self.max_len // 2)
+
+ # get clips
+
+ en = []
+ p_en = []
+ sp = []
+
+ F0_fakes = []
+ N_fakes = []
+
+ wav = []
+
+ for bib in range(len(output_lengths)):
+ mel_length_pred = output_lengths[bib]
+ mel_length_gt = int(mel_input_length[bib].item() / 2)
+ if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
+ continue
+
+ sp.append(s_preds[bib])
+
+ random_start = np.random.randint(0, mel_length_pred - mel_len)
+ en.append(asr_pred[bib, :, random_start:random_start+mel_len])
+ p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
+
+ # get ground truth clips
+ random_start = np.random.randint(0, mel_length_gt - mel_len)
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(ref_text.device))
+
+ if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
+ break
+
+ if len(sp) <= 1:
+ return None
+
+ sp = torch.stack(sp)
+ wav = torch.stack(wav).float()
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+
+ F0_fake, N_fake = self.model.predictor(texts=p_en, style=sp[:, 128:], f0=True)
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
+
+ # discriminator loss
+ if (iters + 1) % self.skip_update == 0:
+ if np.random.randint(0, 2) == 0:
+ wav = y_rec_gt_pred
+ use_rec = True
+ else:
+ use_rec = False
+
+ crop_size = min(wav.size(-1), y_pred.size(-1))
+ if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
+ if wav.size(-1) > y_pred.size(-1):
+ real_GP = wav[:, : , :crop_size]
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
+ out_org = self.wl.discriminator_forward(wav.detach().squeeze())
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
+
+ if np.random.randint(0, 2) == 0:
+ d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean()
+ else:
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
+ else:
+ real_GP = y_pred[:, : , :crop_size]
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
+ out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
+
+ if np.random.randint(0, 2) == 0:
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean()
+ else:
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
+
+ # regularization (ignore length variation)
+ d_loss += loss_reg
+
+ out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
+ out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze())
+
+ # regularization (ignore reconstruction artifacts)
+ d_loss += F.l1_loss(out_gt, out_rec)
+
+ else:
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
+ else:
+ d_loss = 0
+
+ # generator loss
+ gen_loss = self.wl.generator(y_pred.squeeze())
+
+ gen_loss = gen_loss.mean()
+
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
+
+def length_to_mask(lengths):
+
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+# import torch
+# import numpy as np
+# import torch.nn.functional as F
+
+# class SLMAdversarialLoss(torch.nn.Module):
+
+# def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
+# super(SLMAdversarialLoss, self).__init__()
+# self.model = model
+# self.wl = wl
+# self.sampler = sampler
+
+# self.min_len = min_len
+# self.max_len = max_len
+# self.batch_percentage = batch_percentage
+
+# self.sig = sig
+# self.skip_update = skip_update
+
+# def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
+# text_mask = length_to_mask(ref_lengths).to(ref_text.device)
+# bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
+# d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
+
+# if use_ind and np.random.rand() < 0.5:
+# s_preds = s_trg
+# else:
+# num_steps = np.random.randint(3, 5)
+# if ref_s is not None:
+# s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
+# embedding=bert_dur,
+# embedding_scale=1,
+# features=ref_s, # reference from the same speaker as the embedding
+# embedding_mask_proba=0.1,
+# num_steps=num_steps).squeeze(1)
+# else:
+# s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
+# embedding=bert_dur,
+# embedding_scale=1,
+# embedding_mask_proba=0.1,
+# num_steps=num_steps).squeeze(1)
+
+# s_dur = s_preds[:, 128:]
+# s = s_preds[:, :128]
+
+# d, _ = self.model.predictor(d_en, s_dur,
+# ref_lengths,
+# torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
+# text_mask)
+
+# bib = 0
+
+# output_lengths = []
+# attn_preds = []
+
+# # differentiable duration modeling
+# for _s2s_pred, _text_length in zip(d, ref_lengths):
+
+# _s2s_pred_org = _s2s_pred[:_text_length, :]
+
+# _s2s_pred = torch.sigmoid(_s2s_pred_org)
+# _dur_pred = _s2s_pred.sum(axis=-1)
+
+# l = int(torch.round(_s2s_pred.sum()).item())
+# t = torch.arange(0, l).expand(l)
+
+# t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
+# loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
+
+# h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
+
+# out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
+# h.unsqueeze(1),
+# padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
+# attn_preds.append(F.softmax(out.squeeze(), dim=0))
+
+# output_lengths.append(l)
+
+# max_len = max(output_lengths)
+
+# with torch.no_grad():
+# t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
+
+# s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
+# for bib in range(len(output_lengths)):
+# s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
+
+# asr_pred = t_en @ s2s_attn
+
+# _, p_pred = self.model.predictor(d_en, s_dur,
+# ref_lengths,
+# s2s_attn,
+# text_mask)
+
+# mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
+# mel_len = min(mel_len, self.max_len // 2)
+
+# # get clips
+
+# en = []
+# p_en = []
+# sp = []
+
+# F0_fakes = []
+# N_fakes = []
+
+# wav = []
+
+# for bib in range(len(output_lengths)):
+# mel_length_pred = output_lengths[bib]
+# mel_length_gt = int(mel_input_length[bib].item() / 2)
+# if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
+# continue
+
+# sp.append(s_preds[bib])
+
+# random_start = np.random.randint(0, mel_length_pred - mel_len)
+# en.append(asr_pred[bib, :, random_start:random_start+mel_len])
+# p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
+
+# # get ground truth clips
+# random_start = np.random.randint(0, mel_length_gt - mel_len)
+# y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+# wav.append(torch.from_numpy(y).to(ref_text.device))
+
+# if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
+# break
+
+# if len(sp) <= 1:
+# return None
+
+# sp = torch.stack(sp)
+# wav = torch.stack(wav).float()
+# en = torch.stack(en)
+# p_en = torch.stack(p_en)
+
+# F0_fake, N_fake = self.model.predictor(texts=p_en, style=sp[:, 128:], f0=True)
+# y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
+
+# # discriminator loss
+# if (iters + 1) % self.skip_update == 0:
+# if np.random.randint(0, 2) == 0:
+# wav = y_rec_gt_pred
+# use_rec = True
+# else:
+# use_rec = False
+
+# crop_size = min(wav.size(-1), y_pred.size(-1))
+# if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
+# if wav.size(-1) > y_pred.size(-1):
+# real_GP = wav[:, : , :crop_size]
+# out_crop = self.wl(wav = real_GP.detach().squeeze(),y_rec=None, discriminator_forward=True)
+# out_org = self.wl(wav = wav.detach().squeeze(),y_rec=None, discriminator_forward=True)
+# loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
+
+# if np.random.randint(0, 2) == 0:
+# d_loss = self.wl(wav = real_GP.detach().squeeze(),y_rec= y_pred.detach().squeeze(), discriminator=True).mean()
+# else:
+# d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = y_pred.detach().squeeze(), discriminator=True).mean()
+# else:
+# real_GP = y_pred[:, : , :crop_size]
+# out_crop = self.wl(wav = real_GP.detach().squeeze(), y_rec=None, discriminator_forward=True)
+# out_org = self.wl(wav = y_pred.detach().squeeze(),y_rec=None, discriminator_forward=True)
+# loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
+
+# if np.random.randint(0, 2) == 0:
+# d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = real_GP.detach().squeeze(), discriminator=True ).mean()
+# else:
+# d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = y_pred.detach().squeeze(), discriminator=True).mean()
+
+# # regularization (ignore length variation)
+# d_loss += loss_reg
+
+# out_gt = self.wl(wav = y_rec_gt.detach().squeeze(),y_rec=None, discriminator_forward=True)
+# out_rec = self.wl(wav = y_rec_gt_pred.detach().squeeze(), y_rec=None, discriminator_forward=True)
+
+# # regularization (ignore reconstruction artifacts)
+# d_loss += F.l1_loss(out_gt, out_rec)
+
+# else:
+# d_loss = self.wl(wav = wav.detach().squeeze(),y_rec= y_pred.detach().squeeze(), discriminator=True).mean()
+# else:
+# d_loss = 0
+
+# # generator loss
+# gen_loss = self.wl(wav = None, y_rec = y_pred.squeeze(), generator=True)
+
+# gen_loss = gen_loss.mean()
+
+# return d_loss, gen_loss, y_pred.detach().cpu().numpy()
+
+# def length_to_mask(lengths):
+# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# return mask
diff --git a/stts_48khz/StyleTTS2_48khz/Modules/utils.py b/stts_48khz/StyleTTS2_48khz/Modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf5fd72d7e70c4e5b443b43566b6f69e1954ee8
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Modules/utils.py
@@ -0,0 +1,14 @@
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size*dilation - dilation)/2)
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/README.md b/stts_48khz/StyleTTS2_48khz/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d50e8ad7e117f73624656512bb203312aafdbb3c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/README.md
@@ -0,0 +1,125 @@
+# StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
+
+### Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani
+
+> In this paper, we present StyleTTS 2, a text-to-speech (TTS) model that leverages style diffusion and adversarial training with large speech language models (SLMs) to achieve human-level TTS synthesis. StyleTTS 2 differs from its predecessor by modeling styles as a latent random variable through diffusion models to generate the most suitable style for the text without requiring reference speech, achieving efficient latent diffusion while benefiting from the diverse speech synthesis offered by diffusion models. Furthermore, we employ large pre-trained SLMs, such as WavLM, as discriminators with our novel differentiable duration modeling for end-to-end training, resulting in improved speech naturalness. StyleTTS 2 surpasses human recordings on the single-speaker LJSpeech dataset and matches it on the multispeaker VCTK dataset as judged by native English speakers. Moreover, when trained on the LibriTTS dataset, our model outperforms previous publicly available models for zero-shot speaker adaptation. This work achieves the first human-level TTS synthesis on both single and multispeaker datasets, showcasing the potential of style diffusion and adversarial training with large SLMs.
+
+Paper: [https://arxiv.org/abs/2306.07691](https://arxiv.org/abs/2306.07691)
+
+Audio samples: [https://styletts2.github.io/](https://styletts2.github.io/)
+
+Online demo: [Hugging Face](https://huggingface.co/spaces/styletts2/styletts2) (thank [@fakerybakery](https://github.com/fakerybakery) for the wonderful online demo)
+
+[](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/) [](https://discord.gg/ha8sxdG2K4)
+
+## TODO
+- [x] Training and inference demo code for single-speaker models (LJSpeech)
+- [x] Test training code for multi-speaker models (VCTK and LibriTTS)
+- [x] Finish demo code for multispeaker model and upload pre-trained models
+- [x] Add a finetuning script for new speakers with base pre-trained multispeaker models
+- [ ] Fix DDP (accelerator) for `train_second.py` **(I have tried everything I could to fix this but had no success, so if you are willing to help, please see [#7](https://github.com/yl4579/StyleTTS2/issues/7))**
+
+## Pre-requisites
+1. Python >= 3.7
+2. Clone this repository:
+```bash
+git clone https://github.com/yl4579/StyleTTS2.git
+cd StyleTTS2
+```
+3. Install python requirements:
+```bash
+pip install -r requirements.txt
+```
+On Windows add:
+```bash
+pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -U
+```
+Also install phonemizer and espeak if you want to run the demo:
+```bash
+pip install phonemizer
+sudo apt-get install espeak-ng
+```
+4. Download and extract the [LJSpeech dataset](https://keithito.com/LJ-Speech-Dataset/), unzip to the data folder and upsample the data to 24 kHz. The text aligner and pitch extractor are pre-trained on 24 kHz data, but you can easily change the preprocessing and re-train them using your own preprocessing.
+For LibriTTS, you will need to combine train-clean-360 with train-clean-100 and rename the folder train-clean-460 (see [val_list_libritts.txt](https://github.com/yl4579/StyleTTS/blob/main/Data/val_list_libritts.txt) as an example).
+
+## Training
+First stage training:
+```bash
+accelerate launch train_first.py --config_path ./Configs/config.yml
+```
+Second stage training **(DDP version not working, so the current version uses DP, again see [#7](https://github.com/yl4579/StyleTTS2/issues/7) if you want to help)**:
+```bash
+python train_second.py --config_path ./Configs/config.yml
+```
+You can run both consecutively and it will train both the first and second stages. The model will be saved in the format "epoch_1st_%05d.pth" and "epoch_2nd_%05d.pth". Checkpoints and Tensorboard logs will be saved at `log_dir`.
+
+The data list format needs to be `filename.wav|transcription|speaker`, see [val_list.txt](https://github.com/yl4579/StyleTTS2/blob/main/Data/val_list.txt) as an example. The speaker labels are needed for multi-speaker models because we need to sample reference audio for style diffusion model training.
+
+### Important Configurations
+In [config.yml](https://github.com/yl4579/StyleTTS2/blob/main/Configs/config.yml), there are a few important configurations to take care of:
+- `OOD_data`: The path for out-of-distribution texts for SLM adversarial training. The format should be `text|anything`.
+- `min_length`: Minimum length of OOD texts for training. This is to make sure the synthesized speech has a minimum length.
+- `max_len`: Maximum length of audio for training. The unit is frame. Since the default hop size is 300, one frame is approximately `300 / 24000` (0.0125) second. Lowering this if you encounter the out-of-memory issue.
+- `multispeaker`: Set to true if you want to train a multispeaker model. This is needed because the architecture of the denoiser is different for single and multispeaker models.
+- `batch_percentage`: This is to make sure during SLM adversarial training there are no out-of-memory (OOM) issues. If you encounter OOM problem, please set a lower number for this.
+
+### Pre-trained modules
+In [Utils](https://github.com/yl4579/StyleTTS2/tree/main/Utils) folder, there are three pre-trained models:
+- **[ASR](https://github.com/yl4579/StyleTTS2/tree/main/Utils/ASR) folder**: It contains the pre-trained text aligner, which was pre-trained on English (LibriTTS), Japanese (JVS), and Chinese (AiShell) corpus. It works well for most other languages without fine-tuning, but you can always train your own text aligner with the code here: [yl4579/AuxiliaryASR](https://github.com/yl4579/AuxiliaryASR).
+- **[JDC](https://github.com/yl4579/StyleTTS2/tree/main/Utils/JDC) folder**: It contains the pre-trained pitch extractor, which was pre-trained on English (LibriTTS) corpus only. However, it works well for other languages too because F0 is independent of language. If you want to train on singing corpus, it is recommended to train a new pitch extractor with the code here: [yl4579/PitchExtractor](https://github.com/yl4579/PitchExtractor).
+- **[PLBERT](https://github.com/yl4579/StyleTTS2/tree/main/Utils/PLBERT) folder**: It contains the pre-trained [PL-BERT](https://arxiv.org/abs/2301.08810) model, which was pre-trained on English (Wikipedia) corpus only. It probably does not work very well on other languages, so you will need to train a different PL-BERT for different languages using the repo here: [yl4579/PL-BERT](https://github.com/yl4579/PL-BERT). You can also use the [multilingual PL-BERT](https://huggingface.co/papercup-ai/multilingual-pl-bert) which supports 14 languages.
+
+### Common Issues
+- **Loss becomes NaN**: If it is the first stage, please make sure you do not use mixed precision, as it can cause loss becoming NaN for some particular datasets when the batch size is not set properly (need to be more than 16 to work well). For the second stage, please also experiment with different batch sizes, with higher batch sizes being more likely to cause NaN loss values. We recommend the batch size to be 16. You can refer to issues [#10](https://github.com/yl4579/StyleTTS2/issues/10) and [#11](https://github.com/yl4579/StyleTTS2/issues/11) for more details.
+- **Out of memory**: Please either use lower `batch_size` or `max_len`. You may refer to issue [#10](https://github.com/yl4579/StyleTTS2/issues/10) for more information.
+- **Non-English dataset**: You can train on any language you want, but you will need to use a pre-trained PL-BERT model for that language. We have a pre-trained [multilingual PL-BERT](https://huggingface.co/papercup-ai/multilingual-pl-bert) that supports 14 languages. You may refer to [yl4579/StyleTTS#10](https://github.com/yl4579/StyleTTS/issues/10) and [#70](https://github.com/yl4579/StyleTTS2/issues/70) for some examples to train on Chinese datasets.
+
+## Finetuning
+The script is modified from `train_second.py` which uses DP, as DDP does not work for `train_second.py`. Please see the bold section above if you are willing to help with this problem.
+```bash
+python train_finetune.py --config_path ./Configs/config_ft.yml
+```
+Please make sure you have the LibriTTS checkpoint downloaded and unzipped under the folder. The default configuration `config_ft.yml` finetunes on LJSpeech with 1 hour of speech data (around 1k samples) for 50 epochs. This took about 4 hours to finish on four NVidia A100. The quality is slightly worse (similar to NaturalSpeech on LJSpeech) than LJSpeech model trained from scratch with 24 hours of speech data, which took around 2.5 days to finish on four A100. The samples can be found at [#65 (comment)](https://github.com/yl4579/StyleTTS2/discussions/65#discussioncomment-7668393).
+
+If you are using a **single GPU** (because the script doesn't work with DDP) and want to save training speed and VRAM, you can do (thank [@korakoe](https://github.com/korakoe) for making the script at [#100](https://github.com/yl4579/StyleTTS2/pull/100)):
+```bash
+accelerate launch --mixed_precision=fp16 --num_processes=1 train_finetune_accelerate.py --config_path ./Configs/config_ft.yml
+```
+[](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Finetune_Demo.ipynb)
+
+### Common Issues
+[@Kreevoz](https://github.com/Kreevoz) has made detailed notes on common issues in finetuning, with suggestions in maximizing audio quality: [#81](https://github.com/yl4579/StyleTTS2/discussions/81). Some of these also apply to training from scratch. [@IIEleven11](https://github.com/IIEleven11) has also made a guideline for fine-tuning: [#128](https://github.com/yl4579/StyleTTS2/discussions/128).
+
+- **Out of memory after `joint_epoch`**: This is likely because your GPU RAM is not big enough for SLM adversarial training run. You may skip that but the quality could be worse. Setting `joint_epoch` a larger number than `epochs` could skip the SLM advesariral training.
+
+## Inference
+Please refer to [Inference_LJSpeech.ipynb](https://github.com/yl4579/StyleTTS2/blob/main/Demo/Inference_LJSpeech.ipynb) (single-speaker) and [Inference_LibriTTS.ipynb](https://github.com/yl4579/StyleTTS2/blob/main/Demo/Inference_LibriTTS.ipynb) (multi-speaker) for details. For LibriTTS, you will also need to download [reference_audio.zip](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/resolve/main/reference_audio.zip) and unzip it under the `demo` before running the demo.
+
+- The pretrained StyleTTS 2 on LJSpeech corpus in 24 kHz can be downloaded at [https://huggingface.co/yl4579/StyleTTS2-LJSpeech/tree/main](https://huggingface.co/yl4579/StyleTTS2-LJSpeech/tree/main).
+
+ [](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Demo_LJSpeech.ipynb)
+
+- The pretrained StyleTTS 2 model on LibriTTS can be downloaded at [https://huggingface.co/yl4579/StyleTTS2-LibriTTS/tree/main](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/tree/main).
+
+ [](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Demo_LibriTTS.ipynb)
+
+
+You can import StyleTTS 2 and run it in your own code. However, the inference depends on a GPL-licensed package, so it is not included directly in this repository. A [GPL-licensed fork](https://github.com/NeuralVox/StyleTTS2) has an importable script, as well as an experimental streaming API, etc. A [fully MIT-licensed package](https://pypi.org/project/styletts2/) that uses gruut (albeit lower quality due to mismatch between phonemizer and gruut) is also available.
+
+***Before using these pre-trained models, you agree to inform the listeners that the speech samples are synthesized by the pre-trained models, unless you have the permission to use the voice you synthesize. That is, you agree to only use voices whose speakers grant the permission to have their voice cloned, either directly or by license before making synthesized voices public, or you have to publicly announce that these voices are synthesized if you do not have the permission to use these voices.***
+
+### Common Issues
+- **High-pitched background noise**: This is caused by numerical float differences in older GPUs. For more details, please refer to issue [#13](https://github.com/yl4579/StyleTTS2/issues/13). Basically, you will need to use more modern GPUs or do inference on CPUs.
+- **Pre-trained model license**: You only need to abide by the above rules if you use **the pre-trained models** and the voices are **NOT** in the training set, i.e., your reference speakers are not from any open access dataset. For more details of rules to use the pre-trained models, please see [#37](https://github.com/yl4579/StyleTTS2/issues/37).
+
+## References
+- [archinetai/audio-diffusion-pytorch](https://github.com/archinetai/audio-diffusion-pytorch)
+- [jik876/hifi-gan](https://github.com/jik876/hifi-gan)
+- [rishikksh20/iSTFTNet-pytorch](https://github.com/rishikksh20/iSTFTNet-pytorch)
+- [nii-yamagishilab/project-NN-Pytorch-scripts/project/01-nsf](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf)
+
+## License
+
+Code: MIT License
+
+Pre-Trained Models: Before using these pre-trained models, you agree to inform the listeners that the speech samples are synthesized by the pre-trained models, unless you have the permission to use the voice you synthesize. That is, you agree to only use voices whose speakers grant the permission to have their voice cloned, either directly or by license before making synthesized voices public, or you have to publicly announce that these voices are synthesized if you do not have the permission to use these voices.
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__init__.py b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__init__.py
@@ -0,0 +1 @@
+
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13b4b5c779b55ff1db17baabca99e56d8098ef7f
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d19379da4fb920c0d2e6c728fc5e4230df786f09
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..718b87c8c9503efb9865c00f20019445e5f0d344
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89ae7fde0a48247842ab29cd5c4166dc68338ccc
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25f0d8c16f75e3e3e690a432391200ce0ee991bf
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67fcd3df0d13a7adb43ceaf57f5553e64bce6dd6
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/layers.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fe7d987baa9da7e77f171d099e8c9754e15d7ad
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2d3f68e92e1b1d4cc23132736ece75d6ea12687
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0de15e0ec2f78bf58120076e0965271b065341ab
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/ASR/__pycache__/models.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/config copy.yml b/stts_48khz/StyleTTS2_48khz/Utils/ASR/config copy.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fb268a0a19e8fbfb076c27a9fc49625591ea1f39
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/ASR/config copy.yml
@@ -0,0 +1,26 @@
+log_dir: "Checkpoint_new_plus"
+save_freq: 2
+device: "cuda"
+epochs: 200
+batch_size: 64
+pretrained_model: ""
+train_data: "/home/austin/disk2/llmvcs/tt/AuxiliaryASR/Data/train_list_plus.csv"
+val_data: "/home/austin/disk2/llmvcs/tt/AuxiliaryASR/Data/val_list.txt"
+
+preprocess_parasm:
+ sr: 24000
+ spect_params:
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+ mel_params:
+ n_mels: 80
+
+model_params:
+ input_dim: 80
+ hidden_dim: 256
+ n_token: 178
+ token_embedding_dim: 512
+
+optimizer_params:
+ lr: 0.0005
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/config.yml b/stts_48khz/StyleTTS2_48khz/Utils/ASR/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..10a2c84e3882b1e666b228f712dabf13e30fe20d
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/ASR/config.yml
@@ -0,0 +1,26 @@
+log_dir: "Checkpoint_new_plus"
+save_freq: 2
+device: "cuda"
+epochs: 200
+batch_size: 64
+pretrained_model: ""
+train_data: "/home/austin/disk2/llmvcs/tt/AuxiliaryASR/Data/train_list_plus.csv"
+val_data: "/home/austin/disk2/llmvcs/tt/AuxiliaryASR/Data/val_list.txt"
+
+preprocess_parasm:
+ sr: 48000
+ spect_params:
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+ mel_params:
+ n_mels: 80
+
+model_params:
+ input_dim: 80
+ hidden_dim: 256
+ n_token: 178
+ token_embedding_dim: 512
+
+optimizer_params:
+ lr: 0.0005
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/epoch_00070.pth b/stts_48khz/StyleTTS2_48khz/Utils/ASR/epoch_00070.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7ed4fb67a10f40eb96c0934b42567355890b75af
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/ASR/epoch_00070.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a017b00e8955827cd7522b097752efc628e62e05f209f009e0bbb7a970f468cb
+size 94572980
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/layers.py b/stts_48khz/StyleTTS2_48khz/Utils/ASR/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc6567b47ef60e9a64f85b7c088b7fac1683fdb5
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/ASR/layers.py
@@ -0,0 +1,354 @@
+import math
+import torch
+from torch import nn
+from typing import Optional, Any
+from torch import Tensor
+import torch.nn.functional as F
+import torchaudio
+import torchaudio.functional as audio_F
+
+import random
+random.seed(0)
+
+
+def _get_activation_fn(activ):
+ if activ == 'relu':
+ return nn.ReLU()
+ elif activ == 'lrelu':
+ return nn.LeakyReLU(0.2)
+ elif activ == 'swish':
+ return lambda x: x*torch.sigmoid(x)
+ else:
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert(kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+class CausualConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
+ super(CausualConv, self).__init__()
+ if padding is None:
+ assert(kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
+ else:
+ self.padding = padding * 2
+ self.conv = nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=self.padding,
+ dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = x[:, :, :-self.padding]
+ return x
+
+class CausualBlock(nn.Module):
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
+ super(CausualBlock, self).__init__()
+ self.blocks = nn.ModuleList([
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
+ for i in range(n_conv)])
+
+ def forward(self, x):
+ for block in self.blocks:
+ res = x
+ x = block(x)
+ x += res
+ return x
+
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
+ layers = [
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
+ _get_activation_fn(activ),
+ nn.BatchNorm1d(hidden_dim),
+ nn.Dropout(p=dropout_p),
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
+ _get_activation_fn(activ),
+ nn.Dropout(p=dropout_p)
+ ]
+ return nn.Sequential(*layers)
+
+class ConvBlock(nn.Module):
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
+ super().__init__()
+ self._n_groups = 8
+ self.blocks = nn.ModuleList([
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
+ for i in range(n_conv)])
+
+
+ def forward(self, x):
+ for block in self.blocks:
+ res = x
+ x = block(x)
+ x += res
+ return x
+
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
+ layers = [
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
+ _get_activation_fn(activ),
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
+ nn.Dropout(p=dropout_p),
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
+ _get_activation_fn(activ),
+ nn.Dropout(p=dropout_p)
+ ]
+ return nn.Sequential(*layers)
+
+class LocationLayer(nn.Module):
+ def __init__(self, attention_n_filters, attention_kernel_size,
+ attention_dim):
+ super(LocationLayer, self).__init__()
+ padding = int((attention_kernel_size - 1) / 2)
+ self.location_conv = ConvNorm(2, attention_n_filters,
+ kernel_size=attention_kernel_size,
+ padding=padding, bias=False, stride=1,
+ dilation=1)
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
+ bias=False, w_init_gain='tanh')
+
+ def forward(self, attention_weights_cat):
+ processed_attention = self.location_conv(attention_weights_cat)
+ processed_attention = processed_attention.transpose(1, 2)
+ processed_attention = self.location_dense(processed_attention)
+ return processed_attention
+
+
+class Attention(nn.Module):
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
+ attention_location_n_filters, attention_location_kernel_size):
+ super(Attention, self).__init__()
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
+ bias=False, w_init_gain='tanh')
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
+ w_init_gain='tanh')
+ self.v = LinearNorm(attention_dim, 1, bias=False)
+ self.location_layer = LocationLayer(attention_location_n_filters,
+ attention_location_kernel_size,
+ attention_dim)
+ self.score_mask_value = -float("inf")
+
+ def get_alignment_energies(self, query, processed_memory,
+ attention_weights_cat):
+ """
+ PARAMS
+ ------
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
+ RETURNS
+ -------
+ alignment (batch, max_time)
+ """
+
+ processed_query = self.query_layer(query.unsqueeze(1))
+ processed_attention_weights = self.location_layer(attention_weights_cat)
+ energies = self.v(torch.tanh(
+ processed_query + processed_attention_weights + processed_memory))
+
+ energies = energies.squeeze(-1)
+ return energies
+
+ def forward(self, attention_hidden_state, memory, processed_memory,
+ attention_weights_cat, mask):
+ """
+ PARAMS
+ ------
+ attention_hidden_state: attention rnn last output
+ memory: encoder outputs
+ processed_memory: processed encoder outputs
+ attention_weights_cat: previous and cummulative attention weights
+ mask: binary mask for padded data
+ """
+ alignment = self.get_alignment_energies(
+ attention_hidden_state, processed_memory, attention_weights_cat)
+
+ if mask is not None:
+ alignment.data.masked_fill_(mask, self.score_mask_value)
+
+ attention_weights = F.softmax(alignment, dim=1)
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
+ attention_context = attention_context.squeeze(1)
+
+ return attention_context, attention_weights
+
+
+class ForwardAttentionV2(nn.Module):
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
+ attention_location_n_filters, attention_location_kernel_size):
+ super(ForwardAttentionV2, self).__init__()
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
+ bias=False, w_init_gain='tanh')
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
+ w_init_gain='tanh')
+ self.v = LinearNorm(attention_dim, 1, bias=False)
+ self.location_layer = LocationLayer(attention_location_n_filters,
+ attention_location_kernel_size,
+ attention_dim)
+ self.score_mask_value = -float(1e20)
+
+ def get_alignment_energies(self, query, processed_memory,
+ attention_weights_cat):
+ """
+ PARAMS
+ ------
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
+ RETURNS
+ -------
+ alignment (batch, max_time)
+ """
+
+ processed_query = self.query_layer(query.unsqueeze(1))
+ processed_attention_weights = self.location_layer(attention_weights_cat)
+ energies = self.v(torch.tanh(
+ processed_query + processed_attention_weights + processed_memory))
+
+ energies = energies.squeeze(-1)
+ return energies
+
+ def forward(self, attention_hidden_state, memory, processed_memory,
+ attention_weights_cat, mask, log_alpha):
+ """
+ PARAMS
+ ------
+ attention_hidden_state: attention rnn last output
+ memory: encoder outputs
+ processed_memory: processed encoder outputs
+ attention_weights_cat: previous and cummulative attention weights
+ mask: binary mask for padded data
+ """
+ log_energy = self.get_alignment_energies(
+ attention_hidden_state, processed_memory, attention_weights_cat)
+
+ #log_energy =
+
+ if mask is not None:
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
+
+ #attention_weights = F.softmax(alignment, dim=1)
+
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
+
+ #log_total_score = log_alpha + content_score
+
+ #previous_attention_weights = attention_weights_cat[:,0,:]
+
+ log_alpha_shift_padded = []
+ max_time = log_energy.size(1)
+ for sft in range(2):
+ shifted = log_alpha[:,:max_time-sft]
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
+
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
+
+ log_alpha_new = biased + log_energy
+
+ attention_weights = F.softmax(log_alpha_new, dim=1)
+
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
+ attention_context = attention_context.squeeze(1)
+
+ return attention_context, attention_weights, log_alpha_new
+
+
+class PhaseShuffle2d(nn.Module):
+ def __init__(self, n=2):
+ super(PhaseShuffle2d, self).__init__()
+ self.n = n
+ self.random = random.Random(1)
+
+ def forward(self, x, move=None):
+ # x.size = (B, C, M, L)
+ if move is None:
+ move = self.random.randint(-self.n, self.n)
+
+ if move == 0:
+ return x
+ else:
+ left = x[:, :, :, :move]
+ right = x[:, :, :, move:]
+ shuffled = torch.cat([right, left], dim=3)
+ return shuffled
+
+class PhaseShuffle1d(nn.Module):
+ def __init__(self, n=2):
+ super(PhaseShuffle1d, self).__init__()
+ self.n = n
+ self.random = random.Random(1)
+
+ def forward(self, x, move=None):
+ # x.size = (B, C, M, L)
+ if move is None:
+ move = self.random.randint(-self.n, self.n)
+
+ if move == 0:
+ return x
+ else:
+ left = x[:, :, :move]
+ right = x[:, :, move:]
+ shuffled = torch.cat([right, left], dim=2)
+
+ return shuffled
+
+class MFCC(nn.Module):
+ def __init__(self, n_mfcc=40, n_mels=80):
+ super(MFCC, self).__init__()
+ self.n_mfcc = n_mfcc
+ self.n_mels = n_mels
+ self.norm = 'ortho'
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
+ self.register_buffer('dct_mat', dct_mat)
+
+ def forward(self, mel_specgram):
+ if len(mel_specgram.shape) == 2:
+ mel_specgram = mel_specgram.unsqueeze(0)
+ unsqueezed = True
+ else:
+ unsqueezed = False
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
+ # -> (channel, time, n_mfcc).tranpose(...)
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
+
+ # unpack batch
+ if unsqueezed:
+ mfcc = mfcc.squeeze(0)
+ return mfcc
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/ASR/models.py b/stts_48khz/StyleTTS2_48khz/Utils/ASR/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..278c8d06a1cafe77afcac208617b8716357d38d3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/ASR/models.py
@@ -0,0 +1,186 @@
+import math
+import torch
+from torch import nn
+from torch.nn import TransformerEncoder
+import torch.nn.functional as F
+from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
+
+class ASRCNN(nn.Module):
+ def __init__(self,
+ input_dim=80,
+ hidden_dim=256,
+ n_token=35,
+ n_layers=6,
+ token_embedding_dim=256,
+
+ ):
+ super().__init__()
+ self.n_token = n_token
+ self.n_down = 1
+ self.to_mfcc = MFCC()
+ self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2)
+ self.cnns = nn.Sequential(
+ *[nn.Sequential(
+ ConvBlock(hidden_dim),
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim)
+ ) for n in range(n_layers)])
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
+ self.ctc_linear = nn.Sequential(
+ LinearNorm(hidden_dim//2, hidden_dim),
+ nn.ReLU(),
+ LinearNorm(hidden_dim, n_token))
+ self.asr_s2s = ASRS2S(
+ embedding_dim=token_embedding_dim,
+ hidden_dim=hidden_dim//2,
+ n_token=n_token)
+
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
+ x = self.to_mfcc(x)
+ x = self.init_cnn(x)
+ x = self.cnns(x)
+ x = self.projection(x)
+ x = x.transpose(1, 2)
+ ctc_logit = self.ctc_linear(x)
+ if text_input is not None:
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
+ return ctc_logit, s2s_logit, s2s_attn
+ else:
+ return ctc_logit
+
+ def get_feature(self, x):
+ x = self.to_mfcc(x.squeeze(1))
+ x = self.init_cnn(x)
+ x = self.cnns(x)
+ x = self.projection(x)
+ return x
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device)
+ return mask
+
+ def get_future_mask(self, out_length, unmask_future_steps=0):
+ """
+ Args:
+ out_length (int): returned mask shape is (out_length, out_length).
+ unmask_futre_steps (int): unmasking future step size.
+ Return:
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
+ """
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
+ return mask
+
+class ASRS2S(nn.Module):
+ def __init__(self,
+ embedding_dim=256,
+ hidden_dim=512,
+ n_location_filters=32,
+ location_kernel_size=63,
+ n_token=40):
+ super(ASRS2S, self).__init__()
+ self.embedding = nn.Embedding(n_token, embedding_dim)
+ val_range = math.sqrt(6 / hidden_dim)
+ self.embedding.weight.data.uniform_(-val_range, val_range)
+
+ self.decoder_rnn_dim = hidden_dim
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
+ self.attention_layer = Attention(
+ self.decoder_rnn_dim,
+ hidden_dim,
+ hidden_dim,
+ n_location_filters,
+ location_kernel_size
+ )
+ self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
+ self.project_to_hidden = nn.Sequential(
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim),
+ nn.Tanh())
+ self.sos = 1
+ self.eos = 2
+
+ def initialize_decoder_states(self, memory, mask):
+ """
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
+ """
+ B, L, H = memory.shape
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
+ self.memory = memory
+ self.processed_memory = self.attention_layer.memory_layer(memory)
+ self.mask = mask
+ self.unk_index = 3
+ self.random_mask = 0.1
+
+ def forward(self, memory, memory_mask, text_input):
+ """
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
+ moemory_mask.shape = (B, L, )
+ texts_input.shape = (B, T)
+ """
+ self.initialize_decoder_states(memory, memory_mask)
+ # text random mask
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
+ _text_input = text_input.clone()
+ _text_input.masked_fill_(random_mask, self.unk_index)
+ decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
+ start_embedding = self.embedding(
+ torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
+ decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
+
+ hidden_outputs, logit_outputs, alignments = [], [], []
+ while len(hidden_outputs) < decoder_inputs.size(0):
+
+ decoder_input = decoder_inputs[len(hidden_outputs)]
+ hidden, logit, attention_weights = self.decode(decoder_input)
+ hidden_outputs += [hidden]
+ logit_outputs += [logit]
+ alignments += [attention_weights]
+
+ hidden_outputs, logit_outputs, alignments = \
+ self.parse_decoder_outputs(
+ hidden_outputs, logit_outputs, alignments)
+
+ return hidden_outputs, logit_outputs, alignments
+
+
+ def decode(self, decoder_input):
+
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
+ cell_input,
+ (self.decoder_hidden, self.decoder_cell))
+
+ attention_weights_cat = torch.cat(
+ (self.attention_weights.unsqueeze(1),
+ self.attention_weights_cum.unsqueeze(1)),dim=1)
+
+ self.attention_context, self.attention_weights = self.attention_layer(
+ self.decoder_hidden,
+ self.memory,
+ self.processed_memory,
+ attention_weights_cat,
+ self.mask)
+
+ self.attention_weights_cum += self.attention_weights
+
+ hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
+ hidden = self.project_to_hidden(hidden_and_context)
+
+ # dropout to increasing g
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
+
+ return hidden, logit, self.attention_weights
+
+ def parse_decoder_outputs(self, hidden, logit, alignments):
+
+ # -> [B, T_out + 1, max_time]
+ alignments = torch.stack(alignments).transpose(0,1)
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
+
+ return hidden, logit, alignments
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/PE_48khz_epoch_00060.pth b/stts_48khz/StyleTTS2_48khz/Utils/JDC/PE_48khz_epoch_00060.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c611cae9992d277299f14a94ce8bc3ad3abda379
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/JDC/PE_48khz_epoch_00060.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:423f5642d7ba276a9de51edf975902345411fb1014092e4356585440034d7d83
+size 63056556
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/__init__.py b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__init__.py
@@ -0,0 +1 @@
+
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b86f0015b6163a192a5137b8162f1d3162563bdc
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0434ab27a5d628eb8f03175072862d95b64b6e38
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d4971f23f8ec5c246047c82c7c0ed9d6a20d0c4
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9bbf1de8a50c799458aec8a6de6f35f8517ea2cb
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff7b92ef706cc40346c1e1dddb9218cc6662b84c
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..543364866c268c2a1c1d83a96d810a4ec1a9c794
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/JDC/__pycache__/model.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/bst_rmvpe_48k.t7 b/stts_48khz/StyleTTS2_48khz/Utils/JDC/bst_rmvpe_48k.t7
new file mode 100644
index 0000000000000000000000000000000000000000..96cb49d0051126c572e4e391b4a439ce872b049e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/JDC/bst_rmvpe_48k.t7
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20cd418ed4beefd670b0d22da1a3c77892ecdde435a16a60c11755f6240d633b
+size 21029038
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/config.yml b/stts_48khz/StyleTTS2_48khz/Utils/JDC/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f466f1bba84c6c4c08474bf4af6788e24c2e1aba
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/JDC/config.yml
@@ -0,0 +1,17 @@
+log_dir: "Checkpoint_200k"
+save_freq: 5
+device: "cuda"
+epochs: 100
+batch_size: 64
+pretrained_model: ""
+train_data: "/home/austin/disk1/stts-zs_cleaning/data/train_List_PITCH_updated_48khz_200k.csv"
+val_data: "/home/austin/disk1/stts-zs_cleaning/data/val_List_PITCH_updated_48khz.csv"
+num_workers: 64
+
+
+optimizer_params:
+ lr: 0.0003
+
+loss_params:
+ lambda_f0: 0.1
+
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/JDC/model.py b/stts_48khz/StyleTTS2_48khz/Utils/JDC/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..83cd266d1cd6f054d0684e8e1a60496044048605
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/JDC/model.py
@@ -0,0 +1,190 @@
+"""
+Implementation of model from:
+Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
+Convolutional Recurrent Neural Networks" (2019)
+Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
+"""
+import torch
+from torch import nn
+
+class JDCNet(nn.Module):
+ """
+ Joint Detection and Classification Network model for singing voice melody.
+ """
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
+ super().__init__()
+ self.num_class = num_class
+
+ # input = (b, 1, 31, 513), b = batch size
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
+ nn.BatchNorm2d(num_features=64),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
+ )
+
+ # res blocks
+ self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
+ self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
+
+ # pool block
+ self.pool_block = nn.Sequential(
+ nn.BatchNorm2d(num_features=256),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
+ nn.Dropout(p=0.2),
+ )
+
+ # maxpool layers (for auxiliary network inputs)
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
+
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
+ self.detector_conv = nn.Sequential(
+ nn.Conv2d(640, 256, 1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.Dropout(p=0.2),
+ )
+
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
+ self.bilstm_classifier = nn.LSTM(
+ input_size=512, hidden_size=256,
+ batch_first=True, bidirectional=True) # (b, 31, 512)
+
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
+ self.bilstm_detector = nn.LSTM(
+ input_size=512, hidden_size=256,
+ batch_first=True, bidirectional=True) # (b, 31, 512)
+
+ # input: (b * 31, 512)
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
+
+ # input: (b * 31, 512)
+ self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
+
+ # initialize weights
+ self.apply(self.init_weights)
+
+ def get_feature_GAN(self, x):
+ seq_len = x.shape[-2]
+ x = x.float().transpose(-1, -2)
+
+ convblock_out = self.conv_block(x)
+
+ resblock1_out = self.res_block1(convblock_out)
+ resblock2_out = self.res_block2(resblock1_out)
+ resblock3_out = self.res_block3(resblock2_out)
+ poolblock_out = self.pool_block[0](resblock3_out)
+ poolblock_out = self.pool_block[1](poolblock_out)
+
+ return poolblock_out.transpose(-1, -2)
+
+ def get_feature(self, x):
+ seq_len = x.shape[-2]
+ x = x.float().transpose(-1, -2)
+
+ convblock_out = self.conv_block(x)
+
+ resblock1_out = self.res_block1(convblock_out)
+ resblock2_out = self.res_block2(resblock1_out)
+ resblock3_out = self.res_block3(resblock2_out)
+ poolblock_out = self.pool_block[0](resblock3_out)
+ poolblock_out = self.pool_block[1](poolblock_out)
+
+ return self.pool_block[2](poolblock_out)
+
+ def forward(self, x):
+ """
+ Returns:
+ classification_prediction, detection_prediction
+ sizes: (b, 31, 722), (b, 31, 2)
+ """
+ ###############################
+ # forward pass for classifier #
+ ###############################
+ seq_len = x.shape[-1]
+ x = x.float().transpose(-1, -2)
+
+ convblock_out = self.conv_block(x)
+
+ resblock1_out = self.res_block1(convblock_out)
+ resblock2_out = self.res_block2(resblock1_out)
+ resblock3_out = self.res_block3(resblock2_out)
+
+
+ poolblock_out = self.pool_block[0](resblock3_out)
+ poolblock_out = self.pool_block[1](poolblock_out)
+ GAN_feature = poolblock_out.transpose(-1, -2)
+ poolblock_out = self.pool_block[2](poolblock_out)
+
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
+ classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
+
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
+ classifier_out = self.classifier(classifier_out)
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
+
+ # sizes: (b, 31, 722), (b, 31, 2)
+ # classifier output consists of predicted pitch classes per frame
+ # detector output consists of: (isvoice, notvoice) estimates per frame
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
+
+ @staticmethod
+ def init_weights(m):
+ if isinstance(m, nn.Linear):
+ nn.init.kaiming_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
+ for p in m.parameters():
+ if p.data is None:
+ continue
+
+ if len(p.shape) >= 2:
+ nn.init.orthogonal_(p.data)
+ else:
+ nn.init.normal_(p.data)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
+ super().__init__()
+ self.downsample = in_channels != out_channels
+
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
+ self.pre_conv = nn.Sequential(
+ nn.BatchNorm2d(num_features=in_channels),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
+ )
+
+ # conv layers
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
+ )
+
+ # 1 x 1 convolution layer to match the feature dimensions
+ self.conv1by1 = None
+ if self.downsample:
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
+
+ def forward(self, x):
+ x = self.pre_conv(x)
+ if self.downsample:
+ x = self.conv(x) + self.conv1by1(x)
+ else:
+ x = self.conv(x) + x
+ return x
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e25270bed05664ec6759e719a2c1c2a4d805706
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c1ebd9b78e6883e2cd8f524babbdf9b08453e6e
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1c5481ee13db1b3ed8ddb3d1952d1f6e8d3e9fe
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/__pycache__/util.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/config.yml b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f5c7212028813fc4a74888eed87d89e5c6c97806
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/config.yml
@@ -0,0 +1,30 @@
+log_dir: "Checkpoint_40b_Cotlet_Lite"
+mixed_precision: "bf16"
+data_folder: "/home/ubuntu/001_PLBERT_JA/40B_dataset_W_fixed"
+batch_size: 26
+save_interval: 10000
+log_interval: 100
+num_process: 1 # number of GPUs
+num_steps: 1000000
+
+dataset_params:
+ tokenizer: ""
+ token_separator: " " # token used for phoneme separator (space)
+ token_mask: "M" # token used for phoneme mask (M)
+ word_separator: 2 # token used for word separator ()
+ token_maps: "/home/ubuntu/001_PLBERT_JA/token_maps_jp_200k_0.pkl" # token map path
+
+ max_mel_length: 512 # max phoneme length
+
+ word_mask_prob: 0.15 # probability to mask the entire word
+ phoneme_mask_prob: 0.1 # probability to mask each phoneme
+ replace_prob: 0.2 # probablity to replace phonemes
+
+model_params:
+ vocab_size: 178
+ hidden_size: 768
+ num_attention_heads: 12
+ intermediate_size: 2048
+ max_position_embeddings: 512
+ num_hidden_layers: 12
+ dropout: 0.1
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/util.py b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dd03f997b5d5dcce4ec99a97d2d4af994d0720
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/PLBERT/util.py
@@ -0,0 +1,49 @@
+import os
+import yaml
+import torch
+from transformers import AlbertConfig, AlbertModel
+
+class CustomAlbert(AlbertModel):
+ def forward(self, *args, **kwargs):
+ # Call the original forward method
+ outputs = super().forward(*args, **kwargs)
+
+ # Only return the last_hidden_state
+ return outputs.last_hidden_state
+
+
+def load_plbert(log_dir):
+ config_path = os.path.join(log_dir, "config.yml")
+ plbert_config = yaml.safe_load(open(config_path))
+
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
+ bert = CustomAlbert(albert_base_configuration)
+
+ files = os.listdir(log_dir)
+ ckpts = []
+ for f in os.listdir(log_dir):
+ if f.startswith("step_"): ckpts.append(f)
+
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
+ iters = sorted(iters)[-1]
+
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
+ state_dict = checkpoint['net']
+ from collections import OrderedDict
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[7:] # remove `module.`
+ if name.startswith('encoder.'):
+ name = name[8:] # remove `encoder.`
+ new_state_dict[name] = v
+
+ # Check if 'embeddings.position_ids' exists before attempting to delete it
+ if not hasattr(bert.embeddings, 'position_ids'):
+ del new_state_dict["embeddings.position_ids"]
+
+
+ bert.load_state_dict(new_state_dict, strict=False)
+
+ return bert
+
+
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/__init__.py b/stts_48khz/StyleTTS2_48khz/Utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/__init__.py
@@ -0,0 +1 @@
+
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b55bda8460feeae4f4bc4680e585123815ef8bf
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45b1f8cd9b00061363ee4dd9378a3e8b194a86ed
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61be06ebbf493875e8fff8f4e06c79dbdbc27dc6
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/fsdp_patch.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/fsdp_patch.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad80c4fdf04ef4ab96f9bceaa5beb780a97d1588
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/Utils/__pycache__/fsdp_patch.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/Utils/fsdp_patch.py b/stts_48khz/StyleTTS2_48khz/Utils/fsdp_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..643baa0b9bb2de5acccb5be62c3fab9d831b702e
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/Utils/fsdp_patch.py
@@ -0,0 +1,47 @@
+"""
+Monkeypatch to fix fsdp set state when no previous state was set
+
+https://github.com/OpenAccess-AI-Collective/axolotl/pull/400/files
+"""
+
+import contextlib
+from typing import Generator, Optional
+
+import torch
+from torch import nn
+from torch.distributed.fsdp.api import (
+ OptimStateDictConfig,
+ StateDictConfig,
+ StateDictType,
+)
+from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
+
+
+@staticmethod
+@contextlib.contextmanager
+def state_dict_type_patch(
+ module: nn.Module,
+ state_dict_type: StateDictType,
+ state_dict_config: Optional[StateDictConfig] = None,
+ optim_state_dict_config: Optional[OptimStateDictConfig] = None,
+) -> Generator:
+ prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
+ module,
+ state_dict_type,
+ state_dict_config,
+ optim_state_dict_config,
+ )
+ yield
+ if prev_state_dict_settings.state_dict_type:
+ FullyShardedDataParallel.set_state_dict_type(
+ module,
+ prev_state_dict_settings.state_dict_type,
+ prev_state_dict_settings.state_dict_config,
+ prev_state_dict_settings.optim_state_dict_config,
+ )
+
+
+def replace_fsdp_state_dict_type():
+ torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.state_dict_type = (
+ state_dict_type_patch
+ )
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/losses.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/losses.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e50153f56617aa3e81231502a81f0aa597e6462d
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/losses.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/meldataset.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/meldataset.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..183a1e82f2e63f1ea6be0ab551ea487b540a6dbc
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/meldataset.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29ffc15b5733d43cfe060fc62218495817ce02f3
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cf540091568cf03b9d7b624ca15f5f3b0e199ad
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70cfe6226371305402296cd56be655bf8ed3c22b
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/models.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/optimizers.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/optimizers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e047df40d593eb78083e6aaecac3266e81139aa
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/optimizers.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..691944315ef80f4524b154317c287c03bc9c9ea1
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12883d42999b00fb0cb8afc4b4fb6db39926f920
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..932d9d6a9ec1e944a61027d4c0eb1e9a8f240115
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/text_utils.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-310.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ba897e3c051782404d3ac54dd13ee32a5e4f8e0
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-310.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-311.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac724edda67e41d28a2020f722b277135838dca6
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-311.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-39.pyc b/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cfd26e16bef0e14c29c6a88c3a3289cecbda416
Binary files /dev/null and b/stts_48khz/StyleTTS2_48khz/__pycache__/utils.cpython-39.pyc differ
diff --git a/stts_48khz/StyleTTS2_48khz/accelerate_NO_SLM_2ndstage.py b/stts_48khz/StyleTTS2_48khz/accelerate_NO_SLM_2ndstage.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6d0af3139c002fc571940492597a599f574b173
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/accelerate_NO_SLM_2ndstage.py
@@ -0,0 +1,967 @@
+# load packages
+import random
+import yaml
+import time
+from munch import Munch
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+import click
+import shutil
+import traceback
+import warnings
+
+warnings.simplefilter('ignore')
+
+from meldataset import build_dataloader
+
+from Utils.ASR.models import ASRCNN
+from Utils.JDC.model import JDCNet
+from Utils.PLBERT.util import load_plbert
+
+from models import *
+from losses import *
+from utils import *
+
+from Modules.slmadv import SLMAdversarialLoss
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+
+from optimizers import build_optimizer
+
+from accelerate import Accelerator, DistributedDataParallelKwargs
+from accelerate.utils import tqdm, ProjectConfiguration
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+# from Utils.fsdp_patch import replace_fsdp_state_dict_type
+
+# replace_fsdp_state_dict_type()
+
+import logging
+
+from accelerate.logging import get_logger
+from logging import StreamHandler
+
+logger = get_logger(__name__)
+logger.setLevel(logging.DEBUG)
+# handler.setLevel(logging.DEBUG)
+# logger.addHandler(handler)
+
+@click.command()
+@click.option('-p', '--config_path', default='Configs/config.yml', type=str)
+def main(config_path):
+ config = yaml.safe_load(open(config_path))
+
+ log_dir = config['log_dir']
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
+
+ # write logs
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
+ logger.logger.addHandler(file_handler)
+
+ batch_size = config.get('batch_size', 10)
+
+ epochs = config.get('epochs_2nd', 200)
+ save_freq = config.get('save_freq', 2)
+ log_interval = 10
+ saving_epoch = config.get('save_freq', 2)
+
+ data_params = config.get('data_params', None)
+ sr = config['preprocess_params'].get('sr', 48000)
+ hop = config['preprocess_params']["spect_params"].get('hop_length', 512)
+ win = config['preprocess_params']["spect_params"].get('win_length', 2048)
+ train_path = data_params['train_data']
+ val_path = data_params['val_data']
+ root_path = data_params['root_path']
+ min_length = data_params['min_length']
+ OOD_data = data_params['OOD_data']
+
+ max_len = config.get('max_len', 200)
+
+ loss_params = Munch(config['loss_params'])
+ diff_epoch = loss_params.diff_epoch
+ joint_epoch = loss_params.joint_epoch
+
+ optimizer_params = Munch(config['optimizer_params'])
+
+ train_list, val_list = get_data_path_list(train_path, val_path)
+
+ try:
+ tracker = data_params['logger']
+ except KeyError:
+ tracker = "mlflow"
+
+ def log_audio(accelerator, audio, bib="", name="Validation", epoch=0, sr=48000, tracker="tensorboard"):
+ if tracker == "tensorboard":
+ ltracker = accelerator.get_tracker("tensorboard")
+ np_aud = np.stack([np.asarray(aud) for aud in audio])
+ ltracker.writer.add_audio(f"{name}-{bib}", np_aud, epoch, sample_rate=sr)
+ if tracker == "wandb":
+ try:
+ ltracker = accelerator.get_tracker("wandb")
+ ltracker.log(
+ {
+ "validation": [
+ wandb.Audio(audios, caption=f"{name}-{bib}", sample_rate=sr)
+ for i, audios in enumerate(audio)
+ ]
+ }
+ , step=int(bib))
+ except IndexError:
+ pass
+
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False)
+ configAcc = ProjectConfiguration(project_dir=log_dir, logging_dir=log_dir)
+ accelerator = Accelerator(log_with=tracker,
+ project_config=configAcc,
+ split_batches=True,
+ kwargs_handlers=[ddp_kwargs],
+ mixed_precision='bf16')
+
+ accelerator.init_trackers(project_name="StyleTTS2-Second-Stage",
+ config=config if tracker == "wandb" else None)
+ HF = config["data_params"].get("HF", False)
+ name = config["data_params"].get("split", None)
+ split = config["data_params"].get("split", None)
+ val_split = config["data_params"].get("val_split", None)
+ ood_split = config["data_params"].get("OOD_split", None)
+ audcol = config["data_params"].get("audio_column", "speech")
+ phoncol = config["data_params"].get("phoneme_column", "phoneme")
+ specol = config["data_params"].get("speaker_column", "speaker ID")
+
+ if not HF:
+ train_list, val_list = get_data_path_list(train_path, val_path)
+ ds_conf = {"sr": sr, "hop": hop, "win": win}
+ vds_conf = {"sr": sr, "hop": hop, "win": win}
+ else:
+ train_list, val_list = train_path, val_path
+ ds_conf = {"sr": sr,
+ "hop": hop,
+ "split": split,
+ "OOD_split": ood_split,
+ "dataset_name": name,
+ "audio_column": audcol,
+ "phoneme_column": phoncol,
+ "speaker_id_column": specol,
+ "win": win}
+ vds_conf = {"sr": sr,
+ "hop": hop,
+ "split": val_split,
+ "OOD_split": ood_split,
+ "dataset_name": name,
+ "audio_column": audcol,
+ "phoneme_column": phoncol,
+ "speaker_id_column": specol,
+ "win": win}
+ device = accelerator.device
+
+ with accelerator.main_process_first():
+ train_dataloader = build_dataloader(train_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ num_workers=2,
+ dataset_config={},
+ device=device)
+
+ val_dataloader = build_dataloader(val_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ validation=True,
+ num_workers=0,
+ device=device,
+ dataset_config={})
+
+ # load pretrained ASR model
+ ASR_config = config.get('ASR_config', False)
+ ASR_path = config.get('ASR_path', False)
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
+
+ # load pretrained F0 model
+ F0_path = config.get('F0_path', False)
+ pitch_extractor = load_F0_models(F0_path)
+
+ # load PL-BERT model
+ BERT_path = config.get('PLBERT_dir', False)
+ plbert = load_plbert(BERT_path)
+
+ # build model
+ config['model_params']["sr"] = sr
+
+ model_params = recursive_munch(config['model_params'])
+ multispeaker = model_params.multispeaker
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
+ _ = [model[key].to(device) for key in model]
+
+ # DP
+ for key in model:
+ if key != "mpd" and key != "msd" and key != "wd":
+ model[key] = accelerator.prepare(model[key])
+
+ start_epoch = 0
+ iters = 0
+
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
+
+ if not load_pretrained:
+ if config.get('first_stage_path', '') != '':
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
+ accelerator.print('Loading the first stage model at %s ...' % first_stage_path)
+ model, _, start_epoch, iters = load_checkpoint(model,
+ None,
+ first_stage_path,
+ load_only_params=True,
+ ignore_modules=['bert', 'bert_encoder', 'predictor',
+ 'predictor_encoder', 'msd', 'mpd', 'wd',
+ 'diffusion']) # keep starting epoch for tensorboard log
+
+ # these epochs should be counted from the start epoch
+ diff_epoch += start_epoch
+ joint_epoch += start_epoch
+ epochs += start_epoch
+ model.style_encoder.train()
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
+ else:
+ raise ValueError('You need to specify the path to the first stage model.')
+
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
+ wl = WavLMLoss(model_params.slm.model,
+ model.wd,
+ sr,
+ model_params.slm.sr).to(device)
+
+ gl = accelerator.prepare(gl)
+ dl = accelerator.prepare(dl)
+ wl = accelerator.prepare(wl)
+ wl = wl.eval()
+
+ sampler = DiffusionSampler(
+ model.diffusion.module.diffusion,
+ sampler=ADPM2Sampler(),
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
+ clamp=False
+ )
+
+ scheduler_params = {
+ "max_lr": optimizer_params.lr * accelerator.num_processes,
+ "pct_start": float(0),
+ "epochs": epochs,
+ "steps_per_epoch": len(train_dataloader),
+ }
+ scheduler_params_dict = {key: scheduler_params.copy() for key in model}
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
+
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
+ scheduler_params_dict=scheduler_params_dict,
+ lr=optimizer_params.lr * accelerator.num_processes)
+
+ # adjust BERT learning rate
+ for g in optimizer.optimizers['bert'].param_groups:
+ g['betas'] = (0.9, 0.99)
+ g['lr'] = optimizer_params.bert_lr
+ g['initial_lr'] = optimizer_params.bert_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 0.01
+
+ # adjust acoustic module learning rate
+ for module in ["decoder", "style_encoder"]:
+ for g in optimizer.optimizers[module].param_groups:
+ g['betas'] = (0.0, 0.99)
+ g['lr'] = optimizer_params.ft_lr
+ g['initial_lr'] = optimizer_params.ft_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 1e-4
+
+ # load models if there is a model
+ if load_pretrained:
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
+ load_only_params=config.get('load_only_params', True))
+
+ n_down = model.text_aligner.module.n_down
+
+ # for k in model:
+ # model[k] = accelerator.prepare(model[k])
+
+ best_loss = float('inf') # best test loss
+ iters = 0
+
+ criterion = nn.L1Loss() # F0 loss (regression)
+ torch.cuda.empty_cache()
+
+ stft_loss = MultiResolutionSTFTLoss().to(device)
+
+ accelerator.print('BERT', optimizer.optimizers['bert'])
+ accelerator.print('decoder', optimizer.optimizers['decoder'])
+
+ start_ds = False
+
+ running_std = []
+
+ slmadv_params = Munch(config['slmadv_params'])
+
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
+ slmadv_params.min_len,
+ slmadv_params.max_len,
+ batch_percentage=slmadv_params.batch_percentage,
+ skip_update=slmadv_params.iter,
+ sig=slmadv_params.sig
+ )
+
+ for k, v in optimizer.optimizers.items():
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
+
+ train_dataloader = accelerator.prepare(train_dataloader)
+
+ for epoch in range(start_epoch, epochs):
+ running_loss = 0
+ start_time = time.time()
+
+ _ = [model[key].eval() for key in model]
+
+ model.text_aligner.train()
+ model.text_encoder.train()
+
+ model.predictor.train()
+ model.predictor_encoder.train()
+ model.bert_encoder.train()
+ model.bert.train()
+ model.msd.train()
+ model.mpd.train()
+ model.wd.train()
+
+ if epoch >= diff_epoch:
+ start_ds = True
+
+ for i, batch in enumerate(train_dataloader):
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
+ mel_mask = length_to_mask(mel_input_length).to(device)
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ try:
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ except:
+ continue
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
+
+ # compute the style of the entire utterance
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
+ ss = []
+ gs = []
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s_dur = torch.stack(ss).squeeze(1) # global prosodic styles
+ gs = torch.stack(gs).squeeze(1) # global acoustic styles
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ # denoiser training
+ if epoch >= diff_epoch:
+ num_steps = np.random.randint(3, 5)
+
+ if model_params.diffusion.dist.estimate_sigma_data:
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(
+ axis=-1).mean().item() # batch-wise std estimation
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
+
+ if multispeaker:
+ s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ features=ref, # reference from the same speaker as the embedding
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1),
+ embedding=bert_dur).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ # print(loss_sty)
+ else:
+ # print("here")
+ loss_sty = 0
+ loss_diff = 0
+
+ d, p = model.predictor(d_en, s_dur,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+
+ # mel_len = int(mel_input_length.min().item() / 2 - 1)
+
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
+
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+ st = []
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start + mel_len])
+ p_en.append(p[bib, :, random_start:random_start + mel_len])
+ gt.append(mels[bib, :, (random_start * 2):((random_start + mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start + mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ # style reference (better to be different from the GT)
+ random_start = np.random.randint(0, mel_length - mel_len_st)
+ st.append(mels[bib, :, (random_start * 2):((random_start + mel_len_st) * 2)])
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+ st = torch.stack(st).detach()
+
+ if gt.size(-1) < 80:
+ continue
+
+ s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+
+ with torch.no_grad():
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2])
+
+ asr_real = model.text_aligner.module.get_feature(gt)
+
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec_gt = wav.unsqueeze(1)
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
+
+ if epoch >= joint_epoch:
+ # ground truth from recording
+ wav = y_rec_gt # use recording since decoder is tuned
+ else:
+ # ground truth from reconstruction
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
+
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
+
+ if start_ds:
+ optimizer.zero_grad()
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
+ accelerator.backward(d_loss)
+ optimizer.step('msd')
+ optimizer.step('mpd')
+ else:
+ d_loss = 0
+
+ # generator loss
+ optimizer.zero_grad()
+
+ loss_mel = stft_loss(y_rec, wav)
+ if start_ds:
+ loss_gen_all = gl(wav, y_rec).mean()
+ else:
+ loss_gen_all = 0
+ loss_lm = wl(wav.detach().squeeze(1), y_rec.squeeze(1)).mean()
+
+ loss_ce = 0
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for p in range(_s2s_trg.shape[0]):
+ _s2s_trg[p, :_text_input[p]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length - 1],
+ _text_input[1:_text_length - 1])
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
+
+ loss_ce /= texts.size(0)
+ loss_dur /= texts.size(0)
+
+ g_loss = loss_params.lambda_mel * loss_mel + \
+ loss_params.lambda_F0 * loss_F0_rec + \
+ loss_params.lambda_ce * loss_ce + \
+ loss_params.lambda_norm * loss_norm_rec + \
+ loss_params.lambda_dur * loss_dur + \
+ loss_params.lambda_gen * loss_gen_all + \
+ loss_params.lambda_slm * loss_lm + \
+ loss_params.lambda_sty * loss_sty + \
+ loss_params.lambda_diff * loss_diff
+
+ running_loss += accelerator.gather(loss_mel).mean().item()
+ accelerator.backward(g_loss)
+
+ # if torch.isnan(g_loss):
+ # from IPython.core.debugger import set_trace
+ # set_trace()
+
+
+ # if epoch >= diff_epoch:
+
+
+ if epoch >= joint_epoch:
+ d_loss_slm, loss_gen_lm = 0, 0
+ optimizer.step('style_encoder')
+ optimizer.step('decoder')
+
+
+
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('predictor_encoder')
+
+ optimizer.step('diffusion')
+
+
+ # # randomly pick whether to use in-distribution text
+ # if np.random.rand() < 0.5:
+ # use_ind = True
+ # else:
+ # use_ind = False
+
+ # if use_ind:
+ # ref_lengths = input_lengths
+ # ref_texts = texts
+
+ # slm_out = slmadv(i,
+ # y_rec_gt,
+ # y_rec_gt_pred,
+ # waves,
+ # mel_input_length,
+ # ref_texts,
+ # ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
+
+
+
+ # if slm_out is None:
+ # continue
+ # # if slm_out is not None:
+ # # d_loss_slm, loss_gen_lm, y_pred = slm_out
+ # # optimizer.zero_grad()
+ # # # accelerator.clip_grad_norm_(model.decoder.parameters(), 1)
+ # # # print("here")
+ # # accelerator.backward(loss_gen_lm)
+ # # # print("here2")
+ # # # SLM discriminator loss
+
+
+ # # # compute the gradient norm
+
+
+ # # total_norm = {}
+ # # for key in model.keys():
+ # # total_norm[key] = 0
+ # # parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
+ # # for p in parameters:
+ # # param_norm = p.grad.detach().data.norm(2)
+ # # total_norm[key] += param_norm.item() ** 2
+ # # total_norm[key] = total_norm[key] ** 0.5
+
+ # # # gradient scaling
+ # # if total_norm['predictor'] > slmadv_params.thresh:
+ # # for key in model.keys():
+ # # for p in model[key].parameters():
+ # # if p.grad is not None:
+ # # p.grad *= (1 / total_norm['predictor'])
+
+ # # for p in model.predictor.module.duration_proj.parameters():
+ # # if p.grad is not None:
+ # # p.grad *= slmadv_params.scale
+
+ # # for p in model.predictor.module.lstm.parameters():
+ # # if p.grad is not None:
+ # # p.grad *= slmadv_params.scale
+
+ # # for p in model.diffusion.module.parameters():
+ # # if p.grad is not None:
+ # # p.grad *= slmadv_params.scale
+
+ # # optimizer.step('bert_encoder')
+ # # optimizer.step('bert')
+ # # optimizer.step('predictor')
+ # # optimizer.step('diffusion')
+
+ # # # SLM discriminator loss
+ # # if d_loss_slm != 0:
+ # # optimizer.zero_grad()
+ # # # print("hey1")
+ # # accelerator.backward(d_loss_slm, retain_graph=True)
+ # # optimizer.step('wd')
+ # # # print("hey2")
+
+ else:
+ d_loss_slm, loss_gen_lm = 0, 0
+
+ iters = iters + 1
+ if (i + 1) % log_interval == 0:
+ logger.info(
+ 'Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
+ % (epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, d_loss,
+ loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff,
+ d_loss_slm, loss_gen_lm), main_process_only=True)
+ if accelerator.is_main_process:
+ print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
+ % (epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, d_loss,
+ loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff,
+ d_loss_slm, loss_gen_lm))
+ accelerator.log({'train/mel_loss': float(running_loss / log_interval),
+ 'train/gen_loss': float(loss_gen_all),
+ 'train/d_loss': float(d_loss),
+ 'train/ce_loss': float(loss_ce),
+ 'train/dur_loss': float(loss_dur),
+ 'train/slm_loss': float(loss_lm),
+ 'train/norm_loss': float(loss_norm_rec),
+ 'train/F0_loss': float(loss_F0_rec),
+ 'train/sty_loss': float(loss_sty),
+ 'train/diff_loss': float(loss_diff),
+ 'train/d_loss_slm': float(d_loss_slm),
+ 'train/gen_loss_slm': float(loss_gen_lm),
+ 'epoch': int(epoch) + 1}, step=iters)
+
+ running_loss = 0
+
+ accelerator.print('Time elasped:', time.time() - start_time)
+
+ loss_test = 0
+ loss_align = 0
+ loss_f = 0
+ _ = [model[key].eval() for key in model]
+
+ with torch.no_grad():
+ iters_test = 0
+ for batch_idx, batch in enumerate(val_dataloader):
+ optimizer.zero_grad()
+
+ try:
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ # print("t_en", t_en.shape, t_en)
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ ss = []
+ gs = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s = torch.stack(ss).squeeze(1)
+ gs = torch.stack(gs).squeeze(1)
+ s_trg = torch.cat([s, gs], dim=-1).detach()
+ # print("texts", texts.shape, texts)
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+ d, p = model.predictor(d_en, s,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+ # get clips
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start + mel_len])
+ p_en.append(p[bib, :, random_start:random_start + mel_len])
+
+ gt.append(mels[bib, :, (random_start * 2):((random_start + mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start + mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+
+ s = model.predictor_encoder(gt.unsqueeze(1))
+
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s, f0=True)
+
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for bib in range(_s2s_trg.shape[0]):
+ _s2s_trg[bib, :_text_input[bib]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length - 1],
+ _text_input[1:_text_length - 1])
+
+ loss_dur /= texts.size(0)
+
+ s = model.style_encoder(gt.unsqueeze(1))
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+ loss_mel = stft_loss(y_rec.squeeze(1), wav.detach())
+
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
+
+ loss_test += accelerator.gather(loss_mel).mean()
+ loss_align += accelerator.gather(loss_dur).mean()
+ loss_f += accelerator.gather(loss_F0).mean()
+
+ iters_test += 1
+ except Exception as e:
+ accelerator.print(f"Eval errored with: \n {str(e)}")
+ continue
+
+ accelerator.print('Epochs:', epoch + 1)
+ try:
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (
+ loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n', main_process_only=True)
+
+
+ accelerator.log({'eval/mel_loss': float(loss_test / iters_test),
+ 'eval/dur_loss': float(loss_test / iters_test),
+ 'eval/F0_loss': float(loss_f / iters_test)},
+ step=(i + 1) * (epoch + 1))
+ except ZeroDivisionError:
+ accelerator.print("Eval loss was divided by zero... skipping eval cycle")
+
+ if epoch < diff_epoch:
+ # generating reconstruction examples with GT duration
+
+ with torch.no_grad():
+ for bib in range(len(asr)):
+ mel_length = int(mel_input_length[bib].item())
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
+
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
+ F0_real = F0_real.unsqueeze(0)
+ s = model.style_encoder(gt.unsqueeze(1))
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ try:
+ y_rec = model.decoder(en, F0_real.squeeze(0), real_norm, s)
+ except Exception as e:
+ accelerator.print(str(e))
+ accelerator.print(F0_real.size())
+ accelerator.print(F0_real.squeeze(0).size())
+
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
+ p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
+
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
+
+ # writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+ if accelerator.is_main_process:
+ log_audio(accelerator, y_pred.detach().cpu().numpy().squeeze(), bib, "pred/y", epoch, sr, tracker=tracker)
+
+ if epoch == 0:
+ # writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
+ if accelerator.is_main_process:
+ log_audio(accelerator, waves[bib].squeeze(), bib, "gt/y", epoch, sr, tracker=tracker)
+
+ if bib >= 10:
+ break
+ else:
+
+ try:
+ # generating sampled speech from text directly
+ with torch.no_grad():
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
+
+ for bib in range(len(d_en)):
+ if multispeaker:
+ s_pred = sampler(noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device),
+ embedding=bert_dur[bib].unsqueeze(0),
+ embedding_scale=1,
+ features=ref_s[bib].unsqueeze(0),
+ # reference from the same speaker as the embedding
+ num_steps=5).squeeze(1)
+ else:
+ s_pred = sampler(noise=torch.ones((1, 1, 256)).to(texts.device)*0.5,
+ embedding=bert_dur[bib].unsqueeze(0),
+ embedding_scale=1,
+ num_steps=5).squeeze(1)
+
+ s = s_pred[:, 128:]
+ ref = s_pred[:, :128]
+ # print(model.predictor)
+ # print(d_en[bib, :, :input_lengths[bib]])
+ d = model.predictor.module.text_encoder(d_en[bib, :, :input_lengths[bib]].unsqueeze(0),
+ s, input_lengths[bib, ...].unsqueeze(0),
+ text_mask[bib, :input_lengths[bib]].unsqueeze(0))
+
+ x = model.predictor.module.lstm(d)
+ x_mod = model.predictor.module.prepare_projection(x) # 640 -> 512
+ duration = model.predictor.module.duration_proj(x_mod)
+
+ duration = torch.sigmoid(duration).sum(axis=-1)
+ pred_dur = torch.round(duration.squeeze(0)).clamp(min=1)
+
+ pred_dur[-1] += 5
+
+ pred_aln_trg = torch.zeros(input_lengths[bib], int(pred_dur.sum().data))
+ c_frame = 0
+ for i in range(pred_aln_trg.size(0)):
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
+ c_frame += int(pred_dur[i].data)
+
+ # encode prosody
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(texts.device))
+ F0_pred, N_pred = model.predictor(texts=en, style=s, f0=True)
+ out = model.decoder(
+ (t_en[bib, :, :input_lengths[bib]].unsqueeze(0) @ pred_aln_trg.unsqueeze(0).to(texts.device)),
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
+
+ # writer.add_audio('pred/y' + str(bib), out.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+ if accelerator.is_main_process:
+ log_audio(accelerator, out.detach().cpu().numpy().squeeze(), bib, "pred/y", epoch, sr, tracker=tracker)
+
+ if bib >= 5:
+ break
+ except Exception as e:
+ accelerator.print('error -> ', e)
+ accelerator.print("some of the samples couldn't be evaluated, skipping those.")
+
+
+ if epoch % saving_epoch == 0:
+ if (loss_test / iters_test) < best_loss:
+ best_loss = loss_test / iters_test
+ try:
+ accelerator.print('Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ except ZeroDivisionError:
+ accelerator.print('No iter test, Re-Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': 0.1, # not zero just in case
+ 'epoch': epoch,
+ }
+
+ if accelerator.is_main_process:
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
+ torch.save(state, save_path)
+
+ # if estimate sigma, save the estimated simga
+ if model_params.diffusion.dist.estimate_sigma_data:
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
+
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
+ yaml.dump(config, outfile, default_flow_style=True)
+ if accelerator.is_main_process:
+ print('Saving last pth..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, '2nd_phase_last.pth')
+ torch.save(state, save_path)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/stts_48khz/StyleTTS2_48khz/accelerate_train_finetune.py b/stts_48khz/StyleTTS2_48khz/accelerate_train_finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d048276c539ea9223a0109b6eadbdd4da1d07b3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/accelerate_train_finetune.py
@@ -0,0 +1,838 @@
+# load packages
+import random
+import yaml
+import time
+from munch import Munch
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+import click
+import shutil
+import warnings
+warnings.simplefilter('ignore')
+from torch.utils.tensorboard import SummaryWriter
+
+from meldataset import build_dataloader
+
+from Utils.ASR.models import ASRCNN
+from Utils.JDC.model import JDCNet
+from Utils.PLBERT.util import load_plbert
+
+from models import *
+from losses import *
+from utils import *
+
+from Modules.slmadv import SLMAdversarialLoss
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+from accelerate import Accelerator, DistributedDataParallelKwargs
+from accelerate.utils import tqdm, ProjectConfiguration
+
+from optimizers import build_optimizer
+
+# # simple fix for dataparallel that allows access to class attributes
+# class MyDataParallel(torch.nn.DataParallel):
+# def __getattr__(self, name):
+# try:
+# return super().__getattr__(name)
+# except AttributeError:
+# return getattr(self.module, name)
+
+import logging
+# from logging import StreamHandler
+# logger = logging.getLogger(__name__)
+# logger.setLevel(logging.DEBUG)
+# handler = StreamHandler()
+# handler.setLevel(logging.DEBUG)
+# logger.addHandler(handler)
+
+
+from accelerate.logging import get_logger
+from logging import StreamHandler
+
+logger = get_logger(__name__)
+logger.setLevel(logging.DEBUG)
+
+@click.command()
+@click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
+def main(config_path):
+ config = yaml.safe_load(open(config_path))
+
+ log_dir = config['log_dir']
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
+
+ # write logs
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
+ logger.logger.addHandler(file_handler)
+
+
+
+ batch_size = config.get('batch_size', 10)
+
+ epochs = config.get('epochs', 200)
+ save_freq = config.get('save_freq', 2)
+ log_interval = config.get('log_interval', 10)
+ saving_epoch = config.get('save_freq', 2)
+
+ data_params = config.get('data_params', None)
+ sr = config['preprocess_params'].get('sr', 48000)
+ train_path = data_params['train_data']
+ val_path = data_params['val_data']
+ root_path = data_params['root_path']
+ min_length = data_params['min_length']
+ OOD_data = data_params['OOD_data']
+
+ max_len = config.get('max_len', 200)
+
+ loss_params = Munch(config['loss_params'])
+ diff_epoch = loss_params.diff_epoch
+ joint_epoch = loss_params.joint_epoch
+
+ optimizer_params = Munch(config['optimizer_params'])
+
+ train_list, val_list = get_data_path_list(train_path, val_path)
+
+ try:
+ tracker = data_params['logger']
+ except KeyError:
+ tracker = "mlflow"
+
+ def log_audio(accelerator, audio, bib="", name="Validation", epoch=0, sr=48000, tracker="tensorboard"):
+ if tracker == "tensorboard":
+ ltracker = accelerator.get_tracker("tensorboard")
+ np_aud = np.stack([np.asarray(aud) for aud in audio])
+ ltracker.writer.add_audio(f"{name}-{bib}", np_aud, epoch, sample_rate=sr)
+ if tracker == "wandb":
+ try:
+ ltracker = accelerator.get_tracker("wandb")
+ ltracker.log(
+ {
+ "validation": [
+ wandb.Audio(audios, caption=f"{name}-{bib}", sample_rate=sr)
+ for i, audios in enumerate(audio)
+ ]
+ }
+ , step=int(bib))
+ except IndexError:
+ pass
+
+
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False)
+ configAcc = ProjectConfiguration(project_dir=log_dir, logging_dir=log_dir)
+ accelerator = Accelerator(log_with=tracker,
+ project_config=configAcc,
+ split_batches=True,
+ kwargs_handlers=[ddp_kwargs],
+ mixed_precision='bf16')
+
+ accelerator.init_trackers(project_name="StyleTTS2-Second-Stage_finetune",
+ config=config if tracker == "wandb" else None)
+
+
+
+ device = accelerator.device
+
+ with accelerator.main_process_first():
+ train_dataloader = build_dataloader(train_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ num_workers=2,
+ dataset_config={},
+ device=device)
+
+ val_dataloader = build_dataloader(val_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ validation=True,
+ num_workers=0,
+ device=device,
+ dataset_config={})
+
+
+ # load pretrained ASR model
+ ASR_config = config.get('ASR_config', False)
+ ASR_path = config.get('ASR_path', False)
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
+
+ # load pretrained F0 model
+ F0_path = config.get('F0_path', False)
+ pitch_extractor = load_F0_models(F0_path)
+
+ # load PL-BERT model
+ BERT_path = config.get('PLBERT_dir', False)
+ plbert = load_plbert(BERT_path)
+
+ # build model
+ model_params = recursive_munch(config['model_params'])
+ multispeaker = model_params.multispeaker
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
+ _ = [model[key].to(device) for key in model]
+
+ # DP
+ for key in model:
+ if key != "mpd" and key != "msd" and key != "wd":
+ model[key] = accelerator.prepare(model[key])
+
+ start_epoch = 0
+ iters = 0
+
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
+
+ if not load_pretrained:
+ if config.get('first_stage_path', '') != '':
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
+ accelerator.print('Loading the first stage model at %s ...' % first_stage_path)
+ model, _, start_epoch, iters = load_checkpoint(model,
+ None,
+ first_stage_path,
+ load_only_params=True,
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
+
+ # these epochs should be counted from the start epoch
+
+
+ diff_epoch += start_epoch
+ joint_epoch += start_epoch
+ epochs += start_epoch
+
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
+ else:
+ raise ValueError('You need to specify the path to the first stage model.')
+
+
+
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
+ wl = WavLMLoss(model_params.slm.model,
+ model.wd,
+ sr,
+ model_params.slm.sr).to(device)
+
+ gl = accelerator.prepare(gl)
+ dl = accelerator.prepare(dl)
+ wl = accelerator.prepare(wl)
+
+ sampler = DiffusionSampler(
+ model.diffusion.module.diffusion,
+ sampler=ADPM2Sampler(),
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
+ clamp=False
+ )
+
+ scheduler_params = {
+ "max_lr": optimizer_params.lr * accelerator.num_processes,
+ "pct_start": float(0),
+ "epochs": epochs,
+ "steps_per_epoch": len(train_dataloader),
+ }
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
+
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr * accelerator.num_processes)
+
+ # adjust BERT learning rate
+ for g in optimizer.optimizers['bert'].param_groups:
+ g['betas'] = (0.9, 0.99)
+ g['lr'] = optimizer_params.bert_lr
+ g['initial_lr'] = optimizer_params.bert_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 0.01
+
+ # adjust acoustic module learning rate
+ for module in ["decoder", "style_encoder"]:
+ for g in optimizer.optimizers[module].param_groups:
+ g['betas'] = (0.0, 0.99)
+ g['lr'] = optimizer_params.ft_lr
+ g['initial_lr'] = optimizer_params.ft_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 1e-4
+
+ # load models if there is a model
+ if load_pretrained:
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
+ load_only_params=config.get('load_only_params', True))
+
+
+ start_from_zero = True
+ if start_from_zero:
+ start_epoch = 0
+
+ n_down = model.text_aligner.module.n_down
+
+ best_loss = float('inf') # best test loss
+ loss_train_record = list([])
+ loss_test_record = list([])
+ iters = 0
+
+ criterion = nn.L1Loss() # F0 loss (regression)
+ torch.cuda.empty_cache()
+
+ stft_loss = MultiResolutionSTFTLoss().to(device)
+
+ accelerator.print('BERT', optimizer.optimizers['bert'])
+ accelerator.print('decoder', optimizer.optimizers['decoder'])
+
+ start_ds = False
+
+ running_std = []
+
+ slmadv_params = Munch(config['slmadv_params'])
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
+ slmadv_params.min_len,
+ slmadv_params.max_len,
+ batch_percentage=slmadv_params.batch_percentage,
+ skip_update=slmadv_params.iter,
+ sig=slmadv_params.sig
+ )
+
+ for k, v in optimizer.optimizers.items():
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
+
+ train_dataloader = accelerator.prepare(train_dataloader)
+
+
+ for epoch in range(0, epochs):
+
+ running_loss = 0
+ start_time = time.time()
+
+ _ = [model[key].eval() for key in model]
+
+ model.text_aligner.train()
+ model.text_encoder.train()
+
+ model.predictor.train()
+ model.bert_encoder.train()
+ model.bert.train()
+ model.msd.train()
+ model.mpd.train()
+
+ for i, batch in enumerate(train_dataloader):
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
+ mel_mask = length_to_mask(mel_input_length).to(device)
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
+
+ try:
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ except:
+ continue
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+
+ # 50% of chance of using monotonic version
+ if bool(random.getrandbits(1)):
+ asr = (t_en @ s2s_attn)
+ else:
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ # compute the style of the entire utterance
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
+ ss = []
+ gs = []
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
+ gs = torch.stack(gs).squeeze() # global acoustic styles
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ # denoiser training
+ if epoch >= diff_epoch:
+ num_steps = np.random.randint(3, 5)
+
+ if model_params.diffusion.dist.estimate_sigma_data:
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
+
+ if multispeaker:
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ features=ref, # reference from the same speaker as the embedding
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ loss_sty = 0
+ loss_diff = 0
+
+
+ s_loss = 0
+
+
+ d, p = model.predictor(d_en, s_dur,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
+ en = []
+ gt = []
+ p_en = []
+ wav = []
+ st = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ # style reference (better to be different from the GT)
+ random_start = np.random.randint(0, mel_length - mel_len_st)
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+ st = torch.stack(st).detach()
+
+
+ if gt.size(-1) < 80:
+ continue
+
+ s = model.style_encoder(gt.unsqueeze(1))
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
+
+ with torch.no_grad():
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
+
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec_gt = wav.unsqueeze(1)
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
+
+ wav = y_rec_gt
+
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
+
+ if start_ds:
+ optimizer.zero_grad()
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
+ accelerator.backward(d_loss)
+ optimizer.step('msd')
+ optimizer.step('mpd')
+ else:
+ d_loss = 0
+
+ # generator loss
+ optimizer.zero_grad()
+
+ loss_mel = stft_loss(y_rec, wav)
+ loss_gen_all = gl(wav, y_rec).mean()
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
+
+ loss_ce = 0
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for p in range(_s2s_trg.shape[0]):
+ _s2s_trg[p, :_text_input[p]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
+ _text_input[1:_text_length-1])
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
+
+ loss_ce /= texts.size(0)
+ loss_dur /= texts.size(0)
+
+ loss_s2s = 0
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
+ loss_s2s /= texts.size(0)
+
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
+
+ g_loss = loss_params.lambda_mel * loss_mel + \
+ loss_params.lambda_F0 * loss_F0_rec + \
+ loss_params.lambda_ce * loss_ce + \
+ loss_params.lambda_norm * loss_norm_rec + \
+ loss_params.lambda_dur * loss_dur + \
+ loss_params.lambda_gen * loss_gen_all + \
+ loss_params.lambda_slm * loss_lm + \
+ loss_params.lambda_sty * loss_sty + \
+ loss_params.lambda_diff * loss_diff + \
+ loss_params.lambda_mono * loss_mono + \
+ loss_params.lambda_s2s * loss_s2s
+
+
+
+ running_loss += loss_mel.item()
+ accelerator.backward(g_loss)
+ # if torch.isnan(g_loss):
+ # from IPython.core.debugger import set_trace
+ # set_trace()
+
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('predictor_encoder')
+ optimizer.step('style_encoder')
+ optimizer.step('decoder')
+
+ optimizer.step('text_encoder')
+ optimizer.step('text_aligner')
+
+
+
+ if epoch >= diff_epoch:
+ optimizer.step('diffusion')
+
+ d_loss_slm, loss_gen_lm = 0, 0
+ if epoch >= joint_epoch:
+ # randomly pick whether to use in-distribution text
+ if np.random.rand() < 0.5:
+ use_ind = True
+ else:
+ use_ind = False
+
+ if use_ind:
+ ref_lengths = input_lengths
+ ref_texts = texts
+
+ slm_out = slmadv(i,
+ y_rec_gt,
+ y_rec_gt_pred,
+ waves,
+ mel_input_length,
+ ref_texts,
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
+
+ if slm_out is not None:
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
+ optimizer.zero_grad()
+ # accelerator.clip_grad_norm_(model.decoder.parameters(), 1)
+ # print("here")
+ accelerator.backward(loss_gen_lm)
+ # print("here2")
+ # SLM discriminator loss
+ if d_loss_slm != 0:
+ print("loss!", d_loss_slm)
+ else:
+ print("no loss")
+ # compute the gradient norm
+ total_norm = {}
+ for key in model.keys():
+ total_norm[key] = 0
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
+ for p in parameters:
+ param_norm = p.grad.detach().data.norm(2)
+ total_norm[key] += param_norm.item() ** 2
+ total_norm[key] = total_norm[key] ** 0.5
+
+ # gradient scaling
+ if total_norm['predictor'] > slmadv_params.thresh:
+ for key in model.keys():
+ for p in model[key].parameters():
+ if p.grad is not None:
+ p.grad *= (1 / total_norm['predictor'])
+
+ for p in model.predictor.module.duration_proj.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ for p in model.predictor.module.lstm.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ for p in model.diffusion.module.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('diffusion')
+
+ # SLM discriminator loss
+ if d_loss_slm != 0:
+ optimizer.zero_grad()
+ accelerator.backward(d_loss_slm, retain_graph=True)
+ optimizer.step('wd')
+
+
+
+
+ iters = iters + 1
+ if (i + 1) % log_interval == 0:
+ logger.info(
+ 'Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
+ % (epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, d_loss,
+ loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff,
+ d_loss_slm, loss_gen_lm), main_process_only=True)
+ if accelerator.is_main_process:
+ print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
+ % (epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, d_loss,
+ loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff,
+ d_loss_slm, loss_gen_lm))
+ accelerator.log({'train/mel_loss': float(running_loss / log_interval),
+ 'train/gen_loss': float(loss_gen_all),
+ 'train/d_loss': float(d_loss),
+ 'train/ce_loss': float(loss_ce),
+ 'train/dur_loss': float(loss_dur),
+ 'train/slm_loss': float(loss_lm),
+ 'train/norm_loss': float(loss_norm_rec),
+ 'train/F0_loss': float(loss_F0_rec),
+ 'train/sty_loss': float(loss_sty),
+ 'train/diff_loss': float(loss_diff),
+ 'train/d_loss_slm': float(d_loss_slm),
+ 'train/gen_loss_slm': float(loss_gen_lm),
+ 'epoch': int(epoch) + 1}, step=iters)
+
+ running_loss = 0
+
+ accelerator.print('Time elasped:', time.time() - start_time)
+
+
+ loss_test = 0
+ loss_align = 0
+ loss_f = 0
+ _ = [model[key].eval() for key in model]
+
+ with torch.no_grad():
+ iters_test = 0
+ for batch_idx, batch in enumerate(val_dataloader):
+ optimizer.zero_grad()
+
+ try:
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ ss = []
+ gs = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s = torch.stack(ss).squeeze()
+ gs = torch.stack(gs).squeeze()
+ s_trg = torch.cat([s, gs], dim=-1).detach()
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+ d, p = model.predictor(d_en, s,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+ # get clips
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
+
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+ s = model.predictor_encoder(gt.unsqueeze(1))
+
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
+
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for bib in range(_s2s_trg.shape[0]):
+ _s2s_trg[bib, :_text_input[bib]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
+ _text_input[1:_text_length-1])
+
+ loss_dur /= texts.size(0)
+
+ s = model.style_encoder(gt.unsqueeze(1))
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
+
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
+
+ loss_test += (loss_mel).mean()
+ loss_align += (loss_dur).mean()
+ loss_f += (loss_F0).mean()
+
+ iters_test += 1
+
+ except Exception as e:
+ accelerator.print(f"Eval errored with: \n {str(e)}")
+ continue
+
+
+ # print('Epochs:', epoch + 1)
+ # logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
+ # print('\n\n\n')
+ # writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
+ # writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
+ # writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
+
+ accelerator.print('Epochs:', epoch + 1)
+ try:
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (
+ loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n', main_process_only=True)
+
+
+ accelerator.log({'eval/mel_loss': float(loss_test / iters_test),
+ 'eval/dur_loss': float(loss_test / iters_test),
+ 'eval/F0_loss': float(loss_f / iters_test)},
+ step=(i + 1) * (epoch + 1))
+ except ZeroDivisionError:
+ accelerator.print("Eval loss was divided by zero... skipping eval cycle")
+
+ if epoch % saving_epoch == 0:
+ if (loss_test / iters_test) < best_loss:
+ best_loss = loss_test / iters_test
+ try:
+ accelerator.print('Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ except ZeroDivisionError:
+ accelerator.print('No iter test, Re-Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': 0.1, # not zero just in case
+ 'epoch': epoch,
+ }
+
+ if accelerator.is_main_process:
+ save_path = osp.join(log_dir, 'epoch_fine_tune_%05d.pth' % epoch)
+ torch.save(state, save_path)
+
+ # if estimate sigma, save the estimated simga
+ if model_params.diffusion.dist.estimate_sigma_data:
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
+
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
+ yaml.dump(config, outfile, default_flow_style=True)
+ if accelerator.is_main_process:
+ print('Saving last pth..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, 'finetune_phase_last.pth')
+ torch.save(state, save_path)
+
+ accelerator.end_training()
+
+
+
+if __name__=="__main__":
+ main()
diff --git a/stts_48khz/StyleTTS2_48khz/accelerate_train_second.py b/stts_48khz/StyleTTS2_48khz/accelerate_train_second.py
new file mode 100644
index 0000000000000000000000000000000000000000..0594ad10dbf228c356513b967bc3ee0cf2c77b40
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/accelerate_train_second.py
@@ -0,0 +1,975 @@
+# load packages
+import random
+import yaml
+import time
+from munch import Munch
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+import click
+import shutil
+import traceback
+import warnings
+
+warnings.simplefilter('ignore')
+
+from meldataset import build_dataloader
+
+from Utils.ASR.models import ASRCNN
+from Utils.JDC.model import JDCNet
+from Utils.PLBERT.util import load_plbert
+
+from models import *
+from losses import *
+from utils import *
+
+from Modules.slmadv import SLMAdversarialLoss
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+
+from optimizers import build_optimizer
+
+from accelerate import Accelerator, DistributedDataParallelKwargs
+from accelerate.utils import tqdm, ProjectConfiguration
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+# from Utils.fsdp_patch import replace_fsdp_state_dict_type
+
+# replace_fsdp_state_dict_type()
+
+import logging
+
+from accelerate.logging import get_logger
+from logging import StreamHandler
+
+logger = get_logger(__name__)
+logger.setLevel(logging.DEBUG)
+# handler.setLevel(logging.DEBUG)
+# logger.addHandler(handler)
+
+@click.command()
+@click.option('-p', '--config_path', default='Configs/config.yml', type=str)
+def main(config_path):
+ config = yaml.safe_load(open(config_path))
+
+ log_dir = config['log_dir']
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
+
+ # write logs
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
+ logger.logger.addHandler(file_handler)
+
+ batch_size = config.get('batch_size', 10)
+
+ epochs = config.get('epochs_2nd', 200)
+ save_freq = config.get('save_freq', 2)
+ log_interval = 10
+ saving_epoch = config.get('save_freq', 2)
+
+ data_params = config.get('data_params', None)
+ sr = config['preprocess_params'].get('sr', 48000)
+ hop = config['preprocess_params']["spect_params"].get('hop_length', 512)
+ win = config['preprocess_params']["spect_params"].get('win_length', 2048)
+ train_path = data_params['train_data']
+ val_path = data_params['val_data']
+ root_path = data_params['root_path']
+ min_length = data_params['min_length']
+ OOD_data = data_params['OOD_data']
+
+ max_len = config.get('max_len', 200)
+
+ loss_params = Munch(config['loss_params'])
+ diff_epoch = loss_params.diff_epoch
+ joint_epoch = loss_params.joint_epoch
+
+ optimizer_params = Munch(config['optimizer_params'])
+
+ train_list, val_list = get_data_path_list(train_path, val_path)
+
+ try:
+ tracker = data_params['logger']
+ except KeyError:
+ tracker = "mlflow"
+
+ def log_audio(accelerator, audio, bib="", name="Validation", epoch=0, sr=48000, tracker="tensorboard"):
+ if tracker == "tensorboard":
+ ltracker = accelerator.get_tracker("tensorboard")
+ np_aud = np.stack([np.asarray(aud) for aud in audio])
+ ltracker.writer.add_audio(f"{name}-{bib}", np_aud, epoch, sample_rate=sr)
+ if tracker == "wandb":
+ try:
+ ltracker = accelerator.get_tracker("wandb")
+ ltracker.log(
+ {
+ "validation": [
+ wandb.Audio(audios, caption=f"{name}-{bib}", sample_rate=sr)
+ for i, audios in enumerate(audio)
+ ]
+ }
+ , step=int(bib))
+ except IndexError:
+ pass
+
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False)
+ configAcc = ProjectConfiguration(project_dir=log_dir, logging_dir=log_dir)
+ accelerator = Accelerator(log_with=tracker,
+ project_config=configAcc,
+ split_batches=True,
+ kwargs_handlers=[ddp_kwargs],
+ mixed_precision='bf16')
+
+ accelerator.init_trackers(project_name="StyleTTS2-Second-Stage",
+ config=config if tracker == "wandb" else None)
+ HF = config["data_params"].get("HF", False)
+ name = config["data_params"].get("split", None)
+ split = config["data_params"].get("split", None)
+ val_split = config["data_params"].get("val_split", None)
+ ood_split = config["data_params"].get("OOD_split", None)
+ audcol = config["data_params"].get("audio_column", "speech")
+ phoncol = config["data_params"].get("phoneme_column", "phoneme")
+ specol = config["data_params"].get("speaker_column", "speaker ID")
+
+ if not HF:
+ train_list, val_list = get_data_path_list(train_path, val_path)
+ ds_conf = {"sr": sr, "hop": hop, "win": win}
+ vds_conf = {"sr": sr, "hop": hop, "win": win}
+ else:
+ train_list, val_list = train_path, val_path
+ ds_conf = {"sr": sr,
+ "hop": hop,
+ "split": split,
+ "OOD_split": ood_split,
+ "dataset_name": name,
+ "audio_column": audcol,
+ "phoneme_column": phoncol,
+ "speaker_id_column": specol,
+ "win": win}
+ vds_conf = {"sr": sr,
+ "hop": hop,
+ "split": val_split,
+ "OOD_split": ood_split,
+ "dataset_name": name,
+ "audio_column": audcol,
+ "phoneme_column": phoncol,
+ "speaker_id_column": specol,
+ "win": win}
+ device = accelerator.device
+
+ with accelerator.main_process_first():
+ train_dataloader = build_dataloader(train_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ num_workers=2,
+ dataset_config={},
+ device=device)
+
+ val_dataloader = build_dataloader(val_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ validation=True,
+ num_workers=0,
+ device=device,
+ dataset_config={})
+
+ # load pretrained ASR model
+ ASR_config = config.get('ASR_config', False)
+ ASR_path = config.get('ASR_path', False)
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
+
+ # load pretrained F0 model
+ F0_path = config.get('F0_path', False)
+ pitch_extractor = load_F0_models(F0_path)
+
+ # load PL-BERT model
+ BERT_path = config.get('PLBERT_dir', False)
+ plbert = load_plbert(BERT_path)
+
+ # build model
+ config['model_params']["sr"] = sr
+
+ model_params = recursive_munch(config['model_params'])
+ multispeaker = model_params.multispeaker
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
+ _ = [model[key].to(device) for key in model]
+
+ # DP
+ for key in model:
+ if key != "mpd" and key != "msd" and key != "wd":
+ model[key] = accelerator.prepare(model[key])
+
+ start_epoch = 0
+ iters = 0
+
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
+
+ if not load_pretrained:
+ if config.get('first_stage_path', '') != '':
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
+ accelerator.print('Loading the first stage model at %s ...' % first_stage_path)
+ model, _, start_epoch, iters = load_checkpoint(model,
+ None,
+ first_stage_path,
+ load_only_params=True,
+ ignore_modules=['bert', 'bert_encoder', 'predictor',
+ 'predictor_encoder', 'msd', 'mpd', 'wd',
+ 'diffusion']) # keep starting epoch for tensorboard log
+
+ # these epochs should be counted from the start epoch
+ diff_epoch += start_epoch
+ joint_epoch += start_epoch
+ epochs += start_epoch
+ model.style_encoder.train()
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
+ else:
+ raise ValueError('You need to specify the path to the first stage model.')
+
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
+ wl = WavLMLoss(model_params.slm.model,
+ model.wd,
+ sr,
+ model_params.slm.sr).to(device)
+
+ gl = accelerator.prepare(gl)
+ dl = accelerator.prepare(dl)
+ wl = accelerator.prepare(wl)
+ wl = wl.eval()
+
+ sampler = DiffusionSampler(
+ model.diffusion.module.diffusion,
+ sampler=ADPM2Sampler(),
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
+ clamp=False
+ )
+
+ scheduler_params = {
+ "max_lr": optimizer_params.lr * accelerator.num_processes,
+ "pct_start": float(0),
+ "epochs": epochs,
+ "steps_per_epoch": len(train_dataloader),
+ }
+ scheduler_params_dict = {key: scheduler_params.copy() for key in model}
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
+
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
+ scheduler_params_dict=scheduler_params_dict,
+ lr=optimizer_params.lr * accelerator.num_processes)
+
+ # adjust BERT learning rate
+ for g in optimizer.optimizers['bert'].param_groups:
+ g['betas'] = (0.9, 0.99)
+ g['lr'] = optimizer_params.bert_lr
+ g['initial_lr'] = optimizer_params.bert_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 0.01
+
+ # adjust acoustic module learning rate
+ for module in ["decoder", "style_encoder"]:
+ for g in optimizer.optimizers[module].param_groups:
+ g['betas'] = (0.0, 0.99)
+ g['lr'] = optimizer_params.ft_lr
+ g['initial_lr'] = optimizer_params.ft_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 1e-4
+
+ # load models if there is a model
+ if load_pretrained:
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
+ load_only_params=config.get('load_only_params', True))
+
+ n_down = model.text_aligner.module.n_down
+
+ # for k in model:
+ # model[k] = accelerator.prepare(model[k])
+
+ best_loss = float('inf') # best test loss
+ iters = 0
+
+ criterion = nn.L1Loss() # F0 loss (regression)
+ torch.cuda.empty_cache()
+
+ stft_loss = MultiResolutionSTFTLoss().to(device)
+
+ accelerator.print('BERT', optimizer.optimizers['bert'])
+ accelerator.print('decoder', optimizer.optimizers['decoder'])
+
+ start_ds = False
+
+ running_std = []
+
+ slmadv_params = Munch(config['slmadv_params'])
+
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
+ slmadv_params.min_len,
+ slmadv_params.max_len,
+ batch_percentage=slmadv_params.batch_percentage,
+ skip_update=slmadv_params.iter,
+ sig=slmadv_params.sig
+ )
+
+ for k, v in optimizer.optimizers.items():
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
+
+ train_dataloader = accelerator.prepare(train_dataloader)
+
+ for epoch in range(start_epoch, epochs):
+ running_loss = 0
+ start_time = time.time()
+
+ _ = [model[key].eval() for key in model]
+
+ model.text_aligner.train()
+ model.text_encoder.train()
+
+ model.predictor.train()
+ model.predictor_encoder.train()
+ model.bert_encoder.train()
+ model.bert.train()
+ model.msd.train()
+ model.mpd.train()
+ model.wd.train()
+
+ if epoch >= diff_epoch:
+ start_ds = True
+
+ for i, batch in enumerate(train_dataloader):
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
+ mel_mask = length_to_mask(mel_input_length).to(device)
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ try:
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ except:
+ continue
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
+
+ # compute the style of the entire utterance
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
+ ss = []
+ gs = []
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s_dur = torch.stack(ss).squeeze(1) # global prosodic styles
+ gs = torch.stack(gs).squeeze(1) # global acoustic styles
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ # denoiser training
+ if epoch >= diff_epoch:
+ num_steps = np.random.randint(3, 5)
+
+ if model_params.diffusion.dist.estimate_sigma_data:
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(
+ axis=-1).mean().item() # batch-wise std estimation
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
+
+ if multispeaker:
+ s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ features=ref, # reference from the same speaker as the embedding
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1),
+ embedding=bert_dur).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ # print(loss_sty)
+ else:
+ # print("here")
+ loss_sty = 0
+ loss_diff = 0
+
+ d, p = model.predictor(d_en, s_dur,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+
+ # mel_len = int(mel_input_length.min().item() / 2 - 1)
+
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
+
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+ st = []
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start + mel_len])
+ p_en.append(p[bib, :, random_start:random_start + mel_len])
+ gt.append(mels[bib, :, (random_start * 2):((random_start + mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start + mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ # style reference (better to be different from the GT)
+ random_start = np.random.randint(0, mel_length - mel_len_st)
+ st.append(mels[bib, :, (random_start * 2):((random_start + mel_len_st) * 2)])
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+ st = torch.stack(st).detach()
+
+ if gt.size(-1) < 80:
+ continue
+
+ s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+
+ with torch.no_grad():
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2])
+
+ asr_real = model.text_aligner.module.get_feature(gt)
+
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec_gt = wav.unsqueeze(1)
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
+
+ if epoch >= joint_epoch:
+ # ground truth from recording
+ wav = y_rec_gt # use recording since decoder is tuned
+ else:
+ # ground truth from reconstruction
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
+
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
+
+ if start_ds:
+ optimizer.zero_grad()
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
+ accelerator.backward(d_loss)
+ optimizer.step('msd')
+ optimizer.step('mpd')
+ else:
+ d_loss = 0
+
+ # generator loss
+ optimizer.zero_grad()
+
+ loss_mel = stft_loss(y_rec, wav)
+ if start_ds:
+ loss_gen_all = gl(wav, y_rec).mean()
+ else:
+ loss_gen_all = 0
+ loss_lm = wl(wav.detach().squeeze(1), y_rec.squeeze(1)).mean()
+
+ loss_ce = 0
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for p in range(_s2s_trg.shape[0]):
+ _s2s_trg[p, :_text_input[p]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length - 1],
+ _text_input[1:_text_length - 1])
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
+
+ loss_ce /= texts.size(0)
+ loss_dur /= texts.size(0)
+
+ g_loss = loss_params.lambda_mel * loss_mel + \
+ loss_params.lambda_F0 * loss_F0_rec + \
+ loss_params.lambda_ce * loss_ce + \
+ loss_params.lambda_norm * loss_norm_rec + \
+ loss_params.lambda_dur * loss_dur + \
+ loss_params.lambda_gen * loss_gen_all + \
+ loss_params.lambda_slm * loss_lm + \
+ loss_params.lambda_sty * loss_sty + \
+ loss_params.lambda_diff * loss_diff
+
+ running_loss += accelerator.gather(loss_mel).mean().item()
+ accelerator.backward(g_loss)
+
+ # if torch.isnan(g_loss):
+ # from IPython.core.debugger import set_trace
+ # set_trace()
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('predictor_encoder')
+
+ if epoch >= diff_epoch:
+ optimizer.step('diffusion')
+
+ if epoch >= joint_epoch:
+
+ optimizer.step('style_encoder')
+ optimizer.step('decoder')
+
+ d_loss_slm, loss_gen_lm = 0, 0
+
+ # randomly pick whether to use in-distribution text
+ # if np.random.rand() < 0.5:
+ # use_ind = True
+ # else:
+ # use_ind = False
+
+ # if use_ind:
+ # ref_lengths = input_lengths
+ # ref_texts = texts
+
+ # try:
+ # slm_out = slmadv(i,
+ # accelerator,
+ # y_rec_gt,
+ # y_rec_gt_pred,
+ # waves,
+ # mel_input_length,
+ # ref_texts,
+ # ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None,
+ # )
+
+ # except:
+ # # # print('oops! sounds like something has happened!')
+ # # print('error -> ', e)
+ # slm_out = None
+
+
+
+
+ # if slm_out is None:
+ # d_loss_slm, loss_gen_lm = 0, 0
+
+ # else:
+ # d_loss_slm, loss_gen_lm, y_pred = slm_out
+ # optimizer.zero_grad()
+ # # accelerator.clip_grad_norm_(model.decoder.parameters(), 1)
+ # # print("here")
+ # accelerator.backward(loss_gen_lm)
+ # # print("here2")
+ # # SLM discriminator loss
+ # # if d_loss_slm != 0:
+ # # print("loss!", d_loss_slm)
+ # # else:
+ # # print("no loss")
+
+ # # compute the gradient norm
+
+
+ # total_norm = {}
+ # for key in model.keys():
+ # total_norm[key] = 0
+ # parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
+ # for p in parameters:
+ # param_norm = p.grad.detach().data.norm(2)
+ # total_norm[key] += param_norm.item() ** 2
+ # total_norm[key] = total_norm[key] ** 0.5
+
+ # # gradient scaling
+ # if total_norm['predictor'] > slmadv_params.thresh:
+ # for key in model.keys():
+ # for p in model[key].parameters():
+ # if p.grad is not None:
+ # p.grad *= (1 / total_norm['predictor'])
+
+ # for p in model.predictor.module.duration_proj.parameters():
+ # if p.grad is not None:
+ # p.grad *= slmadv_params.scale
+
+ # for p in model.predictor.module.lstm.parameters():
+ # if p.grad is not None:
+ # p.grad *= slmadv_params.scale
+
+ # for p in model.diffusion.module.parameters():
+ # if p.grad is not None:
+ # p.grad *= slmadv_params.scale
+
+ # optimizer.step('bert_encoder')
+ # optimizer.step('bert')
+ # optimizer.step('predictor')
+ # optimizer.step('diffusion')
+
+ # # SLM discriminator loss
+ # if d_loss_slm != 0:
+ # optimizer.zero_grad()
+ # # print("hey1")
+ # accelerator.backward(d_loss_slm, retain_graph=True)
+ # optimizer.step('wd')
+ # # print("hey2")
+
+ else:
+ d_loss_slm, loss_gen_lm = 0, 0
+
+ iters = iters + 1
+ if (i + 1) % log_interval == 0:
+ logger.info(
+ 'Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
+ % (epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, d_loss,
+ loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff,
+ d_loss_slm, loss_gen_lm), main_process_only=True)
+ if accelerator.is_main_process:
+ print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
+ % (epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, d_loss,
+ loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff,
+ d_loss_slm, loss_gen_lm))
+ accelerator.log({'train/mel_loss': float(running_loss / log_interval),
+ 'train/gen_loss': float(loss_gen_all),
+ 'train/d_loss': float(d_loss),
+ 'train/ce_loss': float(loss_ce),
+ 'train/dur_loss': float(loss_dur),
+ 'train/slm_loss': float(loss_lm),
+ 'train/norm_loss': float(loss_norm_rec),
+ 'train/F0_loss': float(loss_F0_rec),
+ 'train/sty_loss': float(loss_sty),
+ 'train/diff_loss': float(loss_diff),
+ 'train/d_loss_slm': float(d_loss_slm),
+ 'train/gen_loss_slm': float(loss_gen_lm),
+ 'epoch': int(epoch) + 1}, step=iters)
+
+ running_loss = 0
+
+ accelerator.print('Time elasped:', time.time() - start_time)
+
+ loss_test = 0
+ loss_align = 0
+ loss_f = 0
+ _ = [model[key].eval() for key in model]
+
+ with torch.no_grad():
+ iters_test = 0
+ for batch_idx, batch in enumerate(val_dataloader):
+ optimizer.zero_grad()
+
+ try:
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ # print("t_en", t_en.shape, t_en)
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ ss = []
+ gs = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s = torch.stack(ss).squeeze(1)
+ gs = torch.stack(gs).squeeze(1)
+ s_trg = torch.cat([s, gs], dim=-1).detach()
+ # print("texts", texts.shape, texts)
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+ d, p = model.predictor(d_en, s,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+ # get clips
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start + mel_len])
+ p_en.append(p[bib, :, random_start:random_start + mel_len])
+
+ gt.append(mels[bib, :, (random_start * 2):((random_start + mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start + mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+
+ s = model.predictor_encoder(gt.unsqueeze(1))
+
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s, f0=True)
+
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for bib in range(_s2s_trg.shape[0]):
+ _s2s_trg[bib, :_text_input[bib]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length - 1],
+ _text_input[1:_text_length - 1])
+
+ loss_dur /= texts.size(0)
+
+ s = model.style_encoder(gt.unsqueeze(1))
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+ loss_mel = stft_loss(y_rec.squeeze(1), wav.detach())
+
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
+
+ loss_test += accelerator.gather(loss_mel).mean()
+ loss_align += accelerator.gather(loss_dur).mean()
+ loss_f += accelerator.gather(loss_F0).mean()
+
+ iters_test += 1
+ except Exception as e:
+ accelerator.print(f"Eval errored with: \n {str(e)}")
+ continue
+
+ accelerator.print('Epochs:', epoch + 1)
+ try:
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (
+ loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n', main_process_only=True)
+
+
+ accelerator.log({'eval/mel_loss': float(loss_test / iters_test),
+ 'eval/dur_loss': float(loss_test / iters_test),
+ 'eval/F0_loss': float(loss_f / iters_test)},
+ step=(i + 1) * (epoch + 1))
+ except ZeroDivisionError:
+ accelerator.print("Eval loss was divided by zero... skipping eval cycle")
+
+ if epoch < diff_epoch:
+ # generating reconstruction examples with GT duration
+
+ with torch.no_grad():
+ for bib in range(len(asr)):
+ mel_length = int(mel_input_length[bib].item())
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
+
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
+ F0_real = F0_real.unsqueeze(0)
+ s = model.style_encoder(gt.unsqueeze(1))
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ try:
+ y_rec = model.decoder(en, F0_real.squeeze(0), real_norm, s)
+ except Exception as e:
+ accelerator.print(str(e))
+ accelerator.print(F0_real.size())
+ accelerator.print(F0_real.squeeze(0).size())
+
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
+ p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
+
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
+
+ # writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+ if accelerator.is_main_process:
+ log_audio(accelerator, y_pred.detach().cpu().numpy().squeeze(), bib, "pred/y", epoch, sr, tracker=tracker)
+
+ if epoch == 0:
+ # writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
+ if accelerator.is_main_process:
+ log_audio(accelerator, waves[bib].squeeze(), bib, "gt/y", epoch, sr, tracker=tracker)
+
+ if bib >= 10:
+ break
+ else:
+
+ try:
+ # generating sampled speech from text directly
+ with torch.no_grad():
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
+
+ for bib in range(len(d_en)):
+ if multispeaker:
+ s_pred = sampler(noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device),
+ embedding=bert_dur[bib].unsqueeze(0),
+ embedding_scale=1,
+ features=ref_s[bib].unsqueeze(0),
+ # reference from the same speaker as the embedding
+ num_steps=5).squeeze(1)
+ else:
+ s_pred = sampler(noise=torch.ones((1, 1, 256)).to(texts.device)*0.5,
+ embedding=bert_dur[bib].unsqueeze(0),
+ embedding_scale=1,
+ num_steps=5).squeeze(1)
+
+ s = s_pred[:, 128:]
+ ref = s_pred[:, :128]
+ # print(model.predictor)
+ # print(d_en[bib, :, :input_lengths[bib]])
+ d = model.predictor.module.text_encoder(d_en[bib, :, :input_lengths[bib]].unsqueeze(0),
+ s, input_lengths[bib, ...].unsqueeze(0),
+ text_mask[bib, :input_lengths[bib]].unsqueeze(0))
+
+ x = model.predictor.module.lstm(d)
+ x_mod = model.predictor.module.prepare_projection(x) # 640 -> 512
+ duration = model.predictor.module.duration_proj(x_mod)
+
+ duration = torch.sigmoid(duration).sum(axis=-1)
+ pred_dur = torch.round(duration.squeeze(0)).clamp(min=1)
+
+ pred_dur[-1] += 5
+
+ pred_aln_trg = torch.zeros(input_lengths[bib], int(pred_dur.sum().data))
+ c_frame = 0
+ for i in range(pred_aln_trg.size(0)):
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
+ c_frame += int(pred_dur[i].data)
+
+ # encode prosody
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(texts.device))
+ F0_pred, N_pred = model.predictor(texts=en, style=s, f0=True)
+ out = model.decoder(
+ (t_en[bib, :, :input_lengths[bib]].unsqueeze(0) @ pred_aln_trg.unsqueeze(0).to(texts.device)),
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
+
+ # writer.add_audio('pred/y' + str(bib), out.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+ if accelerator.is_main_process:
+ log_audio(accelerator, out.detach().cpu().numpy().squeeze(), bib, "pred/y", epoch, sr, tracker=tracker)
+
+ if bib >= 5:
+ break
+ except Exception as e:
+ accelerator.print('error -> ', e)
+ accelerator.print("some of the samples couldn't be evaluated, skipping those.")
+
+
+ if epoch % saving_epoch == 0:
+ if (loss_test / iters_test) < best_loss:
+ best_loss = loss_test / iters_test
+ try:
+ accelerator.print('Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ except ZeroDivisionError:
+ accelerator.print('No iter test, Re-Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': 0.1, # not zero just in case
+ 'epoch': epoch,
+ }
+
+ if accelerator.is_main_process:
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
+ torch.save(state, save_path)
+
+ # if estimate sigma, save the estimated simga
+ if model_params.diffusion.dist.estimate_sigma_data:
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
+
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
+ yaml.dump(config, outfile, default_flow_style=True)
+ if accelerator.is_main_process:
+ print('Saving last pth..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, '2nd_phase_last.pth')
+ torch.save(state, save_path)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/stts_48khz/StyleTTS2_48khz/infer.ipynb b/stts_48khz/StyleTTS2_48khz/infer.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..0140c7471ca9621536c3742046c9f0d6a2e5fc26
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/infer.ipynb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06f49761442c3d69c2bcc7ae3ba0aca0e9ffd47b88f3c1dd4a9ade8018176445
+size 16571199
diff --git a/stts_48khz/StyleTTS2_48khz/losses.py b/stts_48khz/StyleTTS2_48khz/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b0f4316d6af9955f437ff57169b9e9f835a7952
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/losses.py
@@ -0,0 +1,591 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+from transformers import AutoModel, WhisperConfig, WhisperPreTrainedModel
+import whisper
+
+from transformers.models.whisper.modeling_whisper import WhisperEncoder
+
+
+class SpectralConvergengeLoss(torch.nn.Module):
+ """Spectral convergence loss module."""
+
+ def __init__(self):
+ """Initilize spectral convergence loss module."""
+ super(SpectralConvergengeLoss, self).__init__()
+
+ def forward(self, x_mag, y_mag):
+ """Calculate forward propagation.
+ Args:
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+ Returns:
+ Tensor: Spectral convergence loss value.
+ """
+ return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
+
+class STFTLoss(torch.nn.Module):
+ """STFT loss module."""
+
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window):
+ """Initialize STFT loss module."""
+ super(STFTLoss, self).__init__()
+ self.fft_size = fft_size
+ self.shift_size = shift_size
+ self.win_length = win_length
+ self.to_mel = torchaudio.transforms.MelSpectrogram(sample_rate=48000, n_fft=fft_size, win_length=win_length, hop_length=shift_size, window_fn=window)
+
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+ Returns:
+ Tensor: Spectral convergence loss value.
+ Tensor: Log STFT magnitude loss value.
+ """
+ x_mag = self.to_mel(x)
+ mean, std = -4, 4
+ x_mag = (torch.log(1e-5 + x_mag) - mean) / std
+
+ y_mag = self.to_mel(y)
+ mean, std = -4, 4
+ y_mag = (torch.log(1e-5 + y_mag) - mean) / std
+
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+ return sc_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+ """Multi resolution STFT loss module."""
+
+ def __init__(self,
+ fft_sizes=[1024, 2048, 512],
+ hop_sizes=[120, 240, 50],
+ win_lengths=[600, 1200, 240],
+ window=torch.hann_window):
+ """Initialize Multi resolution STFT loss module.
+ Args:
+ fft_sizes (list): List of FFT sizes.
+ hop_sizes (list): List of hop sizes.
+ win_lengths (list): List of window lengths.
+ window (str): Window function type.
+ """
+ super(MultiResolutionSTFTLoss, self).__init__()
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
+ self.stft_losses = torch.nn.ModuleList()
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+ Returns:
+ Tensor: Multi resolution spectral convergence loss value.
+ Tensor: Multi resolution log STFT magnitude loss value.
+ """
+ sc_loss = 0.0
+ for f in self.stft_losses:
+ sc_l = f(x, y)
+ sc_loss += sc_l
+ sc_loss /= len(self.stft_losses)
+
+ return sc_loss
+
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss*2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1-dr)**2)
+ g_loss = torch.mean(dg**2)
+ loss += (r_loss + g_loss)
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean((1-dg)**2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+""" https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
+def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ tau = 0.04
+ m_DG = torch.median((dr-dg))
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
+ loss += tau - F.relu(tau - L_rel)
+ return loss
+
+def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
+ tau = 0.04
+ m_DG = torch.median((dr-dg))
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
+ loss += tau - F.relu(tau - L_rel)
+ return loss
+
+class GeneratorLoss(torch.nn.Module):
+
+ def __init__(self, mpd, msd):
+ super(GeneratorLoss, self).__init__()
+ self.mpd = mpd
+ self.msd = msd
+
+ def forward(self, y, y_hat):
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
+
+ loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
+
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
+
+ return loss_gen_all.mean()
+
+class DiscriminatorLoss(torch.nn.Module):
+
+ def __init__(self, mpd, msd):
+ super(DiscriminatorLoss, self).__init__()
+ self.mpd = mpd
+ self.msd = msd
+
+ def forward(self, y, y_hat):
+ # MPD
+ y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
+ # MSD
+ y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
+
+ loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
+
+
+ d_loss = loss_disc_s + loss_disc_f + loss_rel
+
+ return d_loss.mean()
+
+
+
+
+# #####################
+# MIXED PRECISION
+
+
+class WhisperEncoderOnly(WhisperPreTrainedModel):
+ def __init__(self, config: WhisperConfig):
+ super().__init__(config)
+ self.encoder = WhisperEncoder(config)
+
+ def forward(self, input_features, attention_mask=None):
+ return self.encoder(input_features, attention_mask)
+
+
+
+class WavLMLoss(torch.nn.Module):
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
+ super(WavLMLoss, self).__init__()
+
+ config = WhisperConfig.from_pretrained("Respair/Whisper_Large_v2_Encoder_Block")
+
+ # this will load the full model and keep only the encoder
+ full_model = WhisperEncoderOnly.from_pretrained("openai/whisper-large-v2", config=config, device_map='auto',torch_dtype=torch.bfloat16)
+ model = WhisperEncoderOnly(config)
+ model.encoder.load_state_dict(full_model.encoder.state_dict())
+ del full_model
+
+
+ self.wavlm = model.to(torch.bfloat16)
+ self.wd = wd
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
+
+ def forward(self, wav, y_rec, generator=False, discriminator=False, discriminator_forward=False):
+
+ if generator:
+ y_rec = y_rec.squeeze(1)
+
+
+ y_rec = whisper.pad_or_trim(y_rec)
+ y_rec = whisper.log_mel_spectrogram(y_rec)
+
+ with torch.no_grad():
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+ y_df_hat_g = self.wd(y_rec_embeddings.to(torch.float32))
+ loss_gen = torch.mean((1-y_df_hat_g)**2)
+
+ return loss_gen.to(torch.float32)
+
+ elif discriminator:
+
+ wav = wav.squeeze(1)
+ y_rec = y_rec.squeeze(1)
+
+ wav = whisper.pad_or_trim(wav)
+ wav = whisper.log_mel_spectrogram(wav)
+
+ y_rec = whisper.pad_or_trim(y_rec)
+ y_rec = whisper.log_mel_spectrogram(y_rec)
+
+ with torch.no_grad():
+ wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
+
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+ y_d_rs = self.wd(y_embeddings.to(torch.float32))
+ y_d_gs = self.wd(y_rec_embeddings.to(torch.float32))
+
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
+
+ r_loss = torch.mean((1-y_df_hat_r)**2)
+ g_loss = torch.mean((y_df_hat_g)**2)
+
+ loss_disc_f = r_loss + g_loss
+
+ return loss_disc_f.mean().to(torch.float32)
+
+
+
+ elif discriminator_forward:
+ # Squeeze the channel dimension if it's unnecessary
+ wav = wav.squeeze(1) # Adjust this line if the channel dimension is not at dim=1
+
+
+ with torch.no_grad():
+
+ wav_16 = self.resample(wav)
+ wav_16 = whisper.pad_or_trim(wav_16)
+ wav_16 = whisper.log_mel_spectrogram(wav_16)
+
+ wav_embeddings = self.wavlm.encoder(wav_16.to(torch.bfloat16) , output_hidden_states=True).hidden_states
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+ y_d_rs = self.wd(y_embeddings.to(torch.float32))
+
+ return y_d_rs
+
+ else:
+
+ wav = wav.squeeze(1)
+ y_rec = y_rec.squeeze(1)
+
+ wav = whisper.pad_or_trim(wav)
+ wav = whisper.log_mel_spectrogram(wav)
+
+ y_rec = whisper.pad_or_trim(y_rec)
+ y_rec = whisper.log_mel_spectrogram(y_rec)
+
+ with torch.no_grad():
+ wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
+
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
+
+
+ floss = 0
+ for er, eg in zip([e.to(torch.float32) for e in wav_embeddings], [e.to(torch.float32) for e in y_rec_embeddings]):
+ floss += torch.mean(torch.abs(er - eg))
+
+ return floss.mean()
+
+
+
+ def generator(self, y_rec):
+
+ y_rec = y_rec.squeeze(1)
+
+
+ y_rec = whisper.pad_or_trim(y_rec)
+ y_rec = whisper.log_mel_spectrogram(y_rec)
+
+ with torch.no_grad():
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+ y_df_hat_g = self.wd(y_rec_embeddings.to(torch.float32))
+ loss_gen = torch.mean((1-y_df_hat_g)**2)
+
+ return loss_gen.to(torch.float32)
+
+ def discriminator(self, wav, y_rec):
+
+ wav = wav.squeeze(1)
+ y_rec = y_rec.squeeze(1)
+
+ wav = whisper.pad_or_trim(wav)
+ wav = whisper.log_mel_spectrogram(wav)
+
+ y_rec = whisper.pad_or_trim(y_rec)
+ y_rec = whisper.log_mel_spectrogram(y_rec)
+
+ with torch.no_grad():
+ wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
+
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+ y_d_rs = self.wd(y_embeddings.to(torch.float32))
+ y_d_gs = self.wd(y_rec_embeddings.to(torch.float32))
+
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
+
+ r_loss = torch.mean((1-y_df_hat_r)**2)
+ g_loss = torch.mean((y_df_hat_g)**2)
+
+ loss_disc_f = r_loss + g_loss
+
+ return loss_disc_f.mean().to(torch.float32)
+
+
+
+
+ def discriminator_forward(self, wav):
+ # Squeeze the channel dimension if it's unnecessary
+ wav = wav.squeeze(1) # Adjust this line if the channel dimension is not at dim=1
+
+
+ with torch.no_grad():
+
+ wav_16 = self.resample(wav)
+ wav_16 = whisper.pad_or_trim(wav_16)
+ wav_16 = whisper.log_mel_spectrogram(wav_16)
+
+ wav_embeddings = self.wavlm.encoder(wav_16.to(torch.bfloat16) , output_hidden_states=True).hidden_states
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+ y_d_rs = self.wd(y_embeddings.to(torch.float32))
+
+ return y_d_rs
+
+
+
+
+ # def discriminator_forward(self, wav):
+ # # Squeeze the channel dimension if it's unnecessary
+ # wav = wav.squeeze(1) # Adjust this line if the channel dimension is not at dim=1
+
+
+ # with torch.no_grad():
+
+ # wav_16 = self.resample(wav)
+ # wav_16 = whisper.pad_or_trim(wav_16)
+ # wav_16 = whisper.log_mel_spectrogram(wav_16)
+
+ # wav_embeddings = self.wavlm.encoder(wav_16.to(torch.bfloat16) , output_hidden_states=True).hidden_states
+ # y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+ # y_d_rs = self.wd(y_embeddings.to(torch.bfloat16))
+
+ # return y_d_rs
+
+
+
+################### Whisper - No Mixed Precision!
+
+# class WhisperEncoderOnly(WhisperPreTrainedModel):
+# def __init__(self, config: WhisperConfig):
+# super().__init__(config)
+# self.encoder = WhisperEncoder(config)
+
+# def forward(self, input_features, attention_mask=None):
+# return self.encoder(input_features, attention_mask)
+
+# class WavLMLoss(torch.nn.Module):
+
+# def __init__(self, model, wd, model_sr, slm_sr=16000):
+# super(WavLMLoss, self).__init__()
+
+# config = WhisperConfig.from_pretrained("Respair/Whisper_Large_v2_Encoder_Block")
+
+# # this will load the full model and keep only the encoder
+# full_model = WhisperEncoderOnly.from_pretrained("openai/whisper-large-v2", config=config, device_map='auto')
+# model = WhisperEncoderOnly(config)
+# model.encoder.load_state_dict(full_model.encoder.state_dict())
+# del full_model
+
+
+# self.wavlm = model
+# self.wd = wd
+# self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
+
+# def forward(self, wav, y_rec):
+
+# wav = wav.squeeze(1)
+# y_rec = y_rec.squeeze(1)
+
+# wav = whisper.pad_or_trim(wav)
+# wav = whisper.log_mel_spectrogram(wav)
+
+# y_rec = whisper.pad_or_trim(y_rec)
+# y_rec = whisper.log_mel_spectrogram(y_rec)
+
+# with torch.no_grad():
+# wav_embeddings = self.wavlm.encoder(wav, output_hidden_states=True).hidden_states
+# y_rec_embeddings = self.wavlm.encoder(y_rec, output_hidden_states=True).hidden_states
+
+# floss = 0
+# for er, eg in zip(wav_embeddings, y_rec_embeddings):
+# floss += torch.mean(torch.abs(er - eg))
+
+# return floss.mean()
+
+# def generator(self, y_rec):
+
+# y_rec = y_rec.squeeze(1)
+
+# y_rec = whisper.pad_or_trim(y_rec)
+# y_rec = whisper.log_mel_spectrogram(y_rec)
+
+# with torch.no_grad():
+# y_rec_embeddings = self.wavlm.encoder(y_rec, output_hidden_states=True).hidden_states
+# y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+# y_df_hat_g = self.wd(y_rec_embeddings)
+# loss_gen = torch.mean((1-y_df_hat_g)**2)
+
+# return loss_gen
+
+# def discriminator(self, wav, y_rec):
+
+# wav = wav.squeeze(1)
+# y_rec = y_rec.squeeze(1)
+
+# wav = whisper.pad_or_trim(wav)
+# wav = whisper.log_mel_spectrogram(wav)
+
+# y_rec = whisper.pad_or_trim(y_rec)
+# y_rec = whisper.log_mel_spectrogram(y_rec)
+
+# with torch.no_grad():
+# wav_embeddings = self.wavlm.encoder(wav, output_hidden_states=True).hidden_states
+# y_rec_embeddings = self.wavlm.encoder(y_rec, output_hidden_states=True).hidden_states
+
+# y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+# y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+# y_d_rs = self.wd(y_embeddings)
+# y_d_gs = self.wd(y_rec_embeddings)
+
+# y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
+
+# r_loss = torch.mean((1-y_df_hat_r)**2)
+# g_loss = torch.mean((y_df_hat_g)**2)
+
+# loss_disc_f = r_loss + g_loss
+
+# return loss_disc_f.mean()
+
+
+# def discriminator_forward(self, wav):
+# # Squeeze the channel dimension if it's unnecessary
+# wav = wav.squeeze(1) # Adjust this line if the channel dimension is not at dim=1
+
+
+# with torch.no_grad():
+
+# wav_16 = self.resample(wav)
+# wav_16 = whisper.pad_or_trim(wav_16)
+# wav_16 = whisper.log_mel_spectrogram(wav_16)
+
+# wav_embeddings = self.wavlm.encoder(wav_16, output_hidden_states=True).hidden_states
+# y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+# y_d_rs = self.wd(y_embeddings)
+
+# return y_d_rs
+
+
+
+###### ORIGINAL
+# class WavLMLoss(torch.nn.Module):
+
+# def __init__(self, model, wd, model_sr, slm_sr=16000):
+# super(WavLMLoss, self).__init__()
+# self.wavlm = AutoModel.from_pretrained(model)
+# self.wd = wd
+# self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
+
+# def forward(self, wav, y_rec):
+
+
+
+# with torch.no_grad():
+# wav_16 = self.resample(wav)
+# wav_embeddings = self.wavlm(wav_16, output_hidden_states=True).hidden_states
+# y_rec_16 = self.resample(y_rec)
+# y_rec_embeddings = self.wavlm(y_rec_16.squeeze(), output_hidden_states=True).hidden_states
+
+# floss = 0
+# for er, eg in zip(wav_embeddings, y_rec_embeddings):
+# floss += torch.mean(torch.abs(er - eg))
+
+# return floss.mean()
+
+# def generator(self, y_rec):
+# y_rec_16 = self.resample(y_rec)
+# y_rec_embeddings = self.wavlm(y_rec_16, output_hidden_states=True).hidden_states
+# y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+# y_df_hat_g = self.wd(y_rec_embeddings)
+# loss_gen = torch.mean((1-y_df_hat_g)**2)
+
+# return loss_gen
+
+# def discriminator(self, wav, y_rec):
+# with torch.no_grad():
+# wav_16 = self.resample(wav)
+# wav_embeddings = self.wavlm(wav_16, output_hidden_states=True).hidden_states
+# y_rec_16 = self.resample(y_rec)
+# y_rec_embeddings = self.wavlm(y_rec_16, output_hidden_states=True).hidden_states
+
+# y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+# y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+# y_d_rs = self.wd(y_embeddings)
+# y_d_gs = self.wd(y_rec_embeddings)
+
+# y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
+
+# r_loss = torch.mean((1-y_df_hat_r)**2)
+# g_loss = torch.mean((y_df_hat_g)**2)
+
+# loss_disc_f = r_loss + g_loss
+
+# return loss_disc_f.mean()
+
+# def discriminator_forward(self, wav):
+# with torch.no_grad():
+# wav_16 = self.resample(wav)
+# wav_embeddings = self.wavlm(wav_16, output_hidden_states=True).hidden_states
+# y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
+
+# y_d_rs = self.wd(y_embeddings)
+
+# return y_d_rs
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/meldataset.py b/stts_48khz/StyleTTS2_48khz/meldataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6390ea0866bebc00ceda0f30d82c8ab9d697ea
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/meldataset.py
@@ -0,0 +1,276 @@
+import os
+import os.path as osp
+import time
+import random
+import numpy as np
+import random
+import soundfile as sf
+import librosa
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+from torch.utils.data import DataLoader
+
+import logging
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+import pandas as pd
+
+# _pad = "$"
+# _punctuation = ';:,.!?¡¿—…"«»“” '
+# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
+# _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
+# _vokan_symbols = "()♪😂💨😮🥱😱😡😭1234567890`-~⁼"
+# _ligature = "͡"
+# # Export all symbols:
+# symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + list(_vokan_symbols) + list(_ligature)
+
+# IPA Phonemizer: https://github.com/bootphon/phonemizer
+
+_pad = "$"
+_punctuation = ';:,.!?¡¿—…"«»“” '
+_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
+_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
+
+# Export all symbols:
+symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
+
+
+dicts = {}
+for i in range(len((symbols))):
+ dicts[symbols[i]] = i
+
+class TextCleaner:
+ def __init__(self, dummy=None):
+ self.word_index_dictionary = dicts
+ def __call__(self, text):
+ indexes = []
+ for char in text:
+ try:
+ indexes.append(self.word_index_dictionary[char])
+ except KeyError:
+ print(text)
+ return indexes
+
+np.random.seed(1)
+random.seed(1)
+SPECT_PARAMS = {
+ "n_fft": 2048,
+ "win_length": 2048,
+ "hop_length": 512,
+}
+MEL_PARAMS = {
+ "n_mels": 80,
+}
+
+
+to_mel = torchaudio.transforms.MelSpectrogram(
+ n_mels=80, n_fft=2048, win_length=2048, hop_length=512)
+mean, std = -4, 4
+
+def preprocess(wave):
+ wave_tensor = torch.from_numpy(wave).float()
+ mel_tensor = to_mel(wave_tensor)
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
+ return mel_tensor
+
+class FilePathDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ data_list,
+ root_path,
+ sr=48000,
+ data_augmentation=False,
+ validation=False,
+ OOD_data="Data/OOD_texts.txt",
+ min_length=50,
+ ):
+
+ spect_params = SPECT_PARAMS
+ mel_params = MEL_PARAMS
+
+ _data_list = [l.strip().split('|') for l in data_list]
+ self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
+ self.text_cleaner = TextCleaner()
+ self.sr = sr
+
+ self.df = pd.DataFrame(self.data_list)
+
+ self.to_melspec = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=2048, hop_length=512)
+
+ self.mean, self.std = -4, 4
+ self.data_augmentation = data_augmentation and (not validation)
+ self.max_mel_length = 192
+
+ self.min_length = min_length
+ with open(OOD_data, 'r', encoding='utf-8') as f:
+ tl = f.readlines()
+ idx = 1 if '.wav' in tl[0].split('|')[0] else 0
+ self.ptexts = [t.split('|')[idx] for t in tl]
+
+ self.root_path = root_path
+
+ def __len__(self):
+ return len(self.data_list)
+
+ def __getitem__(self, idx):
+ data = self.data_list[idx]
+ path = data[0]
+
+ wave, text_tensor, speaker_id = self._load_tensor(data)
+
+ mel_tensor = preprocess(wave).squeeze()
+
+ acoustic_feature = mel_tensor.squeeze()
+ length_feature = acoustic_feature.size(1)
+ acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
+
+ # get reference sample
+ ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist()
+ ref_mel_tensor, ref_label = self._load_data(ref_data[:3])
+
+ # get OOD text
+
+ ps = ""
+
+ while len(ps) < self.min_length:
+ rand_idx = np.random.randint(0, len(self.ptexts) - 1)
+ ps = self.ptexts[rand_idx]
+
+ text = self.text_cleaner(ps)
+
+ # if len(text) > 505:
+ # text = text[:505]
+ if text[-1] not in [4, 5, 6] :
+ text = text + [4]
+
+ text.insert(0, 0)
+ text.append(0)
+
+
+
+ ref_text = torch.LongTensor(text)
+
+ return speaker_id, acoustic_feature, text_tensor, ref_text, ref_mel_tensor, ref_label, path, wave
+
+ def _load_tensor(self, data):
+ wave_path, text, speaker_id = data
+ speaker_id = int(speaker_id)
+ wave, sr = sf.read(osp.join(self.root_path, wave_path))
+ if wave.shape[-1] == 2:
+ wave = wave[:, 0].squeeze()
+ if sr != 48000:
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=48000)
+ print(wave_path, sr)
+
+ wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
+
+ text = self.text_cleaner(text)
+
+ text.insert(0, 0)
+ text.append(0)
+
+ text = torch.LongTensor(text)
+
+ return wave, text, speaker_id
+
+ def _load_data(self, data):
+ wave, text_tensor, speaker_id = self._load_tensor(data)
+ mel_tensor = preprocess(wave).squeeze()
+
+ mel_length = mel_tensor.size(1)
+ if mel_length > self.max_mel_length:
+ random_start = np.random.randint(0, mel_length - self.max_mel_length)
+ mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length]
+
+ return mel_tensor, speaker_id
+
+
+class Collater(object):
+ """
+ Args:
+ adaptive_batch_size (bool): if true, decrease batch size when long data comes.
+ """
+
+ def __init__(self, return_wave=False):
+ self.text_pad_index = 0
+ self.min_mel_length = 192
+ self.max_mel_length = 192
+ self.return_wave = return_wave
+
+
+ def __call__(self, batch):
+ # batch[0] = wave, mel, text, f0, speakerid
+ batch_size = len(batch)
+
+ # sort by mel length
+ lengths = [b[1].shape[1] for b in batch]
+ batch_indexes = np.argsort(lengths)[::-1]
+ batch = [batch[bid] for bid in batch_indexes]
+
+ nmels = batch[0][1].size(0)
+ max_mel_length = max([b[1].shape[1] for b in batch])
+ max_text_length = max([b[2].shape[0] for b in batch])
+ max_rtext_length = max([b[3].shape[0] for b in batch])
+
+ labels = torch.zeros((batch_size)).long()
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
+ texts = torch.zeros((batch_size, max_text_length)).long()
+ ref_texts = torch.zeros((batch_size, max_rtext_length)).long()
+
+ input_lengths = torch.zeros(batch_size).long()
+ ref_lengths = torch.zeros(batch_size).long()
+ output_lengths = torch.zeros(batch_size).long()
+ ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float()
+ ref_labels = torch.zeros((batch_size)).long()
+ paths = ['' for _ in range(batch_size)]
+ waves = [None for _ in range(batch_size)]
+
+ for bid, (label, mel, text, ref_text, ref_mel, ref_label, path, wave) in enumerate(batch):
+ mel_size = mel.size(1)
+ text_size = text.size(0)
+ rtext_size = ref_text.size(0)
+ labels[bid] = label
+ mels[bid, :, :mel_size] = mel
+ texts[bid, :text_size] = text
+ ref_texts[bid, :rtext_size] = ref_text
+ input_lengths[bid] = text_size
+ ref_lengths[bid] = rtext_size
+ output_lengths[bid] = mel_size
+ paths[bid] = path
+ ref_mel_size = ref_mel.size(1)
+ ref_mels[bid, :, :ref_mel_size] = ref_mel
+
+ ref_labels[bid] = ref_label
+ waves[bid] = wave
+
+
+
+ return waves, texts, input_lengths, ref_texts, ref_lengths, mels, output_lengths, ref_mels
+
+
+
+def build_dataloader(path_list,
+ root_path,
+ validation=False,
+ OOD_data="Data/OOD_texts.txt",
+ min_length=50,
+ batch_size=4,
+ num_workers=1,
+ device='cpu',
+ collate_config={},
+ dataset_config={}):
+
+ dataset = FilePathDataset(path_list, root_path, OOD_data=OOD_data, min_length=min_length, validation=validation, **dataset_config)
+ collate_fn = Collater(**collate_config)
+ data_loader = DataLoader(dataset,
+ batch_size=batch_size,
+ shuffle=(not validation),
+ num_workers=num_workers,
+ drop_last=True,
+ collate_fn=collate_fn,
+ pin_memory=(device != 'cpu'))
+
+ return data_loader
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/mixed_precision_2nd_stage.py b/stts_48khz/StyleTTS2_48khz/mixed_precision_2nd_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..54168c548b15a3d21843123ef44ecea4acd64b41
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/mixed_precision_2nd_stage.py
@@ -0,0 +1,833 @@
+# load packages
+import random
+import yaml
+import time
+from munch import Munch
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+import click
+import shutil
+import traceback
+import warnings
+warnings.simplefilter('ignore')
+from torch.utils.tensorboard import SummaryWriter
+
+from meldataset import build_dataloader
+
+from Utils.ASR.models import ASRCNN
+from Utils.JDC.model import JDCNet
+from Utils.PLBERT.util import load_plbert
+
+from models import *
+from losses import *
+from utils import *
+
+from Modules.slmadv import SLMAdversarialLoss
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+
+from optimizers import build_optimizer
+
+# simple fix for dataparallel that allows access to class attributes
+class MyDataParallel(torch.nn.DataParallel):
+ def __getattr__(self, name):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.module, name)
+
+import logging
+from logging import StreamHandler
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+handler = StreamHandler()
+handler.setLevel(logging.DEBUG)
+logger.addHandler(handler)
+
+
+@click.command()
+@click.option('-p', '--config_path', default='Configs/config.yml', type=str)
+def main(config_path):
+ config = yaml.safe_load(open(config_path))
+
+ log_dir = config['log_dir']
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
+ writer = SummaryWriter(log_dir + "/tensorboard")
+
+ # write logs
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
+ logger.addHandler(file_handler)
+
+
+ batch_size = config.get('batch_size', 10)
+
+ save_iter = 2200
+
+ epochs = config.get('epochs_2nd', 200)
+ save_freq = config.get('save_freq', 2)
+ log_interval = config.get('log_interval', 10)
+ saving_epoch = config.get('save_freq', 2)
+
+ data_params = config.get('data_params', None)
+ sr = config['preprocess_params'].get('sr', 48000)
+ train_path = data_params['train_data']
+ val_path = data_params['val_data']
+ root_path = data_params['root_path']
+ min_length = data_params['min_length']
+ OOD_data = data_params['OOD_data']
+
+ max_len = config.get('max_len', 200)
+
+ loss_params = Munch(config['loss_params'])
+ diff_epoch = loss_params.diff_epoch
+ joint_epoch = loss_params.joint_epoch
+
+ optimizer_params = Munch(config['optimizer_params'])
+
+ train_list, val_list = get_data_path_list(train_path, val_path)
+ device = 'cuda'
+
+ train_dataloader = build_dataloader(train_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ num_workers=2,
+ dataset_config={},
+ device=device)
+
+ val_dataloader = build_dataloader(val_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ validation=True,
+ num_workers=0,
+ device=device,
+ dataset_config={})
+
+ # load pretrained ASR model
+ ASR_config = config.get('ASR_config', False)
+ ASR_path = config.get('ASR_path', False)
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
+
+ # load pretrained F0 model
+ F0_path = config.get('F0_path', False)
+ pitch_extractor = load_F0_models(F0_path)
+
+ # load PL-BERT model
+ BERT_path = config.get('PLBERT_dir', False)
+ plbert = load_plbert(BERT_path)
+
+ # build model
+ model_params = recursive_munch(config['model_params'])
+ multispeaker = model_params.multispeaker
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
+ _ = [model[key].to(device) for key in model]
+
+ # DP
+ for key in model:
+ if key != "mpd" and key != "msd" and key != "wd":
+ model[key] = MyDataParallel(model[key])
+
+ start_epoch = 0
+ iters = 0
+
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
+
+ if not load_pretrained:
+ if config.get('first_stage_path', '') != '':
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
+ print('Loading the first stage model at %s ...' % first_stage_path)
+ model, _, start_epoch, iters = load_checkpoint(model,
+ None,
+ first_stage_path,
+ load_only_params=True,
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
+
+ # these epochs should be counted from the start epoch
+ diff_epoch += start_epoch
+ joint_epoch += start_epoch
+ epochs += start_epoch
+
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
+ else:
+ raise ValueError('You need to specify the path to the first stage model.')
+
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
+ wl = WavLMLoss(model_params.slm.model,
+ model.wd,
+ sr,
+ model_params.slm.sr).to(device)
+
+ gl = MyDataParallel(gl)
+ dl = MyDataParallel(dl)
+ wl = MyDataParallel(wl)
+
+ sampler = DiffusionSampler(
+ model.diffusion.diffusion,
+ sampler=ADPM2Sampler(),
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
+ clamp=False
+ )
+
+ scheduler_params = {
+ "max_lr": optimizer_params.lr,
+ "pct_start": float(0),
+ "epochs": epochs,
+ "steps_per_epoch": len(train_dataloader),
+ }
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
+
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
+
+ # adjust BERT learning rate
+ for g in optimizer.optimizers['bert'].param_groups:
+ g['betas'] = (0.9, 0.99)
+ g['lr'] = optimizer_params.bert_lr
+ g['initial_lr'] = optimizer_params.bert_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 0.01
+
+ # adjust acoustic module learning rate
+ for module in ["decoder", "style_encoder"]:
+ for g in optimizer.optimizers[module].param_groups:
+ g['betas'] = (0.0, 0.99)
+ g['lr'] = optimizer_params.ft_lr
+ g['initial_lr'] = optimizer_params.ft_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 1e-4
+
+ # load models if there is a model
+ if load_pretrained:
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
+ load_only_params=config.get('load_only_params', True))
+
+ n_down = model.text_aligner.n_down
+
+ best_loss = float('inf') # best test loss
+ loss_train_record = list([])
+ loss_test_record = list([])
+ iters = 0
+
+ criterion = nn.L1Loss() # F0 loss (regression)
+ torch.cuda.empty_cache()
+
+ stft_loss = MultiResolutionSTFTLoss().to(device)
+
+ print('BERT', optimizer.optimizers['bert'])
+ print('decoder', optimizer.optimizers['decoder'])
+
+ start_ds = False
+
+ running_std = []
+
+ slmadv_params = Munch(config['slmadv_params'])
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
+ slmadv_params.min_len,
+ slmadv_params.max_len,
+ batch_percentage=slmadv_params.batch_percentage,
+ skip_update=slmadv_params.iter,
+ sig=slmadv_params.sig
+ )
+
+
+ for epoch in range(start_epoch, epochs):
+ running_loss = 0
+ start_time = time.time()
+
+ _ = [model[key].eval() for key in model]
+
+ model.predictor.train()
+ model.bert_encoder.train()
+ model.bert.train()
+ model.msd.train()
+ model.mpd.train()
+
+
+ if epoch >= diff_epoch:
+ start_ds = True
+
+ for i, batch in enumerate(train_dataloader):
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+
+ with torch.no_grad():
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
+ mel_mask = length_to_mask(mel_input_length).to(device)
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ try:
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ except:
+ continue
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
+
+ # compute the style of the entire utterance
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
+ ss = []
+ gs = []
+ for bib in range(len(mel_input_length)):
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
+ gs = torch.stack(gs).squeeze() # global acoustic styles
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ # denoiser training
+ if epoch >= diff_epoch:
+ num_steps = np.random.randint(3, 5)
+
+ if model_params.diffusion.dist.estimate_sigma_data:
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
+
+ if multispeaker:
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ features=ref, # reference from the same speaker as the embedding
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ loss_sty = 0
+ loss_diff = 0
+
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ d, p = model.predictor(d_en, s_dur,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+ st = []
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ # style reference (better to be different from the GT)
+ random_start = np.random.randint(0, mel_length - mel_len_st)
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+ st = torch.stack(st).detach()
+
+ if gt.size(-1) < 80:
+ continue
+
+ s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+
+ with torch.no_grad():
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
+
+ asr_real = model.text_aligner.get_feature(gt)
+
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec_gt = wav.unsqueeze(1)
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
+
+ if epoch >= joint_epoch:
+ # ground truth from recording
+ wav = y_rec_gt # use recording since decoder is tuned
+ else:
+ # ground truth from reconstruction
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
+
+ if start_ds:
+ optimizer.zero_grad()
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
+ d_loss.backward()
+ optimizer.step('msd')
+ optimizer.step('mpd')
+ else:
+ d_loss = 0
+
+ # generator loss
+ optimizer.zero_grad()
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ loss_mel = stft_loss(y_rec, wav)
+ if start_ds:
+ loss_gen_all = gl(wav, y_rec).mean()
+ else:
+ loss_gen_all = 0
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
+
+ loss_ce = 0
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for p in range(_s2s_trg.shape[0]):
+ _s2s_trg[p, :_text_input[p]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
+ _text_input[1:_text_length-1])
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
+
+ loss_ce /= texts.size(0)
+ loss_dur /= texts.size(0)
+
+ g_loss = loss_params.lambda_mel * loss_mel + \
+ loss_params.lambda_F0 * loss_F0_rec + \
+ loss_params.lambda_ce * loss_ce + \
+ loss_params.lambda_norm * loss_norm_rec + \
+ loss_params.lambda_dur * loss_dur + \
+ loss_params.lambda_gen * loss_gen_all + \
+ loss_params.lambda_slm * loss_lm + \
+ loss_params.lambda_sty * loss_sty + \
+ loss_params.lambda_diff * loss_diff
+
+ running_loss += loss_mel.item()
+ g_loss.backward()
+ # if torch.isnan(g_loss):
+ # from IPython.core.debugger import set_trace
+ # set_trace()
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('predictor_encoder')
+
+ if epoch >= diff_epoch:
+ optimizer.step('diffusion')
+
+ if epoch >= joint_epoch:
+ optimizer.step('style_encoder')
+ optimizer.step('decoder')
+
+ # randomly pick whether to use in-distribution text
+ if np.random.rand() < 0.5:
+ use_ind = True
+ else:
+ use_ind = False
+
+ if use_ind:
+ ref_lengths = input_lengths
+ ref_texts = texts
+
+ slm_out = slmadv(i,
+ y_rec_gt,
+ y_rec_gt_pred,
+ waves,
+ mel_input_length,
+ ref_texts,
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
+
+ if slm_out is None:
+ continue
+
+
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
+
+ # SLM generator loss
+ optimizer.zero_grad()
+ loss_gen_lm.backward()
+
+ # compute the gradient norm
+ total_norm = {}
+ for key in model.keys():
+ total_norm[key] = 0
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
+ for p in parameters:
+ param_norm = p.grad.detach().data.norm(2)
+ total_norm[key] += param_norm.item() ** 2
+ total_norm[key] = total_norm[key] ** 0.5
+
+ # gradient scaling
+ if total_norm['predictor'] > slmadv_params.thresh:
+ for key in model.keys():
+ for p in model[key].parameters():
+ if p.grad is not None:
+ p.grad *= (1 / total_norm['predictor'])
+
+ for p in model.predictor.duration_proj.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ for p in model.predictor.lstm.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ for p in model.diffusion.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('diffusion')
+
+ # SLM discriminator loss
+ if d_loss_slm != 0:
+ optimizer.zero_grad()
+ d_loss_slm.backward(retain_graph=True)
+ optimizer.step('wd')
+
+ else:
+ d_loss_slm, loss_gen_lm = 0, 0
+
+ iters = iters + 1
+
+ if (i+1)%log_interval == 0:
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm))
+
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
+ writer.add_scalar('train/d_loss', d_loss, iters)
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
+
+ running_loss = 0
+
+ print('Time elasped:', time.time()-start_time)
+
+ if (i+1)%save_iter == 0:
+ print(f'Saving on step {epoch*len(train_dataloader)+i}...')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, f'2nd_phase_{epoch*len(train_dataloader)+i}.pth')
+ torch.save(state, save_path)
+
+ loss_test = 0
+ loss_align = 0
+ loss_f = 0
+ _ = [model[key].eval() for key in model]
+
+ with torch.no_grad():
+ iters_test = 0
+ for batch_idx, batch in enumerate(val_dataloader):
+ optimizer.zero_grad()
+
+ try:
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+ with torch.no_grad():
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ ss = []
+ gs = []
+
+ for bib in range(len(mel_input_length)):
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s = torch.stack(ss).squeeze()
+ gs = torch.stack(gs).squeeze()
+ s_trg = torch.cat([s, gs], dim=-1).detach()
+
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+ d, p = model.predictor(d_en, s,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+ # get clips
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
+
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+
+ s = model.predictor_encoder(gt.unsqueeze(1))
+
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
+
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for bib in range(_s2s_trg.shape[0]):
+ _s2s_trg[bib, :_text_input[bib]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
+ _text_input[1:_text_length-1])
+
+ loss_dur /= texts.size(0)
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+
+ s = model.style_encoder(gt.unsqueeze(1))
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
+
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
+
+ loss_test += (loss_mel).mean()
+ loss_align += (loss_dur).mean()
+ loss_f += (loss_F0).mean()
+
+ iters_test += 1
+ except Exception as e:
+ print(f"run into exception", e)
+ traceback.print_exc()
+ continue
+
+ print('Epochs:', epoch + 1)
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
+ print('\n\n\n')
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
+ writer.add_scalar('eval/dur_loss', loss_align / iters_test, epoch + 1)
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
+
+ if epoch < joint_epoch:
+ # generating reconstruction examples with GT duration
+
+ with torch.no_grad():
+ for bib in range(len(asr)):
+ mel_length = int(mel_input_length[bib].item())
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
+ F0_real = F0_real.unsqueeze(0)
+ s = model.style_encoder(gt.unsqueeze(1))
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec = model.decoder(en, F0_real, real_norm, s)
+
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
+ p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
+
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
+
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
+
+ writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+
+ if epoch == 0:
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
+
+ if bib >= 5:
+ break
+ else:
+ # generating sampled speech from text directly
+ with torch.no_grad():
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
+
+ for bib in range(len(d_en)):
+ if multispeaker:
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
+ embedding=bert_dur[bib].unsqueeze(0),
+ embedding_scale=1,
+ features=ref_s[bib].unsqueeze(0), # reference from the same speaker as the embedding
+ num_steps=5).squeeze(1)
+ else:
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
+ embedding=bert_dur[bib].unsqueeze(0),
+ embedding_scale=1,
+ num_steps=5).squeeze(1)
+
+ s = s_pred[:, 128:]
+ ref = s_pred[:, :128]
+
+
+
+ d = model.predictor.text_encoder(d_en[bib, :, :input_lengths[bib]].unsqueeze(0),
+ s, input_lengths[bib, ...].unsqueeze(0), text_mask[bib, :input_lengths[bib]].unsqueeze(0))
+
+ x = model.predictor.lstm(d)
+ x_mod = model.predictor.prepare_projection(x) # 640 -> 512
+ duration = model.predictor.duration_proj(x_mod)
+
+ duration = torch.sigmoid(duration).sum(axis=-1)
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
+
+ pred_dur[-1] += 5
+
+ pred_aln_trg = torch.zeros(input_lengths[bib], int(pred_dur.sum().data))
+ c_frame = 0
+ for i in range(pred_aln_trg.size(0)):
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
+ c_frame += int(pred_dur[i].data)
+
+ # encode prosody
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(texts.device))
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
+ out = model.decoder((t_en[bib, :, :input_lengths[bib]].unsqueeze(0) @ pred_aln_trg.unsqueeze(0).to(texts.device)),
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
+
+ writer.add_audio('pred/y' + str(bib), out.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+
+ if bib >= 5:
+ break
+
+ if epoch % saving_epoch == 0:
+ if (loss_test / iters_test) < best_loss:
+ best_loss = loss_test / iters_test
+ print('Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
+ torch.save(state, save_path)
+
+ # if estimate sigma, save the estimated simga
+ if model_params.diffusion.dist.estimate_sigma_data:
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
+
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
+ yaml.dump(config, outfile, default_flow_style=True)
+
+if __name__=="__main__":
+ main()
diff --git a/stts_48khz/StyleTTS2_48khz/models.py b/stts_48khz/StyleTTS2_48khz/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf6c0737bf774eb98857de194cf27ffe22728e70
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/models.py
@@ -0,0 +1,2656 @@
+# import os
+# import os.path as osp
+
+# import copy
+# import math
+
+# import numpy as np
+# import torch
+# import torch.nn as nn
+# import torch.nn.functional as F
+# from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+# from Utils.ASR.models import ASRCNN
+# from Utils.JDC.model import JDCNet
+
+# from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
+# from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
+# from Modules.diffusion.diffusion import AudioDiffusionConditional
+
+# from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
+
+# from munch import Munch
+# import yaml
+
+# from hflayers import Hopfield, HopfieldPooling, HopfieldLayer
+# from hflayers.auxiliary.data import BitPatternSet
+
+# # Import auxiliary modules.
+# from distutils.version import LooseVersion
+# from typing import List, Tuple
+
+# import math
+
+# import torch
+
+# from xlstm import (
+# xLSTMBlockStack,
+# xLSTMBlockStackConfig,
+# mLSTMBlockConfig,
+# mLSTMLayerConfig,
+# sLSTMBlockConfig,
+# sLSTMLayerConfig,
+# FeedForwardConfig,
+# )
+
+# class LearnedDownSample(nn.Module):
+# def __init__(self, layer_type, dim_in):
+# super().__init__()
+# self.layer_type = layer_type
+
+# if self.layer_type == 'none':
+# self.conv = nn.Identity()
+# elif self.layer_type == 'timepreserve':
+# self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
+# elif self.layer_type == 'half':
+# self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
+# else:
+# raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+# def forward(self, x):
+# return self.conv(x)
+
+# class LearnedUpSample(nn.Module):
+# def __init__(self, layer_type, dim_in):
+# super().__init__()
+# self.layer_type = layer_type
+
+# if self.layer_type == 'none':
+# self.conv = nn.Identity()
+# elif self.layer_type == 'timepreserve':
+# self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
+# elif self.layer_type == 'half':
+# self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
+# else:
+# raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+# def forward(self, x):
+# return self.conv(x)
+
+# class DownSample(nn.Module):
+# def __init__(self, layer_type):
+# super().__init__()
+# self.layer_type = layer_type
+
+# def forward(self, x):
+# if self.layer_type == 'none':
+# return x
+# elif self.layer_type == 'timepreserve':
+# return F.avg_pool2d(x, (2, 1))
+# elif self.layer_type == 'half':
+# if x.shape[-1] % 2 != 0:
+# x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
+# return F.avg_pool2d(x, 2)
+# else:
+# raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+# class UpSample(nn.Module):
+# def __init__(self, layer_type):
+# super().__init__()
+# self.layer_type = layer_type
+
+# def forward(self, x):
+# if self.layer_type == 'none':
+# return x
+# elif self.layer_type == 'timepreserve':
+# return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
+# elif self.layer_type == 'half':
+# return F.interpolate(x, scale_factor=2, mode='nearest')
+# else:
+# raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+# class ResBlk(nn.Module):
+# def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
+# normalize=False, downsample='none'):
+# super().__init__()
+# self.actv = actv
+# self.normalize = normalize
+# self.downsample = DownSample(downsample)
+# self.downsample_res = LearnedDownSample(downsample, dim_in)
+# self.learned_sc = dim_in != dim_out
+# self._build_weights(dim_in, dim_out)
+
+# def _build_weights(self, dim_in, dim_out):
+# self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
+# self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
+# if self.normalize:
+# self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
+# self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
+# if self.learned_sc:
+# self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+# def _shortcut(self, x):
+# if self.learned_sc:
+# x = self.conv1x1(x)
+# if self.downsample:
+# x = self.downsample(x)
+# return x
+
+# def _residual(self, x):
+# if self.normalize:
+# x = self.norm1(x)
+# x = self.actv(x)
+# x = self.conv1(x)
+# x = self.downsample_res(x)
+# if self.normalize:
+# x = self.norm2(x)
+# x = self.actv(x)
+# x = self.conv2(x)
+# return x
+
+# def forward(self, x):
+# x = self._shortcut(x) + self._residual(x)
+# return x / math.sqrt(2) # unit variance
+
+# class StyleEncoder(nn.Module):
+# def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
+# super().__init__()
+# blocks = []
+# blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
+
+# repeat_num = 4
+# for _ in range(repeat_num):
+# dim_out = min(dim_in*2, max_conv_dim)
+# blocks += [ResBlk(dim_in, dim_out, downsample='half')]
+# dim_in = dim_out
+
+# blocks += [nn.LeakyReLU(0.2)]
+# blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
+# blocks += [nn.AdaptiveAvgPool2d(1)]
+# blocks += [nn.LeakyReLU(0.2)]
+# self.shared = nn.Sequential(*blocks)
+
+# self.unshared = nn.Linear(dim_out, style_dim)
+
+# def forward(self, x):
+# h = self.shared(x)
+# h = h.view(h.size(0), -1)
+# s = self.unshared(h)
+
+# return s
+
+# class LinearNorm(torch.nn.Module):
+# def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+# super(LinearNorm, self).__init__()
+# self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+# torch.nn.init.xavier_uniform_(
+# self.linear_layer.weight,
+# gain=torch.nn.init.calculate_gain(w_init_gain))
+
+# def forward(self, x):
+# return self.linear_layer(x)
+
+# class Discriminator2d(nn.Module):
+# def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
+# super().__init__()
+# blocks = []
+# blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
+
+# for lid in range(repeat_num):
+# dim_out = min(dim_in*2, max_conv_dim)
+# blocks += [ResBlk(dim_in, dim_out, downsample='half')]
+# dim_in = dim_out
+
+# blocks += [nn.LeakyReLU(0.2)]
+# blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
+# blocks += [nn.LeakyReLU(0.2)]
+# blocks += [nn.AdaptiveAvgPool2d(1)]
+# blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
+# self.main = nn.Sequential(*blocks)
+
+# def get_feature(self, x):
+# features = []
+# for l in self.main:
+# x = l(x)
+# features.append(x)
+# out = features[-1]
+# out = out.view(out.size(0), -1) # (batch, num_domains)
+# return out, features
+
+# def forward(self, x):
+# out, features = self.get_feature(x)
+# out = out.squeeze() # (batch)
+# return out, features
+
+# class ResBlk1d(nn.Module):
+# def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
+# normalize=False, downsample='none', dropout_p=0.2):
+# super().__init__()
+# self.actv = actv
+# self.normalize = normalize
+# self.downsample_type = downsample
+# self.learned_sc = dim_in != dim_out
+# self._build_weights(dim_in, dim_out)
+# self.dropout_p = dropout_p
+
+# if self.downsample_type == 'none':
+# self.pool = nn.Identity()
+# else:
+# self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
+
+# def _build_weights(self, dim_in, dim_out):
+# self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
+# self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+# if self.normalize:
+# self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
+# self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
+# if self.learned_sc:
+# self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+# def downsample(self, x):
+# if self.downsample_type == 'none':
+# return x
+# else:
+# if x.shape[-1] % 2 != 0:
+# x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
+# return F.avg_pool1d(x, 2)
+
+# def _shortcut(self, x):
+# if self.learned_sc:
+# x = self.conv1x1(x)
+# x = self.downsample(x)
+# return x
+
+# def _residual(self, x):
+# if self.normalize:
+# x = self.norm1(x)
+# x = self.actv(x)
+# x = F.dropout(x, p=self.dropout_p, training=self.training)
+
+# x = self.conv1(x)
+# x = self.pool(x)
+# if self.normalize:
+# x = self.norm2(x)
+
+# x = self.actv(x)
+# x = F.dropout(x, p=self.dropout_p, training=self.training)
+
+# x = self.conv2(x)
+# return x
+
+# def forward(self, x):
+# x = self._shortcut(x) + self._residual(x)
+# return x / math.sqrt(2) # unit variance
+
+# class LayerNorm(nn.Module):
+# def __init__(self, channels, eps=1e-5):
+# super().__init__()
+# self.channels = channels
+# self.eps = eps
+
+# self.gamma = nn.Parameter(torch.ones(channels))
+# self.beta = nn.Parameter(torch.zeros(channels))
+
+# def forward(self, x):
+# x = x.transpose(1, -1)
+# x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+# return x.transpose(1, -1)
+
+# class TextEncoder(nn.Module):
+# def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
+# super().__init__()
+# self.embedding = nn.Embedding(n_symbols, channels)
+
+
+
+
+# padding = (kernel_size - 1) // 2
+# self.cnn = nn.ModuleList()
+# for _ in range(depth):
+# self.cnn.append(nn.Sequential(
+
+# weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
+# LayerNorm(channels),
+# actv,
+# nn.Dropout(0.2),
+# ))
+# # self.cnn = nn.Sequential(*self.cnn)
+
+
+# self.lstm = Hopfield(input_size=channels,
+# hidden_size=channels // 2,
+# num_heads=32,
+# # scaling=.75,
+# add_zero_association=True,
+# batch_first=True)
+
+# def forward(self, x, input_lengths, m):
+
+# x = self.embedding(x) # [B, T, emb]
+
+# x = x.transpose(1, 2) # [B, emb, T]
+# m = m.to(input_lengths.device).unsqueeze(1)
+# x.masked_fill_(m, 0.0)
+
+# for c in self.cnn:
+# x = c(x)
+# x.masked_fill_(m, 0.0)
+
+# x = x.transpose(1, 2) # [B, T, chn]
+
+
+# input_lengths = input_lengths.cpu().numpy()
+
+
+
+
+
+# # x = nn.utils.rnn.pack_padded_sequence(
+# # x, input_lengths, batch_first=True, enforce_sorted=False)
+
+# # self.lstm.flatten_parameters()
+# x = self.lstm(x)
+
+
+# # x, _ = nn.utils.rnn.pad_packed_sequence(
+# # x, batch_first=True)
+
+# x = x.transpose(-1, -2)
+# # x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
+
+# # x_pad[:, :, :x.shape[-1]] = x
+# # x = x_pad.to(x.device)
+
+# x.masked_fill_(m, 0.0)
+
+# return x
+
+# def inference(self, x):
+# x = self.embedding(x)
+# x = x.transpose(1, 2)
+# x = self.cnn(x)
+# x = x.transpose(1, 2)
+# # self.lstm.flatten_parameters()
+# x = self.lstm(x)
+# return x
+
+# def length_to_mask(self, lengths):
+# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# return mask
+
+
+
+# class AdaIN1d(nn.Module):
+# def __init__(self, style_dim, num_features):
+# super().__init__()
+# self.norm = nn.InstanceNorm1d(num_features, affine=False)
+# self.fc = nn.Linear(style_dim, num_features*2)
+
+# def forward(self, x, s):
+# h = self.fc(s)
+
+# h = h.view(h.size(0), h.size(1), 1)
+# gamma, beta = torch.chunk(h, chunks=2, dim=1)
+# return (1 + gamma) * self.norm(x) + beta
+
+# class UpSample1d(nn.Module):
+# def __init__(self, layer_type):
+# super().__init__()
+# self.layer_type = layer_type
+
+# def forward(self, x):
+# if self.layer_type == 'none':
+# return x
+# else:
+# return F.interpolate(x, scale_factor=2, mode='nearest')
+
+# class AdainResBlk1d(nn.Module):
+# def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
+# upsample='none', dropout_p=0.0):
+# super().__init__()
+# self.actv = actv
+# self.upsample_type = upsample
+# self.upsample = UpSample1d(upsample)
+# self.learned_sc = dim_in != dim_out
+# self._build_weights(dim_in, dim_out, style_dim)
+# self.dropout = nn.Dropout(dropout_p)
+
+# if upsample == 'none':
+# self.pool = nn.Identity()
+# else:
+# self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
+
+
+# def _build_weights(self, dim_in, dim_out, style_dim):
+# self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+# self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
+# self.norm1 = AdaIN1d(style_dim, dim_in)
+# self.norm2 = AdaIN1d(style_dim, dim_out)
+# if self.learned_sc:
+# self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+# def _shortcut(self, x):
+# x = self.upsample(x)
+# if self.learned_sc:
+# x = self.conv1x1(x)
+# return x
+
+# def _residual(self, x, s):
+# x = self.norm1(x, s)
+# x = self.actv(x)
+# x = self.pool(x)
+# x = self.conv1(self.dropout(x))
+# x = self.norm2(x, s)
+# x = self.actv(x)
+# x = self.conv2(self.dropout(x))
+# return x
+
+# def forward(self, x, s):
+# out = self._residual(x, s)
+# out = (out + self._shortcut(x)) / math.sqrt(2)
+# return out
+
+# class AdaLayerNorm(nn.Module):
+# def __init__(self, style_dim, channels, eps=1e-5):
+# super().__init__()
+# self.channels = channels
+# self.eps = eps
+
+# self.fc = nn.Linear(style_dim, channels*2)
+
+# def forward(self, x, s):
+# x = x.transpose(-1, -2)
+# x = x.transpose(1, -1)
+
+# h = self.fc(s)
+# h = h.view(h.size(0), h.size(1), 1)
+# gamma, beta = torch.chunk(h, chunks=2, dim=1)
+# gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
+
+
+# x = F.layer_norm(x, (self.channels,), eps=self.eps)
+# x = (1 + gamma) * x + beta
+# return x.transpose(1, -1).transpose(-1, -2)
+
+# # class ProsodyPredictor(nn.Module):
+
+# # def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
+# # super().__init__()
+
+# # self.text_encoder = DurationEncoder(sty_dim=style_dim,
+# # d_model=d_hid,
+# # nlayers=nlayers,
+# # dropout=dropout)
+
+# # self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# # self.duration_proj = LinearNorm(d_hid, max_dur)
+
+# # self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# # self.F0 = nn.ModuleList()
+# # self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# # self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# # self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# # self.N = nn.ModuleList()
+# # self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# # self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# # self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# # self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+# # self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+
+
+# # def forward(self, texts, style, text_lengths, alignment, m):
+# # d = self.text_encoder(texts, style, text_lengths, m)
+
+# # batch_size = d.shape[0]
+# # text_size = d.shape[1]
+
+# # # predict duration
+# # input_lengths = text_lengths.cpu().numpy()
+# # x = nn.utils.rnn.pack_padded_sequence(
+# # d, input_lengths, batch_first=True, enforce_sorted=False)
+
+# # m = m.to(text_lengths.device).unsqueeze(1)
+
+# # self.lstm.flatten_parameters()
+# # x, _ = self.lstm(x)
+# # x, _ = nn.utils.rnn.pad_packed_sequence(
+# # x, batch_first=True)
+
+# # x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
+
+# # x_pad[:, :x.shape[1], :] = x
+# # x = x_pad.to(x.device)
+
+# # duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
+
+# # en = (d.transpose(-1, -2) @ alignment)
+
+# # return duration.squeeze(-1), en
+
+
+# class ProsodyPredictor(nn.Module):
+
+# def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
+# super().__init__()
+
+
+
+
+
+# self.text_encoder = DurationEncoder(sty_dim=style_dim,
+# d_model=d_hid,
+# nlayers=nlayers,
+# dropout=dropout)
+
+
+# self.lstm = Hopfield(input_size=d_hid + style_dim,
+# hidden_size=d_hid // 2,
+# num_heads=32,
+# # scaling=.75,
+# add_zero_association=True,
+# batch_first=True)
+
+
+# self.prepare_projection = nn.Linear(d_hid + style_dim, d_hid)
+
+# self.duration_proj = LinearNorm(d_hid , max_dur)
+
+# self.shared = Hopfield(input_size=d_hid + style_dim,
+# hidden_size=d_hid // 2,
+# num_heads=32,
+# # scaling=.75,
+# add_zero_association=True,
+# batch_first=True)
+
+# #self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# self.F0 = nn.ModuleList()
+# self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# self.N = nn.ModuleList()
+# self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+# self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+
+
+# def forward(self, texts, style, text_lengths, alignment, m):
+# d = self.text_encoder(texts, style, text_lengths, m)
+
+# batch_size = d.shape[0]
+# text_size = d.shape[1]
+
+# # predict duration
+
+
+# input_lengths = text_lengths.cpu().numpy()
+
+
+# # x = nn.utils.rnn.pack_padded_sequence(
+# # d, input_lengths, batch_first=True, enforce_sorted=False)
+
+# x = d # this dude can handle variable seq len so no need for packing
+
+
+# m = m.to(text_lengths.device).unsqueeze(1)
+
+# # self.lstm.flatten_parameters()
+# x = self.lstm(x) # no longer using lstm
+# x = self.prepare_projection(x)
+
+
+# # x, _ = nn.utils.rnn.pad_packed_sequence(
+# # x, batch_first=True)
+
+# x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
+
+# x_pad[:, :x.shape[1], :] = x
+# x = x_pad.to(x.device)
+
+# x = x.transpose(-1,-2)
+# x = x.permute(0,2,1)
+# duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
+
+
+
+# en = (d.transpose(-1, -2) @ alignment)
+
+# return duration.squeeze(-1), en
+
+
+# def F0Ntrain(self, x, s):
+
+
+
+# x = self.shared(x.transpose(-1, -2))
+# x = self.prepare_projection(x)
+
+
+# F0 = x.transpose(-1, -2)
+
+# for block in self.F0:
+# F0 = block(F0, s)
+# F0 = self.F0_proj(F0)
+
+# N = x.transpose(-1, -2)
+# for block in self.N:
+# N = block(N, s)
+# N = self.N_proj(N)
+
+# return F0.squeeze(1), N.squeeze(1)
+
+# def length_to_mask(self, lengths):
+# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# return mask
+
+# class DurationEncoder(nn.Module):
+
+# def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
+# super().__init__()
+# self.lstms = nn.ModuleList()
+# for _ in range(nlayers):
+# self.lstms.append(nn.GRU(d_model + sty_dim,
+# d_model // 2,
+# num_layers=1,
+# batch_first=True,
+# bidirectional=True,
+# dropout=dropout))
+# self.lstms.append(AdaLayerNorm(sty_dim, d_model))
+
+
+# self.dropout = dropout
+# self.d_model = d_model
+# self.sty_dim = sty_dim
+
+# def forward(self, x, style, text_lengths, m):
+# masks = m.to(text_lengths.device)
+
+# x = x.permute(2, 0, 1)
+# s = style.expand(x.shape[0], x.shape[1], -1)
+# x = torch.cat([x, s], axis=-1)
+# x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
+
+# x = x.transpose(0, 1)
+# input_lengths = text_lengths.cpu().numpy()
+# x = x.transpose(-1, -2)
+
+# for block in self.lstms:
+# if isinstance(block, AdaLayerNorm):
+# x = block(x.transpose(-1, -2), style).transpose(-1, -2)
+# x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
+# x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
+# else:
+# x = x.transpose(-1, -2)
+# x = nn.utils.rnn.pack_padded_sequence(
+# x, input_lengths, batch_first=True, enforce_sorted=False)
+# block.flatten_parameters()
+# x, _ = block(x)
+# x, _ = nn.utils.rnn.pad_packed_sequence(
+# x, batch_first=True)
+# x = F.dropout(x, p=self.dropout, training=self.training)
+# x = x.transpose(-1, -2)
+
+# x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
+
+# x_pad[:, :, :x.shape[-1]] = x
+# x = x_pad.to(x.device)
+
+# return x.transpose(-1, -2)
+
+# def inference(self, x, style):
+# x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
+# style = style.expand(x.shape[0], x.shape[1], -1)
+# x = torch.cat([x, style], axis=-1)
+# src = self.pos_encoder(x)
+# output = self.transformer_encoder(src).transpose(0, 1)
+# return output
+
+# def length_to_mask(self, lengths):
+# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# return mask
+
+# def inference(self, x, style):
+# x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
+# style = style.expand(x.shape[0], x.shape[1], -1)
+# x = torch.cat([x, style], axis=-1)
+# src = self.pos_encoder(x)
+# output = self.transformer_encoder(src).transpose(0, 1)
+# return output
+
+# def length_to_mask(self, lengths):
+# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# return mask
+
+
+
+
+
+
+
+# def load_F0_models(path):
+# # load F0 model
+
+# F0_model = JDCNet(num_class=1, seq_len=192)
+# params = torch.load(path, map_location='cpu')['net']
+# F0_model.load_state_dict(params)
+# _ = F0_model.train()
+
+# return F0_model
+
+# def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
+# # load ASR model
+# def _load_config(path):
+# with open(path) as f:
+# config = yaml.safe_load(f)
+# model_config = config['model_params']
+# return model_config
+
+# def _load_model(model_config, model_path):
+# model = ASRCNN(**model_config)
+# params = torch.load(model_path, map_location='cpu')['model']
+# model.load_state_dict(params)
+# return model
+
+# asr_model_config = _load_config(ASR_MODEL_CONFIG)
+# asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
+# _ = asr_model.train()
+
+# return asr_model
+
+# def build_model(args, text_aligner, pitch_extractor, bert):
+# assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
+
+# if args.decoder.type == "istftnet":
+# from Modules.istftnet import Decoder
+# decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
+# resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
+# upsample_rates = args.decoder.upsample_rates,
+# upsample_initial_channel=args.decoder.upsample_initial_channel,
+# resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
+# upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
+# gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
+# else:
+# from Modules.hifigan import Decoder
+# decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
+# resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
+# upsample_rates = args.decoder.upsample_rates,
+# upsample_initial_channel=args.decoder.upsample_initial_channel,
+# resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
+# upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
+
+# text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
+
+# predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
+
+# style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
+# predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
+
+# # define diffusion model
+# if args.multispeaker:
+# transformer = StyleTransformer1d(channels=args.style_dim*2,
+# context_embedding_features=bert.config.hidden_size,
+# context_features=args.style_dim*2,
+# **args.diffusion.transformer)
+# else:
+# transformer = Transformer1d(channels=args.style_dim*2,
+# context_embedding_features=bert.config.hidden_size,
+# **args.diffusion.transformer)
+
+# diffusion = AudioDiffusionConditional(
+# in_channels=1,
+# embedding_max_length=bert.config.max_position_embeddings,
+# embedding_features=bert.config.hidden_size,
+# embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
+# channels=args.style_dim*2,
+# context_features=args.style_dim*2,
+# )
+
+# diffusion.diffusion = KDiffusion(
+# net=diffusion.unet,
+# sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
+# sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
+# dynamic_threshold=0.0
+# )
+# diffusion.diffusion.net = transformer
+# diffusion.unet = transformer
+
+
+# nets = Munch(
+# bert=bert,
+# bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
+
+# predictor=predictor,
+# decoder=decoder,
+# text_encoder=text_encoder,
+
+# predictor_encoder=predictor_encoder,
+# style_encoder=style_encoder,
+# diffusion=diffusion,
+
+# text_aligner = text_aligner,
+# pitch_extractor=pitch_extractor,
+
+# mpd = MultiPeriodDiscriminator(),
+# msd = MultiResSpecDiscriminator(),
+
+# # slm discriminator head
+# wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
+# )
+
+# return nets
+
+# def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
+# state = torch.load(path, map_location='cpu')
+# params = state['net']
+
+# for key in model:
+# if key in params and key not in ignore_modules:
+# print('%s loaded' % key)
+# try:
+# model[key].load_state_dict(params[key], strict=True)
+# except:
+# from collections import OrderedDict
+# state_dict = params[key]
+# new_state_dict = OrderedDict()
+# print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict length: {len(state_dict.keys())}')
+# for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
+# new_state_dict[k_m] = v_c
+# model[key].load_state_dict(new_state_dict, strict=True)
+# _ = [model[key].eval() for key in model]
+
+# if not load_only_params:
+# epoch = state["epoch"]
+# iters = state["iters"]
+# optimizer.load_state_dict(state["optimizer"])
+# else:
+# epoch = 0
+# iters = 0
+
+# return model, optimizer, epoch, iters
+
+##############################################################################################################
+
+##############################################################################################################
+
+##############################################################################################################
+
+
+# mLSTM
+
+
+##############################################################################################################
+
+##############################################################################################################
+
+##############################################################################################################
+
+
+import os
+import os.path as osp
+
+import copy
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+from Utils.ASR.models import ASRCNN
+from Utils.JDC.model import JDCNet
+
+from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
+from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
+from Modules.diffusion.diffusion import AudioDiffusionConditional
+
+from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
+
+from munch import Munch
+import yaml
+
+# from hflayers import Hopfield, HopfieldPooling, HopfieldLayer
+# from hflayers.auxiliary.data import BitPatternSet
+
+# Import auxiliary modules.
+from distutils.version import LooseVersion
+from typing import List, Tuple
+
+import math
+# from liger_kernel.ops.layer_norm import LigerLayerNormFunction
+# from liger_kernel.transformers.experimental.embedding import nn.Embedding
+
+import torch
+
+from xlstm import (
+ xLSTMBlockStack,
+ xLSTMBlockStackConfig,
+ mLSTMBlockConfig,
+ mLSTMLayerConfig,
+ sLSTMBlockConfig,
+ sLSTMLayerConfig,
+ FeedForwardConfig,
+)
+
+class LearnedDownSample(nn.Module):
+ def __init__(self, layer_type, dim_in):
+ super().__init__()
+ self.layer_type = layer_type
+
+ if self.layer_type == 'none':
+ self.conv = nn.Identity()
+ elif self.layer_type == 'timepreserve':
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
+ elif self.layer_type == 'half':
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
+ else:
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+ def forward(self, x):
+ return self.conv(x)
+
+class LearnedUpSample(nn.Module):
+ def __init__(self, layer_type, dim_in):
+ super().__init__()
+ self.layer_type = layer_type
+
+ if self.layer_type == 'none':
+ self.conv = nn.Identity()
+ elif self.layer_type == 'timepreserve':
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
+ elif self.layer_type == 'half':
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
+ else:
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+ def forward(self, x):
+ return self.conv(x)
+
+class DownSample(nn.Module):
+ def __init__(self, layer_type):
+ super().__init__()
+ self.layer_type = layer_type
+
+ def forward(self, x):
+ if self.layer_type == 'none':
+ return x
+ elif self.layer_type == 'timepreserve':
+ return F.avg_pool2d(x, (2, 1))
+ elif self.layer_type == 'half':
+ if x.shape[-1] % 2 != 0:
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
+ return F.avg_pool2d(x, 2)
+ else:
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+class UpSample(nn.Module):
+ def __init__(self, layer_type):
+ super().__init__()
+ self.layer_type = layer_type
+
+ def forward(self, x):
+ if self.layer_type == 'none':
+ return x
+ elif self.layer_type == 'timepreserve':
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
+ elif self.layer_type == 'half':
+ return F.interpolate(x, scale_factor=2, mode='nearest')
+ else:
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+class ResBlk(nn.Module):
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
+ normalize=False, downsample='none'):
+ super().__init__()
+ self.actv = actv
+ self.normalize = normalize
+ self.downsample = DownSample(downsample)
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
+ self.learned_sc = dim_in != dim_out
+ self._build_weights(dim_in, dim_out)
+
+ def _build_weights(self, dim_in, dim_out):
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
+ if self.normalize:
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
+ if self.learned_sc:
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+ def _shortcut(self, x):
+ if self.learned_sc:
+ x = self.conv1x1(x)
+ if self.downsample:
+ x = self.downsample(x)
+ return x
+
+ def _residual(self, x):
+ if self.normalize:
+ x = self.norm1(x)
+ x = self.actv(x)
+ x = self.conv1(x)
+ x = self.downsample_res(x)
+ if self.normalize:
+ x = self.norm2(x)
+ x = self.actv(x)
+ x = self.conv2(x)
+ return x
+
+ def forward(self, x):
+ x = self._shortcut(x) + self._residual(x)
+ return x / math.sqrt(2) # unit variance
+
+class StyleEncoder(nn.Module):
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
+ super().__init__()
+ blocks = []
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
+
+ repeat_num = 4
+ for _ in range(repeat_num):
+ dim_out = min(dim_in*2, max_conv_dim)
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
+ dim_in = dim_out
+
+ blocks += [nn.LeakyReLU(0.2)]
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
+ blocks += [nn.AdaptiveAvgPool2d(1)]
+ blocks += [nn.LeakyReLU(0.2)]
+ self.shared = nn.Sequential(*blocks)
+
+ self.unshared = nn.Linear(dim_out, style_dim)
+
+ def forward(self, x):
+ h = self.shared(x)
+ h = h.view(h.size(0), -1)
+ s = self.unshared(h)
+
+ return s
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+class Discriminator2d(nn.Module):
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
+ super().__init__()
+ blocks = []
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
+
+ for lid in range(repeat_num):
+ dim_out = min(dim_in*2, max_conv_dim)
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
+ dim_in = dim_out
+
+ blocks += [nn.LeakyReLU(0.2)]
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
+ blocks += [nn.LeakyReLU(0.2)]
+ blocks += [nn.AdaptiveAvgPool2d(1)]
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
+ self.main = nn.Sequential(*blocks)
+
+ def get_feature(self, x):
+ features = []
+ for l in self.main:
+ x = l(x)
+ features.append(x)
+ out = features[-1]
+ out = out.view(out.size(0), -1) # (batch, num_domains)
+ return out, features
+
+ def forward(self, x):
+ out, features = self.get_feature(x)
+ out = out.squeeze() # (batch)
+ return out, features
+
+class ResBlk1d(nn.Module):
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
+ normalize=False, downsample='none', dropout_p=0.2):
+ super().__init__()
+ self.actv = actv
+ self.normalize = normalize
+ self.downsample_type = downsample
+ self.learned_sc = dim_in != dim_out
+ self._build_weights(dim_in, dim_out)
+ self.dropout_p = dropout_p
+
+ if self.downsample_type == 'none':
+ self.pool = nn.Identity()
+ else:
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
+
+ def _build_weights(self, dim_in, dim_out):
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+ if self.normalize:
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
+ if self.learned_sc:
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+ def downsample(self, x):
+ if self.downsample_type == 'none':
+ return x
+ else:
+ if x.shape[-1] % 2 != 0:
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
+ return F.avg_pool1d(x, 2)
+
+ def _shortcut(self, x):
+ if self.learned_sc:
+ x = self.conv1x1(x)
+ x = self.downsample(x)
+ return x
+
+ def _residual(self, x):
+ if self.normalize:
+ x = self.norm1(x)
+ x = self.actv(x)
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
+
+ x = self.conv1(x)
+ x = self.pool(x)
+ if self.normalize:
+ x = self.norm2(x)
+
+ x = self.actv(x)
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
+
+ x = self.conv2(x)
+ return x
+
+ def forward(self, x):
+ x = self._shortcut(x) + self._residual(x)
+ return x / math.sqrt(2) # unit variance
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class TextEncoder(nn.Module):
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
+ super().__init__()
+ self.embedding = nn.Embedding(n_symbols, channels)
+
+ self.prepare_projection=LinearNorm(channels,channels // 2)
+ self.post_projection=LinearNorm(channels // 2,channels)
+ self.cfg = xLSTMBlockStackConfig(
+ mlstm_block=mLSTMBlockConfig(
+ mlstm=mLSTMLayerConfig(
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
+ )
+ ),
+ # slstm_block=sLSTMBlockConfig(
+ # slstm=sLSTMLayerConfig(
+ # backend="cuda",
+ # num_heads=4,
+ # conv1d_kernel_size=4,
+ # bias_init="powerlaw_blockdependent",
+ # ),
+ # feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
+ # ),
+ context_length=channels,
+ num_blocks=8,
+ embedding_dim=channels // 2,
+ # slstm_at=[1],
+
+ )
+
+
+
+ padding = (kernel_size - 1) // 2
+ self.cnn = nn.ModuleList()
+ for _ in range(depth):
+ self.cnn.append(nn.Sequential(
+
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
+ LayerNorm(channels),
+ actv,
+ nn.Dropout(0.2),
+ ))
+ # self.cnn = nn.Sequential(*self.cnn)
+
+
+ self.lstm = xLSTMBlockStack(self.cfg)
+ def forward(self, x, input_lengths, m):
+
+ x = self.embedding(x) # [B, T, emb]
+
+
+ x = x.transpose(1, 2) # [B, emb, T]
+ m = m.to(input_lengths.device).unsqueeze(1)
+ x.masked_fill_(m, 0.0)
+
+ for c in self.cnn:
+ x = c(x)
+ x.masked_fill_(m, 0.0)
+
+ x = x.transpose(1, 2) # [B, T, chn]
+
+
+ input_lengths = input_lengths.cpu().numpy()
+
+
+
+ x = self.prepare_projection(x)
+
+ # x = nn.utils.rnn.pack_padded_sequence(
+ # x, input_lengths, batch_first=True, enforce_sorted=False)
+
+ # self.lstm.flatten_parameters()
+ x = self.lstm(x)
+
+ x = self.post_projection(x)
+ # x, _ = nn.utils.rnn.pad_packed_sequence(
+ # x, batch_first=True)
+
+ x = x.transpose(-1, -2)
+# x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
+
+# x_pad[:, :, :x.shape[-1]] = x
+# x = x_pad.to(x.device)
+
+ x.masked_fill_(m, 0.0)
+
+ return x
+
+ def inference(self, x):
+ x = self.embedding(x)
+ x = x.transpose(1, 2)
+ x = self.cnn(x)
+ x = x.transpose(1, 2)
+ # self.lstm.flatten_parameters()
+ x = self.lstm(x)
+ return x
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+
+
+class AdaIN1d(nn.Module):
+ def __init__(self, style_dim, num_features):
+ super().__init__()
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
+ self.fc = nn.Linear(style_dim, num_features*2)
+
+ def forward(self, x, s):
+ h = self.fc(s)
+
+ h = h.view(h.size(0), h.size(1), 1)
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
+ return (1 + gamma) * self.norm(x) + beta
+
+class UpSample1d(nn.Module):
+ def __init__(self, layer_type):
+ super().__init__()
+ self.layer_type = layer_type
+
+ def forward(self, x):
+ if self.layer_type == 'none':
+ return x
+ else:
+ return F.interpolate(x, scale_factor=2, mode='nearest')
+
+class AdainResBlk1d(nn.Module):
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
+ upsample='none', dropout_p=0.0):
+ super().__init__()
+ self.actv = actv
+ self.upsample_type = upsample
+ self.upsample = UpSample1d(upsample)
+ self.learned_sc = dim_in != dim_out
+ self._build_weights(dim_in, dim_out, style_dim)
+ self.dropout = nn.Dropout(dropout_p)
+
+ if upsample == 'none':
+ self.pool = nn.Identity()
+ else:
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
+
+
+ def _build_weights(self, dim_in, dim_out, style_dim):
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
+ self.norm1 = AdaIN1d(style_dim, dim_in)
+ self.norm2 = AdaIN1d(style_dim, dim_out)
+ if self.learned_sc:
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+ def _shortcut(self, x):
+ x = self.upsample(x)
+ if self.learned_sc:
+ x = self.conv1x1(x)
+ return x
+
+ def _residual(self, x, s):
+ x = self.norm1(x, s)
+ x = self.actv(x)
+ x = self.pool(x)
+ x = self.conv1(self.dropout(x))
+ x = self.norm2(x, s)
+ x = self.actv(x)
+ x = self.conv2(self.dropout(x))
+ return x
+
+ def forward(self, x, s):
+ out = self._residual(x, s)
+ out = (out + self._shortcut(x)) / math.sqrt(2)
+ return out
+
+class AdaLayerNorm(nn.Module):
+ def __init__(self, style_dim, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.fc = nn.Linear(style_dim, channels*2)
+
+ def forward(self, x, s):
+ x = x.transpose(-1, -2)
+ x = x.transpose(1, -1)
+
+ h = self.fc(s)
+ h = h.view(h.size(0), h.size(1), 1)
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
+
+
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
+ x = (1 + gamma) * x + beta
+ return x.transpose(1, -1).transpose(-1, -2)
+
+# class ProsodyPredictor(nn.Module):
+
+# def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
+# super().__init__()
+
+# self.text_encoder = DurationEncoder(sty_dim=style_dim,
+# d_model=d_hid,
+# nlayers=nlayers,
+# dropout=dropout)
+
+# self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# self.duration_proj = LinearNorm(d_hid, max_dur)
+
+# self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# self.F0 = nn.ModuleList()
+# self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# self.N = nn.ModuleList()
+# self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+# self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+
+
+# def forward(self, texts, style, text_lengths, alignment, m):
+# d = self.text_encoder(texts, style, text_lengths, m)
+
+# batch_size = d.shape[0]
+# text_size = d.shape[1]
+
+# # predict duration
+# input_lengths = text_lengths.cpu().numpy()
+# x = nn.utils.rnn.pack_padded_sequence(
+# d, input_lengths, batch_first=True, enforce_sorted=False)
+
+# m = m.to(text_lengths.device).unsqueeze(1)
+
+# self.lstm.flatten_parameters()
+# x, _ = self.lstm(x)
+# x, _ = nn.utils.rnn.pad_packed_sequence(
+# x, batch_first=True)
+
+# x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
+
+# x_pad[:, :x.shape[1], :] = x
+# x = x_pad.to(x.device)
+
+# duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
+
+# en = (d.transpose(-1, -2) @ alignment)
+
+# return duration.squeeze(-1), en
+
+
+class ProsodyPredictor(nn.Module):
+
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
+ super().__init__()
+
+ self.cfg = xLSTMBlockStackConfig(
+ mlstm_block=mLSTMBlockConfig(
+ mlstm=mLSTMLayerConfig(
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
+ )
+ ),
+ # slstm_block=sLSTMBlockConfig(
+ # slstm=sLSTMLayerConfig(
+ # backend="cuda",
+ # num_heads=4,
+ # conv1d_kernel_size=4,
+ # bias_init="powerlaw_blockdependent",
+ # ),
+ # feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
+ # ),
+ context_length=d_hid,
+ num_blocks=8,
+ embedding_dim=d_hid + style_dim,
+ # slstm_at=[1],
+
+ )
+
+
+
+
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
+ d_model=d_hid,
+ nlayers=nlayers,
+ dropout=dropout)
+
+
+ self.lstm = xLSTMBlockStack(self.cfg)
+
+ self.prepare_projection = nn.Linear(d_hid + style_dim, d_hid)
+
+ self.duration_proj = LinearNorm(d_hid , max_dur)
+
+ self.shared = xLSTMBlockStack(self.cfg)
+
+ #self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+ self.F0 = nn.ModuleList()
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+ self.N = nn.ModuleList()
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+
+
+ # def forward(self, texts, style, text_lengths, alignment, m):
+ # d = self.text_encoder(texts, style, text_lengths, m)
+
+ # batch_size = d.shape[0]
+ # text_size = d.shape[1]
+
+ # # predict duration
+
+
+ # input_lengths = text_lengths.cpu().numpy()
+
+
+ # # x = nn.utils.rnn.pack_padded_sequence(
+ # # d, input_lengths, batch_first=True, enforce_sorted=False)
+
+ # x = d # this dude can handle variable seq len so no need for packing
+
+
+ # m = m.to(text_lengths.device).unsqueeze(1)
+
+ # # self.lstm.flatten_parameters()
+ # x = self.lstm(x) # no longer using lstm
+ # x = self.prepare_projection(x)
+
+
+ # # x, _ = nn.utils.rnn.pad_packed_sequence(
+ # # x, batch_first=True)
+
+ # # x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
+
+ # # x_pad[:, :x.shape[1], :] = x
+ # # x = x_pad.to(x.device)
+
+ # x = x.transpose(-1,-2)
+ # x = x.permute(0,2,1)
+ # duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
+
+
+
+ # en = (d.transpose(-1, -2) @ alignment)
+
+ # return duration.squeeze(-1), en
+
+
+ # def F0Ntrain(self, x, s):
+
+
+
+ # x = self.shared(x.transpose(-1, -2))
+ # x = self.prepare_projection(x)
+
+
+ # F0 = x.transpose(-1, -2)
+
+ # for block in self.F0:
+ # F0 = block(F0, s)
+ # F0 = self.F0_proj(F0)
+
+ # N = x.transpose(-1, -2)
+ # for block in self.N:
+ # N = block(N, s)
+ # N = self.N_proj(N)
+
+ # return F0.squeeze(1), N.squeeze(1)
+
+
+ def forward(self, texts, style, text_lengths=None, alignment=None, m=None, f0=False):
+
+ if f0:
+ x, s = texts, style
+ x = self.shared(x.transpose(-1, -2))
+ x = self.prepare_projection(x)
+
+ F0 = x.transpose(-1, -2)
+ for block in self.F0:
+ F0 = block(F0, s)
+ F0 = self.F0_proj(F0)
+
+ N = x.transpose(-1, -2)
+ for block in self.N:
+ N = block(N, s)
+ N = self.N_proj(N)
+
+ return F0.squeeze(1), N.squeeze(1)
+ else:
+ # Problem is here
+ d = self.text_encoder(texts, style, text_lengths, m)
+
+ batch_size = d.shape[0]
+ text_size = d.shape[1]
+
+ # predict duration
+
+
+ input_lengths = text_lengths.cpu().numpy()
+
+
+ # x = nn.utils.rnn.pack_padded_sequence(
+ # d, input_lengths, batch_first=True, enforce_sorted=False)
+
+ x = d # this dude can handle variable seq len so no need for packing
+
+
+ m = m.to(text_lengths.device).unsqueeze(1)
+
+ # self.lstm.flatten_parameters()
+ x = self.lstm(x) # no longer using lstm
+ x = self.prepare_projection(x)
+
+
+ # x, _ = nn.utils.rnn.pad_packed_sequence(
+ # x, batch_first=True)
+
+ # x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
+
+ # x_pad[:, :x.shape[1], :] = x
+ # x = x_pad.to(x.device)
+
+ x = x.transpose(-1,-2)
+ x = x.permute(0,2,1)
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
+
+
+
+ en = (d.transpose(-1, -2) @ alignment)
+
+ return duration.squeeze(-1), en
+
+
+ def F0Ntrain(self, x, s):
+
+
+
+ x = self.shared(x.transpose(-1, -2))
+ x = self.prepare_projection(x)
+
+
+ F0 = x.transpose(-1, -2)
+
+ for block in self.F0:
+ F0 = block(F0, s)
+ F0 = self.F0_proj(F0)
+
+ N = x.transpose(-1, -2)
+ for block in self.N:
+ N = block(N, s)
+ N = self.N_proj(N)
+
+ return F0.squeeze(1), N.squeeze(1)
+
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+class DurationEncoder(nn.Module):
+
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
+ super().__init__()
+ self.lstms = nn.ModuleList()
+ for _ in range(nlayers):
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
+ d_model // 2,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=True,
+ dropout=dropout))
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
+
+
+ self.dropout = dropout
+ self.d_model = d_model
+ self.sty_dim = sty_dim
+
+ def forward(self, x, style, text_lengths, m):
+ masks = m.to(text_lengths.device)
+
+ x = x.permute(2, 0, 1)
+ s = style.expand(x.shape[0], x.shape[1], -1)
+ x = torch.cat([x, s], axis=-1)
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
+
+ x = x.transpose(0, 1)
+ input_lengths = text_lengths.cpu().numpy()
+ x = x.transpose(-1, -2)
+
+ for block in self.lstms:
+ if isinstance(block, AdaLayerNorm):
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
+ else:
+ x = x.transpose(-1, -2)
+ x = nn.utils.rnn.pack_padded_sequence(
+ x, input_lengths, batch_first=True, enforce_sorted=False)
+ block.flatten_parameters()
+ x, _ = block(x)
+ x, _ = nn.utils.rnn.pad_packed_sequence(
+ x, batch_first=True)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = x.transpose(-1, -2)
+
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
+
+ x_pad[:, :, :x.shape[-1]] = x
+ x = x_pad.to(x.device)
+
+ return x.transpose(-1, -2)
+
+ def inference(self, x, style):
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
+ style = style.expand(x.shape[0], x.shape[1], -1)
+ x = torch.cat([x, style], axis=-1)
+ src = self.pos_encoder(x)
+ output = self.transformer_encoder(src).transpose(0, 1)
+ return output
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+ def inference(self, x, style):
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
+ style = style.expand(x.shape[0], x.shape[1], -1)
+ x = torch.cat([x, style], axis=-1)
+ src = self.pos_encoder(x)
+ output = self.transformer_encoder(src).transpose(0, 1)
+ return output
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+
+
+
+
+
+
+def load_F0_models(path):
+ # load F0 model
+
+ F0_model = JDCNet(num_class=1, seq_len=192)
+ params = torch.load(path, map_location='cpu')['model']
+ F0_model.load_state_dict(params)
+ _ = F0_model.train()
+
+ return F0_model
+
+def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
+ # load ASR model
+ def _load_config(path):
+ with open(path) as f:
+ config = yaml.safe_load(f)
+ model_config = config['model_params']
+ return model_config
+
+ def _load_model(model_config, model_path):
+ model = ASRCNN(**model_config)
+ params = torch.load(model_path, map_location='cpu')['model']
+ model.load_state_dict(params)
+ return model
+
+ asr_model_config = _load_config(ASR_MODEL_CONFIG)
+ asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
+ _ = asr_model.train()
+
+ return asr_model
+
+def build_model(args, text_aligner, pitch_extractor, bert):
+ assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
+
+ if args.decoder.type == "istftnet":
+ from Modules.istftnet import Decoder
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
+ resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
+ upsample_rates=args.decoder.upsample_rates,
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft,
+ gen_istft_hop_size=args.decoder.gen_istft_hop_size)
+ else:
+ from Modules.hifigan import Decoder
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
+ resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
+ upsample_rates=args.decoder.upsample_rates,
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,)
+
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
+
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer,
+ max_dur=args.max_dur, dropout=args.dropout)
+
+ style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim,
+ max_conv_dim=args.hidden_dim) # acoustic style encoder
+ predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim,
+ max_conv_dim=args.hidden_dim) # prosodic style encoder
+
+ # define diffusion model
+ if args.multispeaker:
+ transformer = StyleTransformer1d(channels=args.style_dim * 2,
+ context_embedding_features=bert.config.hidden_size,
+ context_features=args.style_dim * 2,
+ **args.diffusion.transformer)
+ else:
+ transformer = Transformer1d(channels=args.style_dim * 2,
+ context_embedding_features=bert.config.hidden_size,
+ **args.diffusion.transformer)
+
+ diffusion = AudioDiffusionConditional(
+ in_channels=1,
+ embedding_max_length=bert.config.max_position_embeddings,
+ embedding_features=bert.config.hidden_size,
+ embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
+ channels=args.style_dim * 2,
+ context_features=args.style_dim * 2,
+ )
+
+ diffusion.diffusion = KDiffusion(
+ net=diffusion.unet,
+ sigma_distribution=LogNormalDistribution(mean=args.diffusion.dist.mean, std=args.diffusion.dist.std),
+ sigma_data=args.diffusion.dist.sigma_data,
+ # a placeholder, will be changed dynamically when start training diffusion model
+ dynamic_threshold=0.0
+ )
+ diffusion.diffusion.net = transformer
+ diffusion.unet = transformer
+
+ nets = Munch(
+ bert=bert,
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
+
+ predictor=predictor,
+ decoder=decoder,
+ text_encoder=text_encoder,
+
+ predictor_encoder=predictor_encoder,
+ style_encoder=style_encoder,
+ diffusion=diffusion,
+
+ text_aligner=text_aligner,
+ pitch_extractor=pitch_extractor,
+
+ mpd=MultiPeriodDiscriminator(),
+ msd=MultiResSpecDiscriminator(),
+
+ # slm discriminator head
+ wd=WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
+ )
+
+ return nets
+
+
+# def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
+# state = torch.load(path, map_location='cpu')
+# params = state['net']
+# for key in model:
+# if key in params and key not in ignore_modules:
+# print('%s loaded' % key)
+# model[key].load_state_dict(params[key], strict=False)
+# _ = [model[key].eval() for key in model]
+
+# if not load_only_params:
+# epoch = state["epoch"]
+# iters = state["iters"]
+# optimizer.load_state_dict(state["optimizer"])
+# else:
+# epoch = 0
+# iters = 0
+
+# return model, optimizer, epoch, iters
+
+def load_checkpoint(model, optimizer, path, load_only_params=False, ignore_modules=[]):
+ state = torch.load(path, map_location='cpu')
+ params = state['net']
+ print('loading the ckpt using the correct function.')
+
+ for key in model:
+ if key in params and key not in ignore_modules:
+ try:
+ model[key].load_state_dict(params[key], strict=True)
+ except:
+ from collections import OrderedDict
+ state_dict = params[key]
+ new_state_dict = OrderedDict()
+ print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict key length: {len(state_dict.keys())}')
+ for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
+ new_state_dict[k_m] = v_c
+ model[key].load_state_dict(new_state_dict, strict=True)
+ print('%s loaded' % key)
+
+ if not load_only_params:
+ epoch = state["epoch"]
+ iters = state["iters"]
+ optimizer.load_state_dict(state["optimizer"])
+ else:
+ epoch = 0
+ iters = 0
+
+ return model, optimizer, epoch, iters
+
+
+################################################################################################
+################################################################################################
+################################################################################################
+
+# LSTM ORIGINAL
+
+################################################################################################
+################################################################################################
+
+
+# # import os
+# # import os.path as osp
+
+# # import copy
+# # import math
+
+# # import numpy as np
+# # import torch
+# # import torch.nn as nn
+# # import torch.nn.functional as F
+# # from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+# # from Utils.ASR.models import ASRCNN
+# # from Utils.JDC.model import JDCNet
+
+# # from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
+# # from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
+# # from Modules.diffusion.diffusion import AudioDiffusionConditional
+
+# # from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
+
+# # from munch import Munch
+# # import yaml
+
+# # class LearnedDownSample(nn.Module):
+# # def __init__(self, layer_type, dim_in):
+# # super().__init__()
+# # self.layer_type = layer_type
+
+# # if self.layer_type == 'none':
+# # self.conv = nn.Identity()
+# # elif self.layer_type == 'timepreserve':
+# # self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
+# # elif self.layer_type == 'half':
+# # self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
+# # else:
+# # raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+# # def forward(self, x):
+# # return self.conv(x)
+
+# # class LearnedUpSample(nn.Module):
+# # def __init__(self, layer_type, dim_in):
+# # super().__init__()
+# # self.layer_type = layer_type
+
+# # if self.layer_type == 'none':
+# # self.conv = nn.Identity()
+# # elif self.layer_type == 'timepreserve':
+# # self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
+# # elif self.layer_type == 'half':
+# # self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
+# # else:
+# # raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+# # def forward(self, x):
+# # return self.conv(x)
+
+# # class DownSample(nn.Module):
+# # def __init__(self, layer_type):
+# # super().__init__()
+# # self.layer_type = layer_type
+
+# # def forward(self, x):
+# # if self.layer_type == 'none':
+# # return x
+# # elif self.layer_type == 'timepreserve':
+# # return F.avg_pool2d(x, (2, 1))
+# # elif self.layer_type == 'half':
+# # if x.shape[-1] % 2 != 0:
+# # x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
+# # return F.avg_pool2d(x, 2)
+# # else:
+# # raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+# # class UpSample(nn.Module):
+# # def __init__(self, layer_type):
+# # super().__init__()
+# # self.layer_type = layer_type
+
+# # def forward(self, x):
+# # if self.layer_type == 'none':
+# # return x
+# # elif self.layer_type == 'timepreserve':
+# # return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
+# # elif self.layer_type == 'half':
+# # return F.interpolate(x, scale_factor=2, mode='nearest')
+# # else:
+# # raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+# # class ResBlk(nn.Module):
+# # def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
+# # normalize=False, downsample='none'):
+# # super().__init__()
+# # self.actv = actv
+# # self.normalize = normalize
+# # self.downsample = DownSample(downsample)
+# # self.downsample_res = LearnedDownSample(downsample, dim_in)
+# # self.learned_sc = dim_in != dim_out
+# # self._build_weights(dim_in, dim_out)
+
+# # def _build_weights(self, dim_in, dim_out):
+# # self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
+# # self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
+# # if self.normalize:
+# # self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
+# # self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
+# # if self.learned_sc:
+# # self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+# # def _shortcut(self, x):
+# # if self.learned_sc:
+# # x = self.conv1x1(x)
+# # if self.downsample:
+# # x = self.downsample(x)
+# # return x
+
+# # def _residual(self, x):
+# # if self.normalize:
+# # x = self.norm1(x)
+# # x = self.actv(x)
+# # x = self.conv1(x)
+# # x = self.downsample_res(x)
+# # if self.normalize:
+# # x = self.norm2(x)
+# # x = self.actv(x)
+# # x = self.conv2(x)
+# # return x
+
+# # def forward(self, x):
+# # x = self._shortcut(x) + self._residual(x)
+# # return x / math.sqrt(2) # unit variance
+
+# # class StyleEncoder(nn.Module):
+# # def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
+# # super().__init__()
+# # blocks = []
+# # blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
+
+# # repeat_num = 4
+# # for _ in range(repeat_num):
+# # dim_out = min(dim_in*2, max_conv_dim)
+# # blocks += [ResBlk(dim_in, dim_out, downsample='half')]
+# # dim_in = dim_out
+
+# # blocks += [nn.LeakyReLU(0.2)]
+# # blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
+# # blocks += [nn.AdaptiveAvgPool2d(1)]
+# # blocks += [nn.LeakyReLU(0.2)]
+# # self.shared = nn.Sequential(*blocks)
+
+# # self.unshared = nn.Linear(dim_out, style_dim)
+
+# # def forward(self, x):
+# # h = self.shared(x)
+# # h = h.view(h.size(0), -1)
+# # s = self.unshared(h)
+
+# # return s
+
+# # class LinearNorm(torch.nn.Module):
+# # def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+# # super(LinearNorm, self).__init__()
+# # self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+# # torch.nn.init.xavier_uniform_(
+# # self.linear_layer.weight,
+# # gain=torch.nn.init.calculate_gain(w_init_gain))
+
+# # def forward(self, x):
+# # return self.linear_layer(x)
+
+# # class Discriminator2d(nn.Module):
+# # def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
+# # super().__init__()
+# # blocks = []
+# # blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
+
+# # for lid in range(repeat_num):
+# # dim_out = min(dim_in*2, max_conv_dim)
+# # blocks += [ResBlk(dim_in, dim_out, downsample='half')]
+# # dim_in = dim_out
+
+# # blocks += [nn.LeakyReLU(0.2)]
+# # blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
+# # blocks += [nn.LeakyReLU(0.2)]
+# # blocks += [nn.AdaptiveAvgPool2d(1)]
+# # blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
+# # self.main = nn.Sequential(*blocks)
+
+# # def get_feature(self, x):
+# # features = []
+# # for l in self.main:
+# # x = l(x)
+# # features.append(x)
+# # out = features[-1]
+# # out = out.view(out.size(0), -1) # (batch, num_domains)
+# # return out, features
+
+# # def forward(self, x):
+# # out, features = self.get_feature(x)
+# # out = out.squeeze() # (batch)
+# # return out, features
+
+# # class ResBlk1d(nn.Module):
+# # def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
+# # normalize=False, downsample='none', dropout_p=0.2):
+# # super().__init__()
+# # self.actv = actv
+# # self.normalize = normalize
+# # self.downsample_type = downsample
+# # self.learned_sc = dim_in != dim_out
+# # self._build_weights(dim_in, dim_out)
+# # self.dropout_p = dropout_p
+
+# # if self.downsample_type == 'none':
+# # self.pool = nn.Identity()
+# # else:
+# # self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
+
+# # def _build_weights(self, dim_in, dim_out):
+# # self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
+# # self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+# # if self.normalize:
+# # self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
+# # self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
+# # if self.learned_sc:
+# # self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+# # def downsample(self, x):
+# # if self.downsample_type == 'none':
+# # return x
+# # else:
+# # if x.shape[-1] % 2 != 0:
+# # x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
+# # return F.avg_pool1d(x, 2)
+
+# # def _shortcut(self, x):
+# # if self.learned_sc:
+# # x = self.conv1x1(x)
+# # x = self.downsample(x)
+# # return x
+
+# # def _residual(self, x):
+# # if self.normalize:
+# # x = self.norm1(x)
+# # x = self.actv(x)
+# # x = F.dropout(x, p=self.dropout_p, training=self.training)
+
+# # x = self.conv1(x)
+# # x = self.pool(x)
+# # if self.normalize:
+# # x = self.norm2(x)
+
+# # x = self.actv(x)
+# # x = F.dropout(x, p=self.dropout_p, training=self.training)
+
+# # x = self.conv2(x)
+# # return x
+
+# # def forward(self, x):
+# # x = self._shortcut(x) + self._residual(x)
+# # return x / math.sqrt(2) # unit variance
+
+# # class LayerNorm(nn.Module):
+# # def __init__(self, channels, eps=1e-5):
+# # super().__init__()
+# # self.channels = channels
+# # self.eps = eps
+
+# # self.gamma = nn.Parameter(torch.ones(channels))
+# # self.beta = nn.Parameter(torch.zeros(channels))
+
+# # def forward(self, x):
+# # x = x.transpose(1, -1)
+# # x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+# # return x.transpose(1, -1)
+
+# # class TextEncoder(nn.Module):
+# # def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
+# # super().__init__()
+# # self.embedding = nn.Embedding(n_symbols, channels)
+
+# # padding = (kernel_size - 1) // 2
+# # self.cnn = nn.ModuleList()
+# # for _ in range(depth):
+# # self.cnn.append(nn.Sequential(
+# # weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
+# # LayerNorm(channels),
+# # actv,
+# # nn.Dropout(0.2),
+# # ))
+# # # self.cnn = nn.Sequential(*self.cnn)
+
+# # self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
+
+# # def forward(self, x, input_lengths, m):
+# # x = self.embedding(x) # [B, T, emb]
+# # x = x.transpose(1, 2) # [B, emb, T]
+# # m = m.to(input_lengths.device).unsqueeze(1)
+# # x.masked_fill_(m, 0.0)
+
+# # for c in self.cnn:
+# # x = c(x)
+# # x.masked_fill_(m, 0.0)
+
+# # x = x.transpose(1, 2) # [B, T, chn]
+
+# # input_lengths = input_lengths.cpu().numpy()
+# # x = nn.utils.rnn.pack_padded_sequence(
+# # x, input_lengths, batch_first=True, enforce_sorted=False)
+
+# # self.lstm.flatten_parameters()
+# # x, _ = self.lstm(x)
+# # x, _ = nn.utils.rnn.pad_packed_sequence(
+# # x, batch_first=True)
+
+# # x = x.transpose(-1, -2)
+# # x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
+
+# # x_pad[:, :, :x.shape[-1]] = x
+# # x = x_pad.to(x.device)
+
+# # x.masked_fill_(m, 0.0)
+
+# # return x
+
+# # def inference(self, x):
+# # x = self.embedding(x)
+# # x = x.transpose(1, 2)
+# # x = self.cnn(x)
+# # x = x.transpose(1, 2)
+# # self.lstm.flatten_parameters()
+# # x, _ = self.lstm(x)
+# # return x
+
+# # def length_to_mask(self, lengths):
+# # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# # mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# # return mask
+
+
+
+# # class AdaIN1d(nn.Module):
+# # def __init__(self, style_dim, num_features):
+# # super().__init__()
+# # self.norm = nn.InstanceNorm1d(num_features, affine=False)
+# # self.fc = nn.Linear(style_dim, num_features*2)
+
+# # def forward(self, x, s):
+# # h = self.fc(s)
+# # h = h.view(h.size(0), h.size(1), 1)
+# # gamma, beta = torch.chunk(h, chunks=2, dim=1)
+# # return (1 + gamma) * self.norm(x) + beta
+
+# # class UpSample1d(nn.Module):
+# # def __init__(self, layer_type):
+# # super().__init__()
+# # self.layer_type = layer_type
+
+# # def forward(self, x):
+# # if self.layer_type == 'none':
+# # return x
+# # else:
+# # return F.interpolate(x, scale_factor=2, mode='nearest')
+
+# # class AdainResBlk1d(nn.Module):
+# # def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
+# # upsample='none', dropout_p=0.0):
+# # super().__init__()
+# # self.actv = actv
+# # self.upsample_type = upsample
+# # self.upsample = UpSample1d(upsample)
+# # self.learned_sc = dim_in != dim_out
+# # self._build_weights(dim_in, dim_out, style_dim)
+# # self.dropout = nn.Dropout(dropout_p)
+
+# # if upsample == 'none':
+# # self.pool = nn.Identity()
+# # else:
+# # self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
+
+
+# # def _build_weights(self, dim_in, dim_out, style_dim):
+# # self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+# # self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
+# # self.norm1 = AdaIN1d(style_dim, dim_in)
+# # self.norm2 = AdaIN1d(style_dim, dim_out)
+# # if self.learned_sc:
+# # self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+# # def _shortcut(self, x):
+# # x = self.upsample(x)
+# # if self.learned_sc:
+# # x = self.conv1x1(x)
+# # return x
+
+# # def _residual(self, x, s):
+# # x = self.norm1(x, s)
+# # x = self.actv(x)
+# # x = self.pool(x)
+# # x = self.conv1(self.dropout(x))
+# # x = self.norm2(x, s)
+# # x = self.actv(x)
+# # x = self.conv2(self.dropout(x))
+# # return x
+
+# # def forward(self, x, s):
+# # out = self._residual(x, s)
+# # out = (out + self._shortcut(x)) / math.sqrt(2)
+# # return out
+
+# # class AdaLayerNorm(nn.Module):
+# # def __init__(self, style_dim, channels, eps=1e-5):
+# # super().__init__()
+# # self.channels = channels
+# # self.eps = eps
+
+# # self.fc = nn.Linear(style_dim, channels*2)
+
+# # def forward(self, x, s):
+# # x = x.transpose(-1, -2)
+# # x = x.transpose(1, -1)
+
+# # h = self.fc(s)
+# # h = h.view(h.size(0), h.size(1), 1)
+# # gamma, beta = torch.chunk(h, chunks=2, dim=1)
+# # gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
+
+
+# # x = F.layer_norm(x, (self.channels,), eps=self.eps)
+# # x = (1 + gamma) * x + beta
+# # return x.transpose(1, -1).transpose(-1, -2)
+
+# # class ProsodyPredictor(nn.Module):
+
+# # def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
+# # super().__init__()
+
+# # self.text_encoder = DurationEncoder(sty_dim=style_dim,
+# # d_model=d_hid,
+# # nlayers=nlayers,
+# # dropout=dropout)
+
+# # self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# # self.duration_proj = LinearNorm(d_hid, max_dur)
+
+# # self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# # self.F0 = nn.ModuleList()
+# # self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# # self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# # self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# # self.N = nn.ModuleList()
+# # self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# # self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# # self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# # self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+# # self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+
+
+# # def forward(self, texts, style, text_lengths, alignment, m):
+# # d = self.text_encoder(texts, style, text_lengths, m)
+
+# # batch_size = d.shape[0]
+# # text_size = d.shape[1]
+
+# # # predict duration
+# # input_lengths = text_lengths.cpu().numpy()
+# # x = nn.utils.rnn.pack_padded_sequence(
+# # d, input_lengths, batch_first=True, enforce_sorted=False)
+
+# # m = m.to(text_lengths.device).unsqueeze(1)
+
+# # self.lstm.flatten_parameters()
+# # x, _ = self.lstm(x)
+# # x, _ = nn.utils.rnn.pad_packed_sequence(
+# # x, batch_first=True)
+
+# # x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
+
+# # x_pad[:, :x.shape[1], :] = x
+# # x = x_pad.to(x.device)
+
+# # duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
+
+# # en = (d.transpose(-1, -2) @ alignment)
+
+# # return duration.squeeze(-1), en
+
+# # def F0Ntrain(self, x, s):
+# # x, _ = self.shared(x.transpose(-1, -2))
+
+# # F0 = x.transpose(-1, -2)
+# # for block in self.F0:
+# # F0 = block(F0, s)
+# # F0 = self.F0_proj(F0)
+
+# # N = x.transpose(-1, -2)
+# # for block in self.N:
+# # N = block(N, s)
+# # N = self.N_proj(N)
+
+# # return F0.squeeze(1), N.squeeze(1)
+
+# # def length_to_mask(self, lengths):
+# # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# # mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# # return mask
+
+# # class DurationEncoder(nn.Module):
+
+# # def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
+# # super().__init__()
+# # self.lstms = nn.ModuleList()
+# # for _ in range(nlayers):
+# # self.lstms.append(nn.LSTM(d_model + sty_dim,
+# # d_model // 2,
+# # num_layers=1,
+# # batch_first=True,
+# # bidirectional=True,
+# # dropout=dropout))
+# # self.lstms.append(AdaLayerNorm(sty_dim, d_model))
+
+
+# # self.dropout = dropout
+# # self.d_model = d_model
+# # self.sty_dim = sty_dim
+
+# # def forward(self, x, style, text_lengths, m):
+# # masks = m.to(text_lengths.device)
+
+# # x = x.permute(2, 0, 1)
+# # s = style.expand(x.shape[0], x.shape[1], -1)
+# # x = torch.cat([x, s], axis=-1)
+# # x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
+
+# # x = x.transpose(0, 1)
+# # input_lengths = text_lengths.cpu().numpy()
+# # x = x.transpose(-1, -2)
+
+# # for block in self.lstms:
+# # if isinstance(block, AdaLayerNorm):
+# # x = block(x.transpose(-1, -2), style).transpose(-1, -2)
+# # x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
+# # x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
+# # else:
+# # x = x.transpose(-1, -2)
+# # x = nn.utils.rnn.pack_padded_sequence(
+# # x, input_lengths, batch_first=True, enforce_sorted=False)
+# # block.flatten_parameters()
+# # x, _ = block(x)
+# # x, _ = nn.utils.rnn.pad_packed_sequence(
+# # x, batch_first=True)
+# # x = F.dropout(x, p=self.dropout, training=self.training)
+# # x = x.transpose(-1, -2)
+
+# # x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
+
+# # x_pad[:, :, :x.shape[-1]] = x
+# # x = x_pad.to(x.device)
+
+# # return x.transpose(-1, -2)
+
+# # def inference(self, x, style):
+# # x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
+# # style = style.expand(x.shape[0], x.shape[1], -1)
+# # x = torch.cat([x, style], axis=-1)
+# # src = self.pos_encoder(x)
+# # output = self.transformer_encoder(src).transpose(0, 1)
+# # return output
+
+# # def length_to_mask(self, lengths):
+# # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+# # mask = torch.gt(mask+1, lengths.unsqueeze(1))
+# # return mask
+
+# # def load_F0_models(path):
+# # # load F0 model
+
+# # F0_model = JDCNet(num_class=1, seq_len=192)
+# # params = torch.load(path, map_location='cpu')['net']
+# # F0_model.load_state_dict(params)
+# # _ = F0_model.train()
+
+# # return F0_model
+
+# # def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
+# # # load ASR model
+# # def _load_config(path):
+# # with open(path) as f:
+# # config = yaml.safe_load(f)
+# # model_config = config['model_params']
+# # return model_config
+
+# # def _load_model(model_config, model_path):
+# # model = ASRCNN(**model_config)
+# # params = torch.load(model_path, map_location='cpu')['model']
+# # model.load_state_dict(params)
+# # return model
+
+# # asr_model_config = _load_config(ASR_MODEL_CONFIG)
+# # asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
+# # _ = asr_model.train()
+
+# # return asr_model
+
+# # def build_model(args, text_aligner, pitch_extractor, bert):
+# # assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
+
+# # if args.decoder.type == "istftnet":
+# # from Modules.istftnet import Decoder
+# # decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
+# # resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
+# # upsample_rates = args.decoder.upsample_rates,
+# # upsample_initial_channel=args.decoder.upsample_initial_channel,
+# # resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
+# # upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
+# # gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
+# # else:
+# # from Modules.hifigan import Decoder
+# # decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
+# # resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
+# # upsample_rates = args.decoder.upsample_rates,
+# # upsample_initial_channel=args.decoder.upsample_initial_channel,
+# # resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
+# # upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
+
+# # text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
+
+# # predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
+
+# # style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
+# # predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
+
+# # # define diffusion model
+# # if args.multispeaker:
+# # transformer = StyleTransformer1d(channels=args.style_dim*2,
+# # context_embedding_features=bert.config.hidden_size,
+# # context_features=args.style_dim*2,
+# # **args.diffusion.transformer)
+# # else:
+# # transformer = Transformer1d(channels=args.style_dim*2,
+# # context_embedding_features=bert.config.hidden_size,
+# # **args.diffusion.transformer)
+
+# # diffusion = AudioDiffusionConditional(
+# # in_channels=1,
+# # embedding_max_length=bert.config.max_position_embeddings,
+# # embedding_features=bert.config.hidden_size,
+# # embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
+# # channels=args.style_dim*2,
+# # context_features=args.style_dim*2,
+# # )
+
+# # diffusion.diffusion = KDiffusion(
+# # net=diffusion.unet,
+# # sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
+# # sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
+# # dynamic_threshold=0.0
+# # )
+# # diffusion.diffusion.net = transformer
+# # diffusion.unet = transformer
+
+
+# # nets = Munch(
+# # bert=bert,
+# # bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
+
+# # predictor=predictor,
+# # decoder=decoder,
+# # text_encoder=text_encoder,
+
+# # predictor_encoder=predictor_encoder,
+# # style_encoder=style_encoder,
+# # diffusion=diffusion,
+
+# # text_aligner = text_aligner,
+# # pitch_extractor=pitch_extractor,
+
+# # mpd = MultiPeriodDiscriminator(),
+# # msd = MultiResSpecDiscriminator(),
+
+# # # slm discriminator head
+# # wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
+# # )
+
+# # return nets
+
+# # def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
+# # state = torch.load(path, map_location='cpu')
+# # params = state['net']
+
+# # for key in model:
+# # if key in params and key not in ignore_modules:
+# # print('%s loaded' % key)
+# # try:
+# # model[key].load_state_dict(params[key], strict=True)
+# # except:
+# # from collections import OrderedDict
+# # state_dict = params[key]
+# # new_state_dict = OrderedDict()
+# # print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict length: {len(state_dict.keys())}')
+# # for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
+# # new_state_dict[k_m] = v_c
+# # model[key].load_state_dict(new_state_dict, strict=True)
+# # _ = [model[key].eval() for key in model]
+
+# # if not load_only_params:
+# # epoch = state["epoch"]
+# # iters = state["iters"]
+# # optimizer.load_state_dict(state["optimizer"])
+# # else:
+# # epoch = 0
+# # iters = 0
+
+# # return model, optimizer, epoch, iters
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/optimizers.py b/stts_48khz/StyleTTS2_48khz/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..032ae15d9dd5f0020d68d28440b25c3570bc8ed2
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/optimizers.py
@@ -0,0 +1,73 @@
+#coding:utf-8
+import os, sys
+import os.path as osp
+import numpy as np
+import torch
+from torch import nn
+from torch.optim import Optimizer
+from functools import reduce
+from torch.optim import AdamW
+
+class MultiOptimizer:
+ def __init__(self, optimizers={}, schedulers={}):
+ self.optimizers = optimizers
+ self.schedulers = schedulers
+ self.keys = list(optimizers.keys())
+ self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()])
+
+ def state_dict(self):
+ state_dicts = [(key, self.optimizers[key].state_dict())\
+ for key in self.keys]
+ return state_dicts
+
+ def load_state_dict(self, state_dict):
+ for key, val in state_dict:
+ try:
+ self.optimizers[key].load_state_dict(val)
+ except:
+ print("Unloaded %s" % key)
+
+ def step(self, key=None, scaler=None):
+ keys = [key] if key is not None else self.keys
+ _ = [self._step(key, scaler) for key in keys]
+
+ def _step(self, key, scaler=None):
+ if scaler is not None:
+ scaler.step(self.optimizers[key])
+ scaler.update()
+ else:
+ self.optimizers[key].step()
+
+ def zero_grad(self, key=None):
+ if key is not None:
+ self.optimizers[key].zero_grad()
+ else:
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
+
+ def scheduler(self, *args, key=None):
+ if key is not None:
+ self.schedulers[key].step(*args)
+ else:
+ _ = [self.schedulers[key].step(*args) for key in self.keys]
+
+def define_scheduler(optimizer, params):
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=params.get('max_lr', 2e-4),
+ epochs=params.get('epochs', 200),
+ steps_per_epoch=params.get('steps_per_epoch', 1000),
+ pct_start=params.get('pct_start', 0.0),
+ div_factor=1,
+ final_div_factor=1)
+
+ return scheduler
+
+def build_optimizer(parameters_dict, scheduler_params_dict, lr):
+ optim = dict([(key, AdamW(params, lr=lr, weight_decay=1e-4, betas=(0.0, 0.99), eps=1e-9))
+ for key, params in parameters_dict.items()])
+
+ schedulers = dict([(key, define_scheduler(opt, scheduler_params_dict[key])) \
+ for key, opt in optim.items()])
+
+ multi_optim = MultiOptimizer(optim, schedulers)
+ return multi_optim
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/requirements.txt b/stts_48khz/StyleTTS2_48khz/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b8d11221143a7adfb17a08c6334d99b214b7e2c
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/requirements.txt
@@ -0,0 +1,17 @@
+SoundFile
+torchaudio
+munch
+torch
+pydub
+pyyaml
+librosa
+nltk
+matplotlib
+accelerate
+transformers
+einops
+einops-exts
+tqdm
+typing
+typing-extensions
+git+https://github.com/resemble-ai/monotonic_align.git
\ No newline at end of file
diff --git a/stts_48khz/StyleTTS2_48khz/text_utils.py b/stts_48khz/StyleTTS2_48khz/text_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..97c252ba5953910fc39fd0b25b51f01e17ef0f81
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/text_utils.py
@@ -0,0 +1,26 @@
+# IPA Phonemizer: https://github.com/bootphon/phonemizer
+
+_pad = "$"
+_punctuation = ';:,.!?¡¿—…"«»“” '
+_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
+_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
+
+# Export all symbols:
+symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
+
+dicts = {}
+for i in range(len((symbols))):
+ dicts[symbols[i]] = i
+
+class TextCleaner:
+ def __init__(self, dummy=None):
+ self.word_index_dictionary = dicts
+ print(len(dicts))
+ def __call__(self, text):
+ indexes = []
+ for char in text:
+ try:
+ indexes.append(self.word_index_dictionary[char])
+ except KeyError:
+ print(text)
+ return indexes
diff --git a/stts_48khz/StyleTTS2_48khz/train_finetune.py b/stts_48khz/StyleTTS2_48khz/train_finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..38184fc822dcf5a84903745fd81eb53130d2d0a3
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/train_finetune.py
@@ -0,0 +1,721 @@
+# load packages
+import random
+import yaml
+import time
+from munch import Munch
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+import click
+import shutil
+import warnings
+warnings.simplefilter('ignore')
+from torch.utils.tensorboard import SummaryWriter
+
+from meldataset import build_dataloader
+
+from Utils.ASR.models import ASRCNN
+from Utils.JDC.model import JDCNet
+from Utils.PLBERT.util import load_plbert
+
+from models import *
+from losses import *
+from utils import *
+
+from Modules.slmadv import SLMAdversarialLoss
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+
+from optimizers import build_optimizer
+
+# simple fix for dataparallel that allows access to class attributes
+class MyDataParallel(torch.nn.DataParallel):
+ def __getattr__(self, name):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.module, name)
+
+import logging
+from logging import StreamHandler
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+handler = StreamHandler()
+handler.setLevel(logging.DEBUG)
+logger.addHandler(handler)
+
+
+@click.command()
+@click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
+def main(config_path):
+ config = yaml.safe_load(open(config_path))
+
+ log_dir = config['log_dir']
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
+ writer = SummaryWriter(log_dir + "/tensorboard")
+
+ # write logs
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
+ logger.addHandler(file_handler)
+
+
+ batch_size = config.get('batch_size', 10)
+
+ epochs = config.get('epochs', 200)
+ save_freq = config.get('save_freq', 2)
+ log_interval = config.get('log_interval', 10)
+ saving_epoch = config.get('save_freq', 2)
+
+ data_params = config.get('data_params', None)
+ sr = config['preprocess_params'].get('sr', 48000)
+ train_path = data_params['train_data']
+ val_path = data_params['val_data']
+ root_path = data_params['root_path']
+ min_length = data_params['min_length']
+ OOD_data = data_params['OOD_data']
+
+ max_len = config.get('max_len', 200)
+
+ loss_params = Munch(config['loss_params'])
+ diff_epoch = loss_params.diff_epoch
+ joint_epoch = loss_params.joint_epoch
+
+ optimizer_params = Munch(config['optimizer_params'])
+
+ train_list, val_list = get_data_path_list(train_path, val_path)
+ device = 'cuda'
+
+ train_dataloader = build_dataloader(train_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ num_workers=2,
+ dataset_config={},
+ device=device)
+
+ val_dataloader = build_dataloader(val_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ validation=True,
+ num_workers=0,
+ device=device,
+ dataset_config={})
+
+ # load pretrained ASR model
+ ASR_config = config.get('ASR_config', False)
+ ASR_path = config.get('ASR_path', False)
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
+
+ # load pretrained F0 model
+ F0_path = config.get('F0_path', False)
+ pitch_extractor = load_F0_models(F0_path)
+
+ # load PL-BERT model
+ BERT_path = config.get('PLBERT_dir', False)
+ plbert = load_plbert(BERT_path)
+
+ # build model
+ model_params = recursive_munch(config['model_params'])
+ multispeaker = model_params.multispeaker
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
+ _ = [model[key].to(device) for key in model]
+
+ # DP
+ for key in model:
+ if key != "mpd" and key != "msd" and key != "wd":
+ model[key] = MyDataParallel(model[key])
+
+ start_epoch = 0
+ iters = 0
+
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
+
+ if not load_pretrained:
+ if config.get('first_stage_path', '') != '':
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
+ print('Loading the first stage model at %s ...' % first_stage_path)
+ model, _, start_epoch, iters = load_checkpoint(model,
+ None,
+ first_stage_path,
+ load_only_params=True,
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
+
+ # these epochs should be counted from the start epoch
+ diff_epoch += start_epoch
+ joint_epoch += start_epoch
+ epochs += start_epoch
+
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
+ else:
+ raise ValueError('You need to specify the path to the first stage model.')
+
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
+ wl = WavLMLoss(model_params.slm.model,
+ model.wd,
+ sr,
+ model_params.slm.sr).to(device)
+
+ gl = MyDataParallel(gl)
+ dl = MyDataParallel(dl)
+ wl = MyDataParallel(wl)
+
+ sampler = DiffusionSampler(
+ model.diffusion.diffusion,
+ sampler=ADPM2Sampler(),
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
+ clamp=False
+ )
+
+ scheduler_params = {
+ "max_lr": optimizer_params.lr,
+ "pct_start": float(0),
+ "epochs": epochs,
+ "steps_per_epoch": len(train_dataloader),
+ }
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
+
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
+
+ # adjust BERT learning rate
+ for g in optimizer.optimizers['bert'].param_groups:
+ g['betas'] = (0.9, 0.99)
+ g['lr'] = optimizer_params.bert_lr
+ g['initial_lr'] = optimizer_params.bert_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 0.01
+
+ # adjust acoustic module learning rate
+ for module in ["decoder", "style_encoder"]:
+ for g in optimizer.optimizers[module].param_groups:
+ g['betas'] = (0.0, 0.99)
+ g['lr'] = optimizer_params.ft_lr
+ g['initial_lr'] = optimizer_params.ft_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 1e-4
+
+ # load models if there is a model
+ if load_pretrained:
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
+ load_only_params=config.get('load_only_params', True))
+
+ n_down = model.text_aligner.n_down
+
+ best_loss = float('inf') # best test loss
+ loss_train_record = list([])
+ loss_test_record = list([])
+ iters = 0
+
+ criterion = nn.L1Loss() # F0 loss (regression)
+ torch.cuda.empty_cache()
+
+ stft_loss = MultiResolutionSTFTLoss().to(device)
+
+ print('BERT', optimizer.optimizers['bert'])
+ print('decoder', optimizer.optimizers['decoder'])
+
+ start_ds = False
+
+ running_std = []
+
+ slmadv_params = Munch(config['slmadv_params'])
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
+ slmadv_params.min_len,
+ slmadv_params.max_len,
+ batch_percentage=slmadv_params.batch_percentage,
+ skip_update=slmadv_params.iter,
+ sig=slmadv_params.sig
+ )
+
+
+ for epoch in range(start_epoch, epochs):
+ running_loss = 0
+ start_time = time.time()
+
+ _ = [model[key].eval() for key in model]
+
+ model.text_aligner.train()
+ model.text_encoder.train()
+
+ model.predictor.train()
+ model.bert_encoder.train()
+ model.bert.train()
+ model.msd.train()
+ model.mpd.train()
+
+ for i, batch in enumerate(train_dataloader):
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
+ mel_mask = length_to_mask(mel_input_length).to(device)
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
+
+ try:
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ except:
+ continue
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+
+ # 50% of chance of using monotonic version
+ if bool(random.getrandbits(1)):
+ asr = (t_en @ s2s_attn)
+ else:
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ # compute the style of the entire utterance
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
+ ss = []
+ gs = []
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
+ gs = torch.stack(gs).squeeze() # global acoustic styles
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ # denoiser training
+ if epoch >= diff_epoch:
+ num_steps = np.random.randint(3, 5)
+
+ if model_params.diffusion.dist.estimate_sigma_data:
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
+
+ if multispeaker:
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ features=ref, # reference from the same speaker as the embedding
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ loss_sty = 0
+ loss_diff = 0
+
+
+ s_loss = 0
+
+
+ d, p = model.predictor(d_en, s_dur,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
+
+ en = []
+ gt = []
+ p_en = []
+ wav = []
+ st = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+
+
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ # style reference (better to be different from the GT)
+ random_start = np.random.randint(0, mel_length - mel_len_st)
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+ st = torch.stack(st).detach()
+
+
+ if gt.size(-1) < 80:
+ continue
+
+ s = model.style_encoder(gt.unsqueeze(1))
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
+
+ with torch.no_grad():
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
+
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec_gt = wav.unsqueeze(1)
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
+
+ wav = y_rec_gt
+
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
+
+ optimizer.zero_grad()
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
+ d_loss.backward()
+ optimizer.step('msd')
+ optimizer.step('mpd')
+
+ # generator loss
+ optimizer.zero_grad()
+
+ loss_mel = stft_loss(y_rec, wav)
+ loss_gen_all = gl(wav, y_rec).mean()
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
+
+ loss_ce = 0
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for p in range(_s2s_trg.shape[0]):
+ _s2s_trg[p, :_text_input[p]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
+ _text_input[1:_text_length-1])
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
+
+ loss_ce /= texts.size(0)
+ loss_dur /= texts.size(0)
+
+ loss_s2s = 0
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
+ loss_s2s /= texts.size(0)
+
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
+
+ g_loss = loss_params.lambda_mel * loss_mel + \
+ loss_params.lambda_F0 * loss_F0_rec + \
+ loss_params.lambda_ce * loss_ce + \
+ loss_params.lambda_norm * loss_norm_rec + \
+ loss_params.lambda_dur * loss_dur + \
+ loss_params.lambda_gen * loss_gen_all + \
+ loss_params.lambda_slm * loss_lm + \
+ loss_params.lambda_sty * loss_sty + \
+ loss_params.lambda_diff * loss_diff + \
+ loss_params.lambda_mono * loss_mono + \
+ loss_params.lambda_s2s * loss_s2s
+
+ running_loss += loss_mel.item()
+ g_loss.backward()
+ if torch.isnan(g_loss):
+ from IPython.core.debugger import set_trace
+ set_trace()
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('predictor_encoder')
+ optimizer.step('style_encoder')
+ optimizer.step('decoder')
+
+ optimizer.step('text_encoder')
+ optimizer.step('text_aligner')
+
+ if epoch >= diff_epoch:
+ optimizer.step('diffusion')
+
+ d_loss_slm, loss_gen_lm = 0, 0
+ if epoch >= joint_epoch:
+ # randomly pick whether to use in-distribution text
+ if np.random.rand() < 0.5:
+ use_ind = True
+ else:
+ use_ind = False
+
+ if use_ind:
+ ref_lengths = input_lengths
+ ref_texts = texts
+
+ slm_out = slmadv(i,
+ y_rec_gt,
+ y_rec_gt_pred,
+ waves,
+ mel_input_length,
+ ref_texts,
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
+
+ if slm_out is not None:
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
+
+ # SLM generator loss
+ optimizer.zero_grad()
+ loss_gen_lm.backward()
+
+ # compute the gradient norm
+ total_norm = {}
+ for key in model.keys():
+ total_norm[key] = 0
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
+ for p in parameters:
+ param_norm = p.grad.detach().data.norm(2)
+ total_norm[key] += param_norm.item() ** 2
+ total_norm[key] = total_norm[key] ** 0.5
+
+ # gradient scaling
+ if total_norm['predictor'] > slmadv_params.thresh:
+ for key in model.keys():
+ for p in model[key].parameters():
+ if p.grad is not None:
+ p.grad *= (1 / total_norm['predictor'])
+
+ for p in model.predictor.duration_proj.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ for p in model.predictor.lstm.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ for p in model.diffusion.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('diffusion')
+
+ # SLM discriminator loss
+ if d_loss_slm != 0:
+ optimizer.zero_grad()
+ d_loss_slm.backward(retain_graph=True)
+ optimizer.step('wd')
+
+ iters = iters + 1
+
+ if (i+1)%log_interval == 0:
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono))
+
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
+ writer.add_scalar('train/d_loss', d_loss, iters)
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
+
+ running_loss = 0
+
+ print('Time elasped:', time.time()-start_time)
+
+ if (i+1)%1000 == 0:
+ print(f'Saving on step {epoch*len(train_dataloader)+i}...')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, f'finetune_phase_{epoch*len(train_dataloader)+i}.pth')
+ torch.save(state, save_path)
+
+ loss_test = 0
+ loss_align = 0
+ loss_f = 0
+ _ = [model[key].eval() for key in model]
+
+ with torch.no_grad():
+ iters_test = 0
+ for batch_idx, batch in enumerate(val_dataloader):
+ optimizer.zero_grad()
+
+ try:
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ ss = []
+ gs = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s = torch.stack(ss).squeeze()
+ gs = torch.stack(gs).squeeze()
+ s_trg = torch.cat([s, gs], dim=-1).detach()
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+ d, p = model.predictor(d_en, s,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+ # get clips
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
+
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+ s = model.predictor_encoder(gt.unsqueeze(1))
+
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
+
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for bib in range(_s2s_trg.shape[0]):
+ _s2s_trg[bib, :_text_input[bib]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
+ _text_input[1:_text_length-1])
+
+ loss_dur /= texts.size(0)
+
+ s = model.style_encoder(gt.unsqueeze(1))
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
+
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
+
+ loss_test += (loss_mel).mean()
+ loss_align += (loss_dur).mean()
+ loss_f += (loss_F0).mean()
+
+ iters_test += 1
+ except:
+ continue
+
+ print('Epochs:', epoch + 1)
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
+ print('\n\n\n')
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
+ writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
+
+
+ if (epoch + 1) % save_freq == 0 :
+ if (loss_test / iters_test) < best_loss:
+ best_loss = loss_test / iters_test
+ print('Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
+ torch.save(state, save_path)
+
+ # if estimate sigma, save the estimated simga
+ if model_params.diffusion.dist.estimate_sigma_data:
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
+
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
+ yaml.dump(config, outfile, default_flow_style=True)
+
+
+if __name__=="__main__":
+ main()
diff --git a/stts_48khz/StyleTTS2_48khz/train_first.py b/stts_48khz/StyleTTS2_48khz/train_first.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ff954c1f0f098a691d1df7899a5dd5c44d3b957
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/train_first.py
@@ -0,0 +1,466 @@
+import os
+import os.path as osp
+import re
+import sys
+import yaml
+import shutil
+import numpy as np
+import torch
+import click
+import warnings
+warnings.simplefilter('ignore')
+
+# load packages
+import random
+import yaml
+from munch import Munch
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+from accelerate.utils import GradientAccumulationPlugin
+
+from models import *
+from meldataset import build_dataloader
+from utils import *
+from losses import *
+from optimizers import build_optimizer
+import time
+
+from accelerate import Accelerator
+from accelerate.utils import LoggerType
+from accelerate import DistributedDataParallelKwargs
+
+from torch.utils.tensorboard import SummaryWriter
+
+import logging
+from accelerate.logging import get_logger
+logger = get_logger(__name__, log_level="DEBUG")
+
+@click.command()
+@click.option('-p', '--config_path', default='Configs/config.yml', type=str)
+def main(config_path):
+ config = yaml.safe_load(open(config_path))
+
+ save_iter = 2500
+
+ log_dir = config['log_dir']
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs], mixed_precision='bf16')
+ if accelerator.is_main_process:
+ writer = SummaryWriter(log_dir + "/tensorboard")
+
+ # write logs
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
+ logger.logger.addHandler(file_handler)
+
+ batch_size = config.get('batch_size', 10)
+ device = accelerator.device
+
+ epochs = config.get('epochs_1st', 200)
+ save_freq = config.get('save_freq', 2)
+ log_interval = config.get('log_interval', 10)
+ saving_epoch = config.get('save_freq', 2)
+
+ data_params = config.get('data_params', None)
+ sr = config['preprocess_params'].get('sr', 48000)
+ train_path = data_params['train_data']
+ val_path = data_params['val_data']
+ root_path = data_params['root_path']
+ min_length = data_params['min_length']
+ OOD_data = data_params['OOD_data']
+
+ max_len = config.get('max_len', 200)
+
+ # load data
+ train_list, val_list = get_data_path_list(train_path, val_path)
+
+ train_dataloader = build_dataloader(train_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ num_workers=2,
+ dataset_config={},
+ device=device)
+
+ val_dataloader = build_dataloader(val_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ validation=True,
+ num_workers=0,
+ device=device,
+ dataset_config={})
+
+ with accelerator.main_process_first():
+ # load pretrained ASR model
+ ASR_config = config.get('ASR_config', False)
+ ASR_path = config.get('ASR_path', False)
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
+
+ # load pretrained F0 model
+ F0_path = config.get('F0_path', False)
+ pitch_extractor = load_F0_models(F0_path)
+
+ # load BERT model
+ from Utils.PLBERT.util import load_plbert
+ BERT_path = config.get('PLBERT_dir', False)
+ plbert = load_plbert(BERT_path)
+
+ scheduler_params = {
+ "max_lr": float(config['optimizer_params'].get('lr', 1e-4)),
+ "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
+ "epochs": epochs,
+ "steps_per_epoch": len(train_dataloader),
+ }
+
+ model_params = recursive_munch(config['model_params'])
+ multispeaker = model_params.multispeaker
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
+
+ best_loss = float('inf') # best test loss
+ loss_train_record = list([])
+ loss_test_record = list([])
+
+ loss_params = Munch(config['loss_params'])
+ TMA_epoch = loss_params.TMA_epoch
+
+ for k in model:
+ model[k] = accelerator.prepare(model[k])
+
+ train_dataloader, val_dataloader = accelerator.prepare(
+ train_dataloader, val_dataloader
+ )
+
+ _ = [model[key].to(device) for key in model]
+
+ # initialize optimizers after preparing models for compatibility with FSDP
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model},
+ lr=float(config['optimizer_params'].get('lr', 1e-4)))
+
+ for k, v in optimizer.optimizers.items():
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
+
+ with accelerator.main_process_first():
+ if config.get('pretrained_model', '') != '':
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
+ load_only_params=config.get('load_only_params', True))
+ else:
+ start_epoch = 0
+ iters = 0
+
+ # in case not distributed
+ try:
+ n_down = model.text_aligner.module.n_down
+ except:
+ n_down = model.text_aligner.n_down
+
+ # wrapped losses for compatibility with mixed precision
+ stft_loss = MultiResolutionSTFTLoss().to(device)
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
+ wl = WavLMLoss(model_params.slm.model,
+ model.wd,
+ sr,
+ model_params.slm.sr).to(device)
+
+ for epoch in range(start_epoch, epochs):
+ running_loss = 0
+ start_time = time.time()
+
+ _ = [model[key].train() for key in model]
+
+ for i, batch in enumerate(train_dataloader):
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
+
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
+
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+
+ with torch.no_grad():
+ attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
+ attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
+ attn_mask = (attn_mask < 1)
+
+ s2s_attn.masked_fill_(attn_mask, 0.0)
+
+ with torch.no_grad():
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+
+ # 50% of chance of using monotonic version
+ if bool(random.getrandbits(1)):
+ asr = (t_en @ s2s_attn)
+ else:
+ asr = (t_en @ s2s_attn_mono)
+
+ # get clips
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
+
+ en = []
+ gt = []
+ wav = []
+ st = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ # style reference (better to be different from the GT)
+ random_start = np.random.randint(0, mel_length - mel_len_st)
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
+
+ en = torch.stack(en)
+ gt = torch.stack(gt).detach()
+ st = torch.stack(st).detach()
+
+ wav = torch.stack(wav).float().detach()
+
+ # clip too short to be used by the style encoder
+ if gt.shape[-1] < 80:
+ continue
+
+ with torch.no_grad():
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
+
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+
+
+ y_rec = model.decoder(en, F0_real, real_norm, s)
+
+ # discriminator loss
+
+ if epoch >= TMA_epoch:
+ optimizer.zero_grad()
+
+ d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
+ accelerator.backward(d_loss)
+ optimizer.step('msd')
+ optimizer.step('mpd')
+ else:
+ d_loss = 0
+
+ # generator loss
+ optimizer.zero_grad()
+
+
+
+
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
+
+ if epoch >= TMA_epoch: # start TMA training
+ loss_s2s = 0
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
+ loss_s2s /= texts.size(0)
+
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
+
+ loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
+ loss_slm = wl(wav.detach(), y_rec).mean()
+
+ g_loss = loss_params.lambda_mel * loss_mel + \
+ loss_params.lambda_mono * loss_mono + \
+ loss_params.lambda_s2s * loss_s2s + \
+ loss_params.lambda_gen * loss_gen_all + \
+ loss_params.lambda_slm * loss_slm
+
+ else:
+ loss_s2s = 0
+ loss_mono = 0
+ loss_gen_all = 0
+ loss_slm = 0
+ g_loss = loss_mel
+
+ running_loss += accelerator.gather(loss_mel).mean().item()
+
+ accelerator.backward(g_loss)
+
+ optimizer.step('text_encoder')
+ optimizer.step('style_encoder')
+ optimizer.step('decoder')
+
+ if epoch >= TMA_epoch:
+ optimizer.step('text_aligner')
+ optimizer.step('pitch_extractor')
+
+ iters = iters + 1
+
+ if (i+1)%log_interval == 0 and accelerator.is_main_process:
+ log_print ('Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f'
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, loss_gen_all, d_loss, loss_mono, loss_s2s, loss_slm), logger)
+
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
+ writer.add_scalar('train/d_loss', d_loss, iters)
+ writer.add_scalar('train/mono_loss', loss_mono, iters)
+ writer.add_scalar('train/s2s_loss', loss_s2s, iters)
+ writer.add_scalar('train/slm_loss', loss_slm, iters)
+
+ running_loss = 0
+
+ print('Time elasped:', time.time()-start_time)
+
+ if (i+1)%save_iter == 0 and accelerator.is_main_process:
+
+ print(f'Saving on step {epoch*len(train_dataloader)+i}...')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, f'2nd_phase_{epoch*len(train_dataloader)+i}.pth')
+ torch.save(state, save_path)
+
+ loss_test = 0
+
+ _ = [model[key].eval() for key in model]
+
+ with torch.no_grad():
+ iters_test = 0
+ for batch_idx, batch in enumerate(val_dataloader):
+ optimizer.zero_grad()
+
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
+
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
+
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+ attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
+ attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
+ attn_mask = (attn_mask < 1)
+ s2s_attn.masked_fill_(attn_mask, 0.0)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+
+ asr = (t_en @ s2s_attn)
+
+ # get clips
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
+ mel_len = min([int(mel_input_length.min().item() / 2 - 1), max_len // 2])
+
+ en = []
+ gt = []
+ wav = []
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to('cuda'))
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ gt = torch.stack(gt).detach()
+
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+ s = model.style_encoder(gt.unsqueeze(1))
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
+ y_rec = model.decoder(en, F0_real, real_norm, s)
+
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
+
+ loss_test += accelerator.gather(loss_mel).mean().item()
+ iters_test += 1
+
+ if accelerator.is_main_process:
+ print('Epochs:', epoch + 1)
+ log_print('Validation loss: %.3f' % (loss_test / iters_test) + '\n\n\n\n', logger)
+ print('\n\n\n')
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
+ attn_image = get_image(s2s_attn[0].cpu().numpy().squeeze())
+ writer.add_figure('eval/attn', attn_image, epoch)
+
+ with torch.no_grad():
+ for bib in range(len(asr)):
+ mel_length = int(mel_input_length[bib].item())
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
+
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
+ F0_real = F0_real.unsqueeze(0)
+ s = model.style_encoder(gt.unsqueeze(1))
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec = model.decoder(en, F0_real, real_norm, s)
+
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+ if epoch == 0:
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
+
+ if bib >= 6:
+ break
+
+ if epoch % saving_epoch == 0:
+ if (loss_test / iters_test) < best_loss:
+ best_loss = loss_test / iters_test
+ print('Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, 'epoch_1st_%05d.pth' % epoch)
+ torch.save(state, save_path)
+
+ if accelerator.is_main_process:
+ print('Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
+ torch.save(state, save_path)
+
+
+
+if __name__=="__main__":
+ main()
diff --git a/stts_48khz/StyleTTS2_48khz/train_second.py b/stts_48khz/StyleTTS2_48khz/train_second.py
new file mode 100644
index 0000000000000000000000000000000000000000..4244a91ebe97721a8c691f4b8755f0f959f1e13b
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/train_second.py
@@ -0,0 +1,814 @@
+# load packages
+import random
+import yaml
+import time
+from munch import Munch
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+import click
+import shutil
+import traceback
+import warnings
+warnings.simplefilter('ignore')
+from torch.utils.tensorboard import SummaryWriter
+
+from meldataset import build_dataloader
+
+from Utils.ASR.models import ASRCNN
+from Utils.JDC.model import JDCNet
+from Utils.PLBERT.util import load_plbert
+
+from models import *
+from losses import *
+from utils import *
+
+from Modules.slmadv import SLMAdversarialLoss
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+
+from optimizers import build_optimizer
+
+# simple fix for dataparallel that allows access to class attributes
+class MyDataParallel(torch.nn.DataParallel):
+ def __getattr__(self, name):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.module, name)
+
+import logging
+from logging import StreamHandler
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+handler = StreamHandler()
+handler.setLevel(logging.DEBUG)
+logger.addHandler(handler)
+
+
+@click.command()
+@click.option('-p', '--config_path', default='Configs/config.yml', type=str)
+def main(config_path):
+ config = yaml.safe_load(open(config_path))
+
+ log_dir = config['log_dir']
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
+ writer = SummaryWriter(log_dir + "/tensorboard")
+
+ # write logs
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
+ logger.addHandler(file_handler)
+
+
+ batch_size = config.get('batch_size', 10)
+
+ save_iter = 1000
+
+ epochs = config.get('epochs_2nd', 200)
+ save_freq = config.get('save_freq', 2)
+ log_interval = config.get('log_interval', 10)
+ saving_epoch = config.get('save_freq', 2)
+
+ data_params = config.get('data_params', None)
+ sr = config['preprocess_params'].get('sr', 48000)
+ train_path = data_params['train_data']
+ val_path = data_params['val_data']
+ root_path = data_params['root_path']
+ min_length = data_params['min_length']
+ OOD_data = data_params['OOD_data']
+
+ max_len = config.get('max_len', 200)
+
+ loss_params = Munch(config['loss_params'])
+ diff_epoch = loss_params.diff_epoch
+ joint_epoch = loss_params.joint_epoch
+
+ optimizer_params = Munch(config['optimizer_params'])
+
+ train_list, val_list = get_data_path_list(train_path, val_path)
+ device = 'cuda'
+
+ train_dataloader = build_dataloader(train_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ num_workers=2,
+ dataset_config={},
+ device=device)
+
+ val_dataloader = build_dataloader(val_list,
+ root_path,
+ OOD_data=OOD_data,
+ min_length=min_length,
+ batch_size=batch_size,
+ validation=True,
+ num_workers=0,
+ device=device,
+ dataset_config={})
+
+ # load pretrained ASR model
+ ASR_config = config.get('ASR_config', False)
+ ASR_path = config.get('ASR_path', False)
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
+
+ # load pretrained F0 model
+ F0_path = config.get('F0_path', False)
+ pitch_extractor = load_F0_models(F0_path)
+
+ # load PL-BERT model
+ BERT_path = config.get('PLBERT_dir', False)
+ plbert = load_plbert(BERT_path)
+
+ # build model
+ model_params = recursive_munch(config['model_params'])
+ multispeaker = model_params.multispeaker
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
+ _ = [model[key].to(device) for key in model]
+
+ # DP
+ for key in model:
+ if key != "mpd" and key != "msd" and key != "wd":
+ model[key] = MyDataParallel(model[key])
+
+ start_epoch = 0
+ iters = 0
+
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
+
+ if not load_pretrained:
+ if config.get('first_stage_path', '') != '':
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
+ print('Loading the first stage model at %s ...' % first_stage_path)
+ model, _, start_epoch, iters = load_checkpoint(model,
+ None,
+ first_stage_path,
+ load_only_params=True,
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
+
+ # these epochs should be counted from the start epoch
+ diff_epoch += start_epoch
+ joint_epoch += start_epoch
+ epochs += start_epoch
+
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
+ else:
+ raise ValueError('You need to specify the path to the first stage model.')
+
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
+ wl = WavLMLoss(model_params.slm.model,
+ model.wd,
+ sr,
+ model_params.slm.sr).to(device)
+
+ gl = MyDataParallel(gl)
+ dl = MyDataParallel(dl)
+ wl = MyDataParallel(wl)
+
+ sampler = DiffusionSampler(
+ model.diffusion.diffusion,
+ sampler=ADPM2Sampler(),
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
+ clamp=False
+ )
+
+ scheduler_params = {
+ "max_lr": optimizer_params.lr,
+ "pct_start": float(0),
+ "epochs": epochs,
+ "steps_per_epoch": len(train_dataloader),
+ }
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
+
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
+
+ # adjust BERT learning rate
+ for g in optimizer.optimizers['bert'].param_groups:
+ g['betas'] = (0.9, 0.99)
+ g['lr'] = optimizer_params.bert_lr
+ g['initial_lr'] = optimizer_params.bert_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 0.01
+
+ # adjust acoustic module learning rate
+ for module in ["decoder", "style_encoder"]:
+ for g in optimizer.optimizers[module].param_groups:
+ g['betas'] = (0.0, 0.99)
+ g['lr'] = optimizer_params.ft_lr
+ g['initial_lr'] = optimizer_params.ft_lr
+ g['min_lr'] = 0
+ g['weight_decay'] = 1e-4
+
+ # load models if there is a model
+ if load_pretrained:
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
+ load_only_params=config.get('load_only_params', True))
+
+ n_down = model.text_aligner.n_down
+
+ best_loss = float('inf') # best test loss
+ loss_train_record = list([])
+ loss_test_record = list([])
+ iters = 0
+
+ criterion = nn.L1Loss() # F0 loss (regression)
+ torch.cuda.empty_cache()
+
+ stft_loss = MultiResolutionSTFTLoss().to(device)
+
+ print('BERT', optimizer.optimizers['bert'])
+ print('decoder', optimizer.optimizers['decoder'])
+
+ start_ds = False
+
+ running_std = []
+
+ slmadv_params = Munch(config['slmadv_params'])
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
+ slmadv_params.min_len,
+ slmadv_params.max_len,
+ batch_percentage=slmadv_params.batch_percentage,
+ skip_update=slmadv_params.iter,
+ sig=slmadv_params.sig
+ )
+
+
+ for epoch in range(start_epoch, epochs):
+ running_loss = 0
+ start_time = time.time()
+
+ _ = [model[key].eval() for key in model]
+
+ model.predictor.train()
+ model.bert_encoder.train()
+ model.bert.train()
+ model.msd.train()
+ model.mpd.train()
+
+
+ if epoch >= diff_epoch:
+ start_ds = True
+
+ for i, batch in enumerate(train_dataloader):
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
+ mel_mask = length_to_mask(mel_input_length).to(device)
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ try:
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ except:
+ continue
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
+
+ # compute the style of the entire utterance
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
+ ss = []
+ gs = []
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
+ gs = torch.stack(gs).squeeze() # global acoustic styles
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ # denoiser training
+ if epoch >= diff_epoch:
+ num_steps = np.random.randint(3, 5)
+
+ if model_params.diffusion.dist.estimate_sigma_data:
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
+
+ if multispeaker:
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ features=ref, # reference from the same speaker as the embedding
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=1,
+ embedding_mask_proba=0.1,
+ num_steps=num_steps).squeeze(1)
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
+ else:
+ loss_sty = 0
+ loss_diff = 0
+
+
+
+ d, p = model.predictor(d_en, s_dur,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+ st = []
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ # style reference (better to be different from the GT)
+ random_start = np.random.randint(0, mel_length - mel_len_st)
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+ st = torch.stack(st).detach()
+
+ if gt.size(-1) < 80:
+ continue
+
+ s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
+
+ with torch.no_grad():
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
+
+ asr_real = model.text_aligner.get_feature(gt)
+
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec_gt = wav.unsqueeze(1)
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
+
+ if epoch >= joint_epoch:
+ # ground truth from recording
+ wav = y_rec_gt # use recording since decoder is tuned
+ else:
+ # ground truth from reconstruction
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
+
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
+
+ if start_ds:
+ optimizer.zero_grad()
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
+ d_loss.backward()
+ optimizer.step('msd')
+ optimizer.step('mpd')
+ else:
+ d_loss = 0
+
+ # generator loss
+ optimizer.zero_grad()
+
+ loss_mel = stft_loss(y_rec, wav)
+ if start_ds:
+ loss_gen_all = gl(wav, y_rec).mean()
+ else:
+ loss_gen_all = 0
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
+
+ loss_ce = 0
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for p in range(_s2s_trg.shape[0]):
+ _s2s_trg[p, :_text_input[p]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
+ _text_input[1:_text_length-1])
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
+
+ loss_ce /= texts.size(0)
+ loss_dur /= texts.size(0)
+
+ g_loss = loss_params.lambda_mel * loss_mel + \
+ loss_params.lambda_F0 * loss_F0_rec + \
+ loss_params.lambda_ce * loss_ce + \
+ loss_params.lambda_norm * loss_norm_rec + \
+ loss_params.lambda_dur * loss_dur + \
+ loss_params.lambda_gen * loss_gen_all + \
+ loss_params.lambda_slm * loss_lm + \
+ loss_params.lambda_sty * loss_sty + \
+ loss_params.lambda_diff * loss_diff
+
+ running_loss += loss_mel.item()
+ g_loss.backward()
+ # if torch.isnan(g_loss):
+ # from IPython.core.debugger import set_trace
+ # set_trace()
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('predictor_encoder')
+
+ if epoch >= diff_epoch:
+ optimizer.step('diffusion')
+
+ if epoch >= joint_epoch:
+ optimizer.step('style_encoder')
+ optimizer.step('decoder')
+
+ # randomly pick whether to use in-distribution text
+ if np.random.rand() < 0.5:
+ use_ind = True
+ else:
+ use_ind = False
+
+ if use_ind:
+ ref_lengths = input_lengths
+ ref_texts = texts
+
+ try:
+ slm_out = slmadv(i,
+ y_rec_gt,
+ y_rec_gt_pred,
+ waves,
+ mel_input_length,
+ ref_texts,
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
+
+ except:
+ print('oops! sounds like something has happened!')
+ slm_out = None
+
+ if slm_out is None:
+ continue
+
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
+
+ # SLM generator loss
+ optimizer.zero_grad()
+ loss_gen_lm.backward()
+
+ # compute the gradient norm
+ total_norm = {}
+ for key in model.keys():
+ total_norm[key] = 0
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
+ for p in parameters:
+ param_norm = p.grad.detach().data.norm(2)
+ total_norm[key] += param_norm.item() ** 2
+ total_norm[key] = total_norm[key] ** 0.5
+
+ # gradient scaling
+ if total_norm['predictor'] > slmadv_params.thresh:
+ for key in model.keys():
+ for p in model[key].parameters():
+ if p.grad is not None:
+ p.grad *= (1 / total_norm['predictor'])
+
+ for p in model.predictor.duration_proj.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ for p in model.predictor.lstm.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ for p in model.diffusion.parameters():
+ if p.grad is not None:
+ p.grad *= slmadv_params.scale
+
+ optimizer.step('bert_encoder')
+ optimizer.step('bert')
+ optimizer.step('predictor')
+ optimizer.step('diffusion')
+
+ # SLM discriminator loss
+ if d_loss_slm != 0:
+ optimizer.zero_grad()
+ d_loss_slm.backward(retain_graph=True)
+ optimizer.step('wd')
+
+ else:
+ d_loss_slm, loss_gen_lm = 0, 0
+
+ iters = iters + 1
+
+ if (i+1)%log_interval == 0:
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm))
+
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
+ writer.add_scalar('train/d_loss', d_loss, iters)
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
+
+ running_loss = 0
+
+ print('Time elasped:', time.time()-start_time)
+
+ if (i+1)%save_iter == 0:
+ print(f'Saving on step {epoch*len(train_dataloader)+i}...')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, f'2nd_phase_{epoch*len(train_dataloader)+i}.pth')
+ torch.save(state, save_path)
+
+ loss_test = 0
+ loss_align = 0
+ loss_f = 0
+ _ = [model[key].eval() for key in model]
+
+ with torch.no_grad():
+ iters_test = 0
+ for batch_idx, batch in enumerate(val_dataloader):
+ optimizer.zero_grad()
+
+ try:
+ waves = batch[0]
+ batch = [b.to(device) for b in batch[1:]]
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
+ with torch.no_grad():
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
+ text_mask = length_to_mask(input_lengths).to(texts.device)
+
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
+ s2s_attn = s2s_attn.transpose(-1, -2)
+ s2s_attn = s2s_attn[..., 1:]
+ s2s_attn = s2s_attn.transpose(-1, -2)
+
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
+
+ # encode
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
+ asr = (t_en @ s2s_attn_mono)
+
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
+
+ ss = []
+ gs = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+ mel = mels[bib, :, :mel_input_length[bib]]
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
+ ss.append(s)
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
+ gs.append(s)
+
+ s = torch.stack(ss).squeeze()
+ gs = torch.stack(gs).squeeze()
+ s_trg = torch.cat([s, gs], dim=-1).detach()
+
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+ d, p = model.predictor(d_en, s,
+ input_lengths,
+ s2s_attn_mono,
+ text_mask)
+ # get clips
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
+ en = []
+ gt = []
+ p_en = []
+ wav = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item() / 2)
+
+ random_start = np.random.randint(0, mel_length - mel_len)
+ en.append(asr[bib, :, random_start:random_start+mel_len])
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
+
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
+
+ y = waves[bib][(random_start * 2) * 512:((random_start+mel_len) * 2) * 512]
+ wav.append(torch.from_numpy(y).to(device))
+
+ wav = torch.stack(wav).float().detach()
+
+ en = torch.stack(en)
+ p_en = torch.stack(p_en)
+ gt = torch.stack(gt).detach()
+
+ s = model.predictor_encoder(gt.unsqueeze(1))
+
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
+
+ loss_dur = 0
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
+ _s2s_pred = _s2s_pred[:_text_length, :]
+ _text_input = _text_input[:_text_length].long()
+ _s2s_trg = torch.zeros_like(_s2s_pred)
+ for bib in range(_s2s_trg.shape[0]):
+ _s2s_trg[bib, :_text_input[bib]] = 1
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
+ _text_input[1:_text_length-1])
+
+ loss_dur /= texts.size(0)
+
+ s = model.style_encoder(gt.unsqueeze(1))
+
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
+
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
+
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
+
+ loss_test += (loss_mel).mean()
+ loss_align += (loss_dur).mean()
+ loss_f += (loss_F0).mean()
+
+ iters_test += 1
+ except Exception as e:
+ print(f"run into exception", e)
+ traceback.print_exc()
+ continue
+
+ print('Epochs:', epoch + 1)
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
+ print('\n\n\n')
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
+ writer.add_scalar('eval/dur_loss', loss_align / iters_test, epoch + 1)
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
+
+ if epoch < joint_epoch:
+ # generating reconstruction examples with GT duration
+
+ with torch.no_grad():
+ for bib in range(len(asr)):
+ mel_length = int(mel_input_length[bib].item())
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
+
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
+ F0_real = F0_real.unsqueeze(0)
+ s = model.style_encoder(gt.unsqueeze(1))
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
+
+ y_rec = model.decoder(en, F0_real, real_norm, s)
+
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
+ p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
+
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
+
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
+
+ writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+
+ if epoch == 0:
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
+
+ if bib >= 5:
+ break
+ else:
+ # generating sampled speech from text directly
+ with torch.no_grad():
+ # compute reference styles
+ if multispeaker and epoch >= diff_epoch:
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
+
+ for bib in range(len(d_en)):
+ if multispeaker:
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
+ embedding=bert_dur[bib].unsqueeze(0),
+ embedding_scale=1,
+ features=ref_s[bib].unsqueeze(0), # reference from the same speaker as the embedding
+ num_steps=5).squeeze(1)
+ else:
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
+ embedding=bert_dur[bib].unsqueeze(0),
+ embedding_scale=1,
+ num_steps=5).squeeze(1)
+
+ s = s_pred[:, 128:]
+ ref = s_pred[:, :128]
+
+ d = model.predictor.text_encoder(d_en[bib, :, :input_lengths[bib]].unsqueeze(0),
+ s, input_lengths[bib, ...].unsqueeze(0), text_mask[bib, :input_lengths[bib]].unsqueeze(0))
+
+ x = model.predictor.lstm(d)
+ x_mod = model.predictor.prepare_projection(x) # 640 -> 512
+ duration = model.predictor.duration_proj(x_mod)
+
+ duration = torch.sigmoid(duration).sum(axis=-1)
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
+
+ pred_dur[-1] += 5
+
+ pred_aln_trg = torch.zeros(input_lengths[bib], int(pred_dur.sum().data))
+ c_frame = 0
+ for i in range(pred_aln_trg.size(0)):
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
+ c_frame += int(pred_dur[i].data)
+
+ # encode prosody
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(texts.device))
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
+ out = model.decoder((t_en[bib, :, :input_lengths[bib]].unsqueeze(0) @ pred_aln_trg.unsqueeze(0).to(texts.device)),
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
+
+ writer.add_audio('pred/y' + str(bib), out.cpu().numpy().squeeze(), epoch, sample_rate=sr)
+
+ if bib >= 5:
+ break
+
+ if epoch % saving_epoch == 0:
+ if (loss_test / iters_test) < best_loss:
+ best_loss = loss_test / iters_test
+ print('Saving..')
+ state = {
+ 'net': {key: model[key].state_dict() for key in model},
+ 'optimizer': optimizer.state_dict(),
+ 'iters': iters,
+ 'val_loss': loss_test / iters_test,
+ 'epoch': epoch,
+ }
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
+ torch.save(state, save_path)
+
+ # if estimate sigma, save the estimated simga
+ if model_params.diffusion.dist.estimate_sigma_data:
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
+
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
+ yaml.dump(config, outfile, default_flow_style=True)
+
+if __name__=="__main__":
+ main()
diff --git a/stts_48khz/StyleTTS2_48khz/utils.py b/stts_48khz/StyleTTS2_48khz/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2206d9277879b5abbd6d29be86eb2f181a8c1db
--- /dev/null
+++ b/stts_48khz/StyleTTS2_48khz/utils.py
@@ -0,0 +1,74 @@
+from monotonic_align import maximum_path
+from monotonic_align import mask_from_lens
+from monotonic_align.core import maximum_path_c
+import numpy as np
+import torch
+import copy
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+import matplotlib.pyplot as plt
+from munch import Munch
+
+def maximum_path(neg_cent, mask):
+ """ Cython optimized version.
+ neg_cent: [b, t_t, t_s]
+ mask: [b, t_t, t_s]
+ """
+ device = neg_cent.device
+ dtype = neg_cent.dtype
+ neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
+ path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
+
+ t_t_max = np.ascontiguousarray(mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32))
+ t_s_max = np.ascontiguousarray(mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32))
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
+
+def get_data_path_list(train_path=None, val_path=None):
+ if train_path is None:
+ train_path = "Data/train_list.txt"
+ if val_path is None:
+ val_path = "Data/val_list.txt"
+
+ with open(train_path, 'r', encoding='utf-8', errors='ignore') as f:
+ train_list = f.readlines()
+ with open(val_path, 'r', encoding='utf-8', errors='ignore') as f:
+ val_list = f.readlines()
+
+ return train_list, val_list
+
+def length_to_mask(lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+# for norm consistency loss
+def log_norm(x, mean=-4, std=4, dim=2):
+ """
+ normalized log mel -> mel -> norm -> log(norm)
+ """
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
+ return x
+
+def get_image(arrs):
+ plt.switch_backend('agg')
+ fig = plt.figure()
+ ax = plt.gca()
+ ax.imshow(arrs)
+
+ return fig
+
+def recursive_munch(d):
+ if isinstance(d, dict):
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
+ elif isinstance(d, list):
+ return [recursive_munch(v) for v in d]
+ else:
+ return d
+
+def log_print(message, logger):
+ logger.info(message)
+ print(message)
+
\ No newline at end of file