{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "extend_to = None #<= CHANGE THIS. The original is 178 symbols\n", "\n", "save_path = \"./Extend/New_Weights\"\n", "config_path = \"./Configs/config.yaml\"\n", "model_path = \"./Models/Finetune/base_model.pth\"\n", "\n", "#⚠️ Must run this notebook first before adding any symbol to the config file" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load packages\n", "%cd ..\n", "import yaml\n", "import torch\n", "from torch import nn\n", "import os\n", "from models import *\n", "from utils import *\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "device = 'cpu'\n", "\n", "config = yaml.safe_load(open(config_path, \"r\", encoding=\"utf-8\"))\n", "try:\n", " symbols = (\n", " list(config['symbol']['pad']) +\n", " list(config['symbol']['punctuation']) +\n", " list(config['symbol']['letters']) +\n", " list(config['symbol']['letters_ipa']) +\n", " list(config['symbol']['extend'])\n", " )\n", " symbol_dict = {}\n", " for i in range(len((symbols))):\n", " symbol_dict[symbols[i]] = i\n", "\n", " n_token = len(symbol_dict) + 1\n", " print(\"\\nFound\", n_token, \"symbols in the original config file\")\n", "except Exception as e:\n", " print(f\"\\nERROR: Cannot find {e} in config file!\\nYour config file is likely outdated, please download updated version from the repository.\")\n", " raise SystemExit(1)\n", "\n", "if (extend_to-n_token) <= 0:\n", " print(f\"\\nERROR: Cannot extend from {n_token} to {extend_to}.\")\n", " raise SystemExit(1)\n", "\n", "model_params = recursive_munch(config['model_params'])\n", "model_params['n_token'] = n_token\n", "model = build_model(model_params)\n", "\n", "keys_to_keep = {'predictor', 'decoder', 'text_encoder', 'style_encoder', 'text_aligner', 'pitch_extractor', 'mpd', 'msd'}\n", "params_whole = torch.load(model_path, map_location='cpu')\n", "params = params_whole['net']\n", "params = {key: value for key, value in params.items() if key in keys_to_keep}\n", "\n", "for key in list(model.keys()):\n", " if key not in keys_to_keep:\n", " del model[key]\n", "\n", "for key in model:\n", " if key in params:\n", " print('%s loaded' % key)\n", " try:\n", " model[key].load_state_dict(params[key])\n", " except:\n", " from collections import OrderedDict\n", " state_dict = params[key]\n", " new_state_dict = OrderedDict()\n", " for k, v in state_dict.items():\n", " name = k[7:] # remove `module.`\n", " new_state_dict[name] = v\n", " # load params\n", " model[key].load_state_dict(new_state_dict, strict=False)\n", "\n", "old_weight = [\n", " model['text_encoder'].embedding,\n", " model['text_aligner'].ctc_linear[2].linear_layer,\n", " model['text_aligner'].asr_s2s.embedding,\n", " model['text_aligner'].asr_s2s.project_to_n_symbols\n", "]\n", "print(\"\\nOld shape:\") \n", "for module in old_weight:\n", " print(module, module.weight.shape)\n", "\n", "for i in range(len(old_weight)):\n", " new_shape = (extend_to, old_weight[i].weight.shape[1])\n", " new_weight = torch.randn(new_shape) * 0.01 #init mean=0, std=0.01\n", " with torch.no_grad():\n", " new_weight[:old_weight[i].weight.size(0), :] = old_weight[i].weight.detach().clone()\n", " new_param = nn.Parameter(new_weight, requires_grad=True)\n", "\n", " if isinstance(old_weight[i], nn.Embedding):\n", " old_weight[i].num_embeddings = extend_to\n", " \n", " if isinstance(old_weight[i], nn.Linear):\n", " old_weight[i].out_features = extend_to\n", " #update bias\n", " old_bias = old_weight[i].bias.detach()\n", " old_dim = old_bias.shape[0]\n", " new_bias = torch.zeros(extend_to)\n", " new_bias[:old_dim] = old_bias.clone()\n", " old_weight[i].bias.data = new_bias\n", "\n", " old_weight[i].weight = new_param\n", "\n", "print(\"\\nNew shape:\")\n", "for module in old_weight:\n", " print(module, module.weight.shape)\n", "\n", "if not os.path.exists(save_path):\n", " os.mkdir(save_path)\n", "\n", "print(f\"\\n\\n✅ Successfully extended the token set to a maximum of {extend_to} symbols.\")\n", "print(f\"You can now add {extend_to - n_token} additional symbols in the config file.\")\n", "\n", "#save new weights\n", "state = {\n", " 'net': {key: model[key].state_dict() for key in model}, \n", " 'optimizer': None,\n", " 'iters': 0,\n", " 'val_loss': 0,\n", " 'epoch': 0,\n", "}\n", "torch.save(state, os.path.join(save_path, 'extended.pth'))" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 2 }