{ "cells": [ { "cell_type": "code", "execution_count": 12, "id": "4effe69f", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import torch\n", "import torchaudio\n", "import random\n", "import numpy as np\n", "import torchaudio\n", "from omegaconf import OmegaConf\n", "from torch.nn import functional as F\n", "\n", "from cosyvoice.flow.decoder import ConditionalDecoder, CausalConditionalDecoder\n", "from cosyvoice.flow.flow import CausalMaskedDiffWithXvec\n", "from cosyvoice.flow.flow_matching import CausalConditionalCFM\n", "from cosyvoice.hifigan.f0_predictor import ConvRNNF0Predictor\n", "from cosyvoice.hifigan.generator import HiFTGenerator\n", "from cosyvoice.llm.llm import Qwen2Encoder, Qwen2LM\n", "from cosyvoice.tokenizer.tokenizer import get_qwen_tokenizer\n", "from cosyvoice.transformer.upsample_encoder import UpsampleConformerEncoder\n", "from cosyvoice.utils.common import ras_sampling\n", "\n", "# Set CUDA device\n", "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # Use GPU 0\n", "device = \"cuda:0\"\n", "\n", "\n", "def set_deterministic_behavior(seed=42):\n", " \"\"\"Set seeds for reproducibility across all random libraries\"\"\"\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False\n", " os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", "\n", "\n", "# Call this function at the beginning of your script\n", "set_deterministic_behavior(70000)" ] }, { "cell_type": "code", "execution_count": 13, "id": "0322c8f4", "metadata": {}, "outputs": [], "source": [ "model_dir = './pretrained_models/CosyVoice2-0.5B'\n", "allowed_special = 'all'\n", "sample_rate = 24000\n", "fp16 = False" ] }, { "cell_type": "code", "execution_count": null, "id": "a0ba457c", "metadata": {}, "outputs": [], "source": [ "llm_config = {\n", " 'llm_input_size': 896,\n", " 'llm_output_size': 896,\n", " 'speech_token_size': 6561,\n", " 'length_normalized_loss': True,\n", " 'lsm_weight': 0,\n", " 'mix_ratio': [5, 15]\n", "}\n", "\n", "llm_encoder_config = {\n", " 'pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')\n", "}\n", "\n", "sampling_config = {\n", " 'top_p': 0.8,\n", " 'top_k': 25,\n", " 'win_size': 10,\n", " 'tau_r': 0.1\n", "}\n", "\n", "flow_config = {\n", " 'input_size': 512,\n", " 'output_size': 80,\n", " 'spk_embed_dim': 192,\n", " 'output_type': 'mel',\n", " 'vocab_size': 6561,\n", " 'input_frame_rate': 25,\n", " 'only_mask_loss': True,\n", " 'token_mel_ratio': 2,\n", " 'pre_lookahead_len': 3\n", "}\n", "\n", "encoder_config = {\n", " 'output_size': 512,\n", " 'attention_heads': 8,\n", " 'linear_units': 2048,\n", " 'num_blocks': 6,\n", " 'dropout_rate': 0.1,\n", " 'positional_dropout_rate': 0.1,\n", " 'attention_dropout_rate': 0.1,\n", " 'normalize_before': True,\n", " 'input_layer': 'linear',\n", " 'pos_enc_layer_type': 'rel_pos_espnet',\n", " 'selfattention_layer_type': 'rel_selfattn',\n", " 'input_size': 512,\n", " 'use_cnn_module': False,\n", " 'macaron_style': False\n", "}\n", "\n", "decoder_config = {\n", " 'in_channels': 240,\n", " 'n_spks': 1,\n", " 'spk_emb_dim': 80,\n", " 'cfm_params': {\n", " 'sigma_min': 1e-06,\n", " 'solver': 'euler',\n", " 't_scheduler': 'cosine',\n", " 'training_cfg_rate': 0.2,\n", " 'inference_cfg_rate': 0.7,\n", " 'reg_loss_type': 'l1',\n", " 'use_immiscible': True,\n", " 'immiscible_k': 8,\n", " 'use_contrastive_fm': True,\n", " 'contrastive_lambda': 0.05,\n", " }\n", "}\n", "decoder_config['cfm_params'] = OmegaConf.create(decoder_config['cfm_params'])\n", "\n", "estimator_config = {\n", " 'in_channels': 320,\n", " 'out_channels': 80,\n", " 'channels': [256],\n", " 'dropout': 0.0,\n", " 'attention_head_dim': 64,\n", " 'n_blocks': 4,\n", " 'num_mid_blocks': 12,\n", " 'num_heads': 8,\n", " 'act_fn': 'gelu',\n", " 'static_chunk_size': 50,\n", " 'num_decoding_left_chunks': 2\n", " }\n", "\n", "f0_predictor_config = {\n", " 'num_class': 1,\n", " 'in_channels': 80,\n", " 'cond_channels': 512,\n", "}\n", "\n", "hift_config = {\n", " 'in_channels': 80,\n", " 'base_channels': 512,\n", " 'nb_harmonics': 8,\n", " 'sampling_rate': 24000,\n", " 'nsf_alpha': 0.1,\n", " 'nsf_sigma': 0.003,\n", " 'nsf_voiced_threshold': 10,\n", " 'upsample_rates': [8, 5, 3],\n", " 'upsample_kernel_sizes': [16, 11, 7],\n", " 'istft_params': {\n", " 'n_fft': 16,\n", " 'hop_len': 4,\n", " },\n", " 'resblock_kernel_sizes': [3, 7, 11],\n", " 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n", " 'source_resblock_kernel_sizes': [7, 7, 11],\n", " 'source_resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n", " 'lrelu_slope': 0.1,\n", " 'audio_limit': 0.99,\n", "}" ] }, { "cell_type": "code", "execution_count": 15, "id": "03fe8925", "metadata": {}, "outputs": [], "source": [ "llm_encoder = Qwen2Encoder(**llm_encoder_config)\n", "llm_model = Qwen2LM(llm=llm_encoder, **llm_config, sampling=ras_sampling)" ] }, { "cell_type": "code", "execution_count": 16, "id": "41bc6b44", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mas/anaconda3/envs/learnable/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n", " deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n" ] }, { "ename": "ConfigAttributeError", "evalue": "Missing key use_immiscible\n full_key: use_immiscible\n object_type=dict", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mConfigAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m flow_encoder \u001b[38;5;241m=\u001b[39m UpsampleConformerEncoder(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mencoder_config)\n\u001b[1;32m 2\u001b[0m estimator \u001b[38;5;241m=\u001b[39m CausalConditionalDecoder(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mestimator_config)\n\u001b[0;32m----> 3\u001b[0m flow_decoder \u001b[38;5;241m=\u001b[39m \u001b[43mCausalConditionalCFM\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdecoder_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m flow \u001b[38;5;241m=\u001b[39m CausalMaskedDiffWithXvec(\n\u001b[1;32m 5\u001b[0m encoder\u001b[38;5;241m=\u001b[39mflow_encoder,\n\u001b[1;32m 6\u001b[0m decoder\u001b[38;5;241m=\u001b[39mflow_decoder,\n\u001b[1;32m 7\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mflow_config\n\u001b[1;32m 8\u001b[0m )\n", "File \u001b[0;32m/data/learnable-speech/speech/cosyvoice/flow/flow_matching.py:329\u001b[0m, in \u001b[0;36mCausalConditionalCFM.__init__\u001b[0;34m(self, in_channels, cfm_params, n_spks, spk_emb_dim, estimator)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, in_channels, cfm_params, n_spks\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, spk_emb_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, estimator: torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 329\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43min_channels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcfm_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_spks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspk_emb_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 330\u001b[0m set_all_random_seed(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrand_noise \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn([\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m80\u001b[39m, \u001b[38;5;241m50\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m300\u001b[39m])\n", "File \u001b[0;32m/data/learnable-speech/speech/cosyvoice/flow/flow_matching.py:35\u001b[0m, in \u001b[0;36mConditionalCFM.__init__\u001b[0;34m(self, in_channels, cfm_params, n_spks, spk_emb_dim, estimator)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# Just change the architecture of the estimator here\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mestimator \u001b[38;5;241m=\u001b[39m estimator\n\u001b[0;32m---> 35\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_immiscible \u001b[38;5;241m=\u001b[39m \u001b[43mcfm_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_immiscible\u001b[49m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimmiscible_k \u001b[38;5;241m=\u001b[39m cfm_params\u001b[38;5;241m.\u001b[39mimmiscible_k\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlambda_weight \u001b[38;5;241m=\u001b[39m cfm_params\u001b[38;5;241m.\u001b[39mcontrastive_lambda\n", "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:355\u001b[0m, in \u001b[0;36mDictConfig.__getattr__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_impl(\n\u001b[1;32m 352\u001b[0m key\u001b[38;5;241m=\u001b[39mkey, default_value\u001b[38;5;241m=\u001b[39m_DEFAULT_MARKER_, validate_key\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 353\u001b[0m )\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConfigKeyError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_format_and_raise\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 356\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtype_override\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mConfigAttributeError\u001b[49m\n\u001b[1;32m 357\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 358\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_and_raise(key\u001b[38;5;241m=\u001b[39mkey, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, cause\u001b[38;5;241m=\u001b[39me)\n", "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/base.py:231\u001b[0m, in \u001b[0;36mNode._format_and_raise\u001b[0;34m(self, key, value, cause, msg, type_override)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_format_and_raise\u001b[39m(\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 225\u001b[0m key: Any,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 229\u001b[0m type_override: Any \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 230\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 231\u001b[0m \u001b[43mformat_and_raise\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 232\u001b[0m \u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 234\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 235\u001b[0m \u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcause\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcause\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 237\u001b[0m \u001b[43m \u001b[49m\u001b[43mtype_override\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtype_override\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 238\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 239\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n", "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/_utils.py:899\u001b[0m, in \u001b[0;36mformat_and_raise\u001b[0;34m(node, key, value, msg, cause, type_override)\u001b[0m\n\u001b[1;32m 896\u001b[0m ex\u001b[38;5;241m.\u001b[39mref_type \u001b[38;5;241m=\u001b[39m ref_type\n\u001b[1;32m 897\u001b[0m ex\u001b[38;5;241m.\u001b[39mref_type_str \u001b[38;5;241m=\u001b[39m ref_type_str\n\u001b[0;32m--> 899\u001b[0m \u001b[43m_raise\u001b[49m\u001b[43m(\u001b[49m\u001b[43mex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/_utils.py:797\u001b[0m, in \u001b[0;36m_raise\u001b[0;34m(ex, cause)\u001b[0m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 796\u001b[0m ex\u001b[38;5;241m.\u001b[39m__cause__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 797\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ex\u001b[38;5;241m.\u001b[39mwith_traceback(sys\u001b[38;5;241m.\u001b[39mexc_info()[\u001b[38;5;241m2\u001b[39m])\n", "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:351\u001b[0m, in \u001b[0;36mDictConfig.__getattr__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m()\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdefault_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_DEFAULT_MARKER_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConfigKeyError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_and_raise(\n\u001b[1;32m 356\u001b[0m key\u001b[38;5;241m=\u001b[39mkey, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, cause\u001b[38;5;241m=\u001b[39me, type_override\u001b[38;5;241m=\u001b[39mConfigAttributeError\n\u001b[1;32m 357\u001b[0m )\n", "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:442\u001b[0m, in \u001b[0;36mDictConfig._get_impl\u001b[0;34m(self, key, default_value, validate_key)\u001b[0m\n\u001b[1;32m 438\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_get_impl\u001b[39m(\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28mself\u001b[39m, key: DictKeyType, default_value: Any, validate_key: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 440\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 442\u001b[0m node \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_child\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 443\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_key\u001b[49m\n\u001b[1;32m 444\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConfigAttributeError, ConfigKeyError):\n\u001b[1;32m 446\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m default_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _DEFAULT_MARKER_:\n", "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/basecontainer.py:73\u001b[0m, in \u001b[0;36mBaseContainer._get_child\u001b[0;34m(self, key, validate_access, validate_key, throw_on_missing_value, throw_on_missing_key)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_get_child\u001b[39m(\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 66\u001b[0m key: Any,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 70\u001b[0m throw_on_missing_key: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 71\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Optional[Node], List[Optional[Node]]]:\n\u001b[1;32m 72\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Like _get_node, passing through to the nearest concrete Node.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 73\u001b[0m child \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_node\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidate_access\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_access\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 76\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mthrow_on_missing_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mthrow_on_missing_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m \u001b[49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(child, UnionNode) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_special(child):\n\u001b[1;32m 81\u001b[0m value \u001b[38;5;241m=\u001b[39m child\u001b[38;5;241m.\u001b[39m_value()\n", "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:480\u001b[0m, in \u001b[0;36mDictConfig._get_node\u001b[0;34m(self, key, validate_access, validate_key, throw_on_missing_value, throw_on_missing_key)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m throw_on_missing_key:\n\u001b[0;32m--> 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ConfigKeyError(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMissing key \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m!s}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 481\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m throw_on_missing_value \u001b[38;5;129;01mand\u001b[39;00m value\u001b[38;5;241m.\u001b[39m_is_missing():\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MissingMandatoryValue(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMissing mandatory value: $KEY\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "\u001b[0;31mConfigAttributeError\u001b[0m: Missing key use_immiscible\n full_key: use_immiscible\n object_type=dict" ] } ], "source": [ "flow_encoder = UpsampleConformerEncoder(**encoder_config)\n", "estimator = CausalConditionalDecoder(**estimator_config)\n", "flow_decoder = CausalConditionalCFM(**decoder_config, estimator=estimator)\n", "flow = CausalMaskedDiffWithXvec(\n", " encoder=flow_encoder,\n", " decoder=flow_decoder,\n", " **flow_config\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6f689e0b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['utts', 'speech_token', 'speech_token_len', 'speech_feat', 'speech_feat_len', 'text', 'text_token', 'text_token_len', 'utt_embedding', 'spk_embedding', 'embedding'])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f0_predictor = ConvRNNF0Predictor(**f0_predictor_config)\n", "hifi = HiFTGenerator(**hift_config, f0_predictor=f0_predictor)" ] }, { "cell_type": "code", "execution_count": 30, "id": "cfbef316", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(47, dtype=torch.int32),\n", " tensor([47, 50, 49, 49, 49, 48, 48, 48, 48, 47, 43, 47, 46, 46, 46, 45, 45, 45,\n", " 45, 43], dtype=torch.int32))" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data['speech_token_len'][0], data['speech_token_len']" ] }, { "cell_type": "code", "execution_count": 31, "id": "d0942196", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(20, 20, 20)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(data['utts']), len(data['text']), len(data['speech_token_len'])" ] }, { "cell_type": "code", "execution_count": 35, "id": "622100eb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([20]),\n", " torch.Size([20]),\n", " torch.Size([20, 192]),\n", " torch.Size([20, 98, 80]),\n", " torch.Size([20, 192]),\n", " torch.Size([20]),\n", " torch.Size([20, 192]))" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data['speech_token_len'].shape, data['speech_token_len'].shape, data['spk_embedding'].shape, data['speech_feat'].shape, data['embedding'].shape, data['speech_feat_len'].shape, data['embedding'].shape" ] }, { "cell_type": "code", "execution_count": 37, "id": "0adc02f8", "metadata": {}, "outputs": [], "source": [ "token_len = data['speech_token_len']" ] }, { "cell_type": "code", "execution_count": 38, "id": "7aea884b", "metadata": {}, "outputs": [], "source": [ "from cosyvoice.utils.mask import make_pad_mask\n", "mask = (~make_pad_mask(token_len)).float().unsqueeze(-1)" ] }, { "cell_type": "code", "execution_count": 39, "id": "45422efa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([20, 50, 1])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mask.shape" ] }, { "cell_type": "code", "execution_count": 40, "id": "0f2b0b77", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([47, 50, 49, 49, 49, 48, 48, 48, 48, 47, 43, 47, 46, 46, 46, 45, 45, 45,\n", " 45, 43], dtype=torch.int32)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "token_len" ] }, { "cell_type": "code", "execution_count": 5, "id": "2dcfa795", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing ResumableSequentialLR:\n", "--------------------------------------------------\n", "Step LR Expected Match \n", "--------------------------------------------------\n", "0 1.000000e-04 1.000000e-04 ✓ \n", "1 2.800000e-04 2.800000e-04 ✓ \n", "2 4.600000e-04 4.600000e-04 ✓ \n", "3 6.400000e-04 6.400000e-04 ✓ \n", "4 8.200000e-04 8.200000e-04 ✓ \n", "5 1.000000e-03 1.000000e-03 ✓ \n", "6 1.000000e-03 1.000000e-03 ✓ \n", "7 1.000000e-03 1.000000e-03 ✓ \n", "8 1.000000e-03 1.000000e-03 ✓ \n", "9 1.000000e-03 1.000000e-03 ✓ \n", "\n", "Testing resume from step 7:\n", "--------------------------------------------------\n", "7 1.000000e-03 1.000000e-03 ✓ \n", "8 1.000000e-03 1.000000e-03 ✓ \n", "9 1.000000e-03 1.000000e-03 ✓ \n" ] } ], "source": [ "from torch.optim.lr_scheduler import _LRScheduler\n", "import warnings\n", "\n", "class ResumableSequentialLR(_LRScheduler):\n", " \"\"\"A resumable version of SequentialLR that properly manages child schedulers\"\"\"\n", " \n", " def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):\n", " \"\"\"\n", " Args:\n", " optimizer: Wrapped optimizer\n", " schedulers: List of schedulers to sequentially use\n", " milestones: List of epoch/step numbers when to switch schedulers\n", " last_epoch: The index of last epoch/step\n", " \"\"\"\n", " # Validate inputs\n", " if len(schedulers) != len(milestones) + 1:\n", " raise ValueError(\"Expected len(schedulers) == len(milestones) + 1\")\n", " \n", " self.schedulers = schedulers\n", " self.milestones = milestones\n", " self._scheduler_idx = 0\n", " \n", " # Initialize parent class (this sets last_epoch and calls step())\n", " super().__init__(optimizer, last_epoch)\n", " \n", " def _get_scheduler_info(self, epoch):\n", " \"\"\"Determine which scheduler to use and its relative epoch\"\"\"\n", " scheduler_idx = 0\n", " relative_epoch = epoch\n", " \n", " for i, milestone in enumerate(self.milestones):\n", " if epoch >= milestone:\n", " scheduler_idx = i + 1\n", " if i == 0:\n", " relative_epoch = epoch - milestone\n", " else:\n", " relative_epoch = epoch - milestone\n", " else:\n", " break\n", " \n", " # Calculate relative epoch for the current scheduler\n", " if scheduler_idx == 0:\n", " relative_epoch = epoch\n", " elif scheduler_idx < len(self.milestones):\n", " if scheduler_idx == 1:\n", " relative_epoch = epoch - self.milestones[0]\n", " else:\n", " relative_epoch = epoch - self.milestones[scheduler_idx - 1]\n", " \n", " return scheduler_idx, relative_epoch\n", " \n", " def get_lr(self):\n", " \"\"\"Get learning rate from the appropriate scheduler\"\"\"\n", " if not self._get_lr_called_within_step:\n", " warnings.warn(\"To get the last learning rate computed by the scheduler, \"\n", " \"please use `get_last_lr()`.\", UserWarning)\n", " \n", " # Get current scheduler and its relative epoch\n", " scheduler_idx, relative_epoch = self._get_scheduler_info(self.last_epoch)\n", " scheduler = self.schedulers[scheduler_idx]\n", " \n", " # Set the scheduler's last_epoch to match relative progress\n", " scheduler.last_epoch = relative_epoch\n", " \n", " # Get LR from the scheduler\n", " if hasattr(scheduler, '_get_closed_form_lr'):\n", " return scheduler._get_closed_form_lr()\n", " else:\n", " # Temporarily set the flag to avoid warning from child scheduler\n", " scheduler._get_lr_called_within_step = True\n", " lrs = scheduler.get_lr()\n", " scheduler._get_lr_called_within_step = False\n", " return lrs\n", " \n", " def step(self, epoch=None):\n", " \"\"\"Step the scheduler\"\"\"\n", " # Step the parent class (updates last_epoch and sets _get_lr_called_within_step)\n", " super().step(epoch)\n", " \n", " def set_step(self, step):\n", " \"\"\"Set the current step for resuming training\"\"\"\n", " self.last_epoch = step - 1\n", " \n", " # Update child schedulers' state\n", " scheduler_idx, relative_epoch = self._get_scheduler_info(step - 1)\n", " \n", " # Set all previous schedulers to their final state\n", " for i in range(scheduler_idx):\n", " if i < len(self.milestones):\n", " if i == 0:\n", " self.schedulers[i].last_epoch = self.milestones[i] - 1\n", " else:\n", " self.schedulers[i].last_epoch = self.milestones[i] - self.milestones[i-1] - 1\n", " \n", " # Set current scheduler to its relative position\n", " self.schedulers[scheduler_idx].last_epoch = relative_epoch\n", " \n", " # Update optimizer's learning rates\n", " for param_group, lr in zip(self.optimizer.param_groups, self.get_last_lr()):\n", " param_group['lr'] = lr\n", "\n", "\n", "# Alternative simpler implementation that's more robust\n", "class SimpleResumableSequentialLR(_LRScheduler):\n", " \"\"\"Simpler implementation that manually tracks scheduler states\"\"\"\n", " \n", " def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):\n", " self.schedulers = schedulers\n", " self.milestones = milestones\n", " super().__init__(optimizer, last_epoch)\n", " \n", " def get_lr(self):\n", " \"\"\"Calculate learning rate based on current epoch\"\"\"\n", " epoch = self.last_epoch\n", " \n", " # For LinearLR with warmup\n", " if epoch < self.milestones[0]:\n", " # We're in warmup phase\n", " warmup_scheduler = self.schedulers[0]\n", " start_factor = warmup_scheduler.start_factor\n", " end_factor = warmup_scheduler.end_factor\n", " total_iters = warmup_scheduler.total_iters\n", " \n", " # Calculate factor\n", " if epoch >= total_iters:\n", " factor = end_factor\n", " else:\n", " factor = start_factor + (end_factor - start_factor) * epoch / total_iters\n", " \n", " # Apply factor to base learning rates\n", " return [base_lr * factor for base_lr in self.base_lrs]\n", " else:\n", " # We're in constant phase - just return base LRs\n", " return [base_lr * 1.0 for base_lr in self.base_lrs]\n", "\n", "\n", "# Test function to verify the scheduler works correctly\n", "def test_resumable_scheduler():\n", " \"\"\"Test the ResumableSequentialLR implementation\"\"\"\n", " import torch\n", " import torch.optim as optim\n", " from torch.optim.lr_scheduler import LinearLR, ConstantLR\n", " \n", " # Create dummy model and optimizer\n", " model = torch.nn.Linear(10, 1)\n", " base_lr = 1e-3\n", " optimizer = optim.Adam(model.parameters(), lr=base_lr)\n", " \n", " # Create schedulers\n", " warmup_steps = 5\n", " warmup_scheduler = LinearLR(\n", " optimizer,\n", " start_factor=0.1,\n", " end_factor=1.0,\n", " total_iters=warmup_steps\n", " )\n", " \n", " constant_scheduler = ConstantLR(\n", " optimizer,\n", " factor=1.0,\n", " total_iters=float('inf')\n", " )\n", " \n", " # Test both implementations\n", " print(\"Testing ResumableSequentialLR:\")\n", " print(\"-\" * 50)\n", " \n", " # Reset optimizer\n", " for param_group in optimizer.param_groups:\n", " param_group['lr'] = base_lr\n", " \n", " scheduler = ResumableSequentialLR(\n", " optimizer,\n", " schedulers=[warmup_scheduler, constant_scheduler],\n", " milestones=[warmup_steps]\n", " )\n", " \n", " print(f\"{'Step':<10} {'LR':<15} {'Expected':<15} {'Match':<10}\")\n", " print(\"-\" * 50)\n", " \n", " for step in range(10):\n", " current_lr = optimizer.param_groups[0]['lr']\n", " \n", " # Calculate expected LR\n", " if step < warmup_steps:\n", " expected_lr = base_lr * (0.1 + 0.9 * step / warmup_steps)\n", " else:\n", " expected_lr = base_lr\n", " \n", " match = \"✓\" if abs(current_lr - expected_lr) < 1e-10 else \"✗\"\n", " print(f\"{step:<10} {current_lr:<15.6e} {expected_lr:<15.6e} {match:<10}\")\n", " \n", " scheduler.step()\n", " \n", " # Test resuming\n", " print(\"\\nTesting resume from step 7:\")\n", " print(\"-\" * 50)\n", " \n", " # Reset and jump to step 7\n", " for param_group in optimizer.param_groups:\n", " param_group['lr'] = base_lr\n", " \n", " scheduler = ResumableSequentialLR(\n", " optimizer,\n", " schedulers=[warmup_scheduler, constant_scheduler],\n", " milestones=[warmup_steps]\n", " )\n", " scheduler.set_step(7)\n", " \n", " for step in range(7, 10):\n", " scheduler.step()\n", " current_lr = optimizer.param_groups[0]['lr']\n", " expected_lr = base_lr # Should be constant phase\n", " match = \"✓\" if abs(current_lr - expected_lr) < 1e-10 else \"✗\"\n", " print(f\"{step:<10} {current_lr:<15.6e} {expected_lr:<15.6e} {match:<10}\")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " test_resumable_scheduler()" ] }, { "cell_type": "code", "execution_count": null, "id": "ce71bea4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "42b9b936", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "id": "e3d4d5a1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=== Learning Rate Source Verification ===\n", "\n", "Comparing LR sources during warmup:\n", "\n", "Step Optimizer LR Scheduler LR Match? \n", "--------------------------------------------------\n", "0 1.00e-04 1.00e-04 ✓ \n", "1 2.80e-04 2.80e-04 ✓ \n", "2 4.60e-04 4.60e-04 ✓ \n", "3 6.40e-04 6.40e-04 ✓ \n", "4 8.20e-04 8.20e-04 ✓ \n", "5 1.00e-03 1.00e-03 ✓ \n", "6 1.00e-03 1.00e-03 ✓ \n", "7 1.00e-03 1.00e-03 ✓ \n", "8 1.00e-03 1.00e-03 ✓ \n", "9 1.00e-03 1.00e-03 ✓ \n", "\n", "Conclusion: optimizer.param_groups[0]['lr'] is the authoritative source!\n", "\n", "\n", "Manual LR change test:\n", "Current optimizer LR: 1.00e-03\n", "After manual change: 1.00e-02\n", "This confirms the optimizer holds the actual LR being used.\n", "\n", "==================================================\n", "\n", "\n", "Different ways to access learning rate:\n", "\n", "Initial state:\n", " optimizer.param_groups[0]['lr']: 1.00e-04\n", " scheduler.get_last_lr(): 1.00e-04\n", "\n", "After scheduler.step():\n", " optimizer.param_groups[0]['lr']: 2.80e-04\n", " scheduler.get_last_lr(): 2.80e-04\n", "\n", "Key insights:\n", "1. optimizer.param_groups[0]['lr'] - Always current, used by optimizer\n", "2. scheduler.get_last_lr() - What scheduler set on last step()\n", "3. scheduler.get_lr() - Internal method, calculates next LR (don't use directly)\n", "\n", "==================================================\n", "\n", "\n", "Multiple parameter groups:\n", " Group 0: lr = 1.00e-03\n", " Group 1: lr = 1.00e-04\n", "\n", "After scheduler step:\n", " Group 0: lr = 2.80e-04\n", " Group 1: lr = 2.80e-05\n" ] } ], "source": [ "import torch\n", "import torch.optim as optim\n", "from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR\n", "\n", "def verify_lr_sources():\n", " \"\"\"Verify that optimizer.param_groups[0]['lr'] is the correct source\"\"\"\n", " \n", " # Create a simple model and optimizer\n", " model = torch.nn.Linear(10, 1)\n", " optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", " \n", " # Create schedulers\n", " warmup_scheduler = LinearLR(\n", " optimizer,\n", " start_factor=0.1, # Start at 10% of base LR\n", " end_factor=1.0, # End at 100% of base LR\n", " total_iters=5 # 5 warmup steps\n", " )\n", " \n", " constant_scheduler = ConstantLR(\n", " optimizer,\n", " factor=1.0,\n", " total_iters=float('inf')\n", " )\n", " \n", " scheduler = SequentialLR(\n", " optimizer,\n", " schedulers=[warmup_scheduler, constant_scheduler],\n", " milestones=[5]\n", " )\n", " \n", " print(\"Comparing LR sources during warmup:\\n\")\n", " print(f\"{'Step':<6} {'Optimizer LR':<15} {'Scheduler LR':<15} {'Match?':<10}\")\n", " print(\"-\" * 50)\n", " \n", " for step in range(10):\n", " # Get LR from optimizer\n", " optimizer_lr = optimizer.param_groups[0]['lr']\n", " \n", " # Get LR from scheduler (if available)\n", " # Note: scheduler.get_last_lr() returns the LR after the last step\n", " scheduler_lr = scheduler.get_last_lr()[0] if hasattr(scheduler, 'get_last_lr') else None\n", " \n", " # Print comparison\n", " match = \"✓\" if scheduler_lr is None or abs(optimizer_lr - scheduler_lr) < 1e-10 else \"✗\"\n", " print(f\"{step:<6} {optimizer_lr:<15.2e} {scheduler_lr:<15.2e} {match:<10}\")\n", " \n", " # Step the scheduler\n", " scheduler.step()\n", " \n", " print(\"\\nConclusion: optimizer.param_groups[0]['lr'] is the authoritative source!\")\n", " \n", " # Additional verification: what happens if we manually change the optimizer's LR?\n", " print(\"\\n\\nManual LR change test:\")\n", " print(f\"Current optimizer LR: {optimizer.param_groups[0]['lr']:.2e}\")\n", " \n", " # Manually change it\n", " for param_group in optimizer.param_groups:\n", " param_group['lr'] = 0.01\n", " \n", " print(f\"After manual change: {optimizer.param_groups[0]['lr']:.2e}\")\n", " print(\"This confirms the optimizer holds the actual LR being used.\")\n", "\n", "\n", "def compare_lr_access_methods():\n", " \"\"\"Compare different ways to access the learning rate\"\"\"\n", " \n", " model = torch.nn.Linear(10, 1)\n", " optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", " \n", " scheduler = LinearLR(\n", " optimizer,\n", " start_factor=0.1,\n", " end_factor=1.0,\n", " total_iters=5\n", " )\n", " \n", " print(\"\\nDifferent ways to access learning rate:\\n\")\n", " \n", " # Before any steps\n", " print(\"Initial state:\")\n", " print(f\" optimizer.param_groups[0]['lr']: {optimizer.param_groups[0]['lr']:.2e}\")\n", " print(f\" scheduler.get_last_lr(): {scheduler.get_last_lr()[0]:.2e}\")\n", " \n", " # After stepping\n", " scheduler.step()\n", " print(\"\\nAfter scheduler.step():\")\n", " print(f\" optimizer.param_groups[0]['lr']: {optimizer.param_groups[0]['lr']:.2e}\")\n", " print(f\" scheduler.get_last_lr(): {scheduler.get_last_lr()[0]:.2e}\")\n", " \n", " # Key insight\n", " print(\"\\nKey insights:\")\n", " print(\"1. optimizer.param_groups[0]['lr'] - Always current, used by optimizer\")\n", " print(\"2. scheduler.get_last_lr() - What scheduler set on last step()\")\n", " print(\"3. scheduler.get_lr() - Internal method, calculates next LR (don't use directly)\")\n", "\n", "\n", "def check_multiple_param_groups():\n", " \"\"\"Check how LR works with multiple parameter groups\"\"\"\n", " \n", " model = torch.nn.Sequential(\n", " torch.nn.Linear(10, 20),\n", " torch.nn.Linear(20, 1)\n", " )\n", " \n", " # Different LRs for different layers\n", " optimizer = optim.Adam([\n", " {'params': model[0].parameters(), 'lr': 1e-3},\n", " {'params': model[1].parameters(), 'lr': 1e-4}\n", " ])\n", " \n", " print(\"\\nMultiple parameter groups:\")\n", " for i, param_group in enumerate(optimizer.param_groups):\n", " print(f\" Group {i}: lr = {param_group['lr']:.2e}\")\n", " \n", " # Scheduler affects all groups\n", " scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=5)\n", " scheduler.step()\n", " \n", " print(\"\\nAfter scheduler step:\")\n", " for i, param_group in enumerate(optimizer.param_groups):\n", " print(f\" Group {i}: lr = {param_group['lr']:.2e}\")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " print(\"=== Learning Rate Source Verification ===\\n\")\n", " verify_lr_sources()\n", " print(\"\\n\" + \"=\"*50 + \"\\n\")\n", " compare_lr_access_methods()\n", " print(\"\\n\" + \"=\"*50 + \"\\n\")\n", " check_multiple_param_groups()" ] }, { "cell_type": "code", "execution_count": null, "id": "918d3322", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "eb19ac5e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7f2c3038", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4f528b78", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "f0fcea90", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "bb5de4ae", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "fbf1de4d", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "learnable", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }