primepake commited on
Commit
d1b8469
·
1 Parent(s): 5805255

update notebook

Browse files
Files changed (4) hide show
  1. speech/.gitignore +0 -52
  2. speech/config.yaml +5 -4
  3. speech/dev.ipynb +717 -65
  4. speech/test_train.sh +2 -4
speech/.gitignore DELETED
@@ -1,52 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # Visual Studio Code files
7
- .vscode
8
- .vs
9
-
10
- # PyCharm files
11
- .idea
12
-
13
- # Eclipse Project settings
14
- *.*project
15
- .settings
16
-
17
- # Sublime Text settings
18
- *.sublime-workspace
19
- *.sublime-project
20
-
21
- # Editor temporaries
22
- *.swn
23
- *.swo
24
- *.swp
25
- *.swm
26
- *~
27
-
28
- # IPython notebook checkpoints
29
- .ipynb_checkpoints
30
-
31
- # macOS dir files
32
- .DS_Store
33
-
34
- exp
35
- data
36
- raw_wav
37
- tensorboard
38
- **/*build*
39
-
40
- # Clangd files
41
- .cache
42
- compile_commands.json
43
-
44
- # train/inference files
45
- *.wav
46
- *.m4a
47
- *.aac
48
- *.pt
49
- pretrained_models/*
50
- *_pb2_grpc.py
51
- *_pb2.py
52
- *.tar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
speech/config.yaml CHANGED
@@ -200,12 +200,13 @@ data_pipeline: [
200
  train_conf:
201
  optim: adamw
202
  optim_conf:
203
- lr: 2e-6 # change to 1e-5 during sft
204
  scheduler: constantlr # change to constantlr during sft
205
  scheduler_conf:
206
- warmup_steps: 2500
207
- max_epoch: 200
208
  grad_clip: 1
209
  accum_grad: 1
210
  log_interval: 5
211
- save_per_step: -1
 
 
200
  train_conf:
201
  optim: adamw
202
  optim_conf:
203
+ lr: 5e-5 # change to 1e-5 during sft
204
  scheduler: constantlr # change to constantlr during sft
205
  scheduler_conf:
206
+ warmup_steps: 500
207
+ max_epoch: 2000
208
  grad_clip: 1
209
  accum_grad: 1
210
  log_interval: 5
211
+ save_per_step: 2000
212
+ total_iters: 1000000000
speech/dev.ipynb CHANGED
@@ -2,111 +2,243 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "id": "4effe69f",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
- "from __future__ import print_function\n",
11
- "\n",
12
- "import argparse\n",
13
- "import datetime\n",
14
  "import os\n",
15
- "from copy import deepcopy\n",
16
- "\n",
17
- "import deepspeed\n",
18
  "import torch\n",
19
- "import torch.distributed as dist\n",
20
- "from hyperpyyaml import load_hyperpyyaml\n",
21
- "from loguru import logger\n",
22
- "from torch.distributed.elastic.multiprocessing.errors import record\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  "\n",
24
- "from comet_ml import Experiment\n",
25
- "from cosyvoice.utils.executor import Executor\n",
26
- "from cosyvoice.utils.losses import DPOLoss\n",
27
- "from cosyvoice.utils.train_utils import (check_modify_and_save_config,\n",
28
- " init_dataset_and_dataloader,\n",
29
- " init_optimizer_and_scheduler,\n",
30
- " save_model)"
 
 
 
 
 
 
31
  ]
32
  },
33
  {
34
  "cell_type": "code",
35
- "execution_count": 3,
36
  "id": "0322c8f4",
37
  "metadata": {},
38
- "outputs": [
39
- {
40
- "name": "stderr",
41
- "output_type": "stream",
42
- "text": [
43
- "/home/mas/anaconda3/envs/learnable/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n",
44
- " deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n",
45
- "2025-07-14 13:59:59,637 INFO input frame rate=25\n"
46
- ]
47
- }
48
- ],
49
  "source": [
50
- "override_dict = {\n",
51
- " k: None for k in [\"llm\", \"flow\", \"hift\", \"hifigan\"] if k != 'flow'\n",
52
- "}\n",
53
- "config = 'cosyvoice2.yaml'\n",
54
- "qwen_pretrain_path = './pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN'\n",
55
- "try:\n",
56
- " with open(config, \"r\", encoding=\"utf-8\") as f:\n",
57
- " configs = load_hyperpyyaml(\n",
58
- " f,\n",
59
- " overrides={\n",
60
- " **override_dict,\n",
61
- " \"qwen_pretrain_path\": qwen_pretrain_path,\n",
62
- " },\n",
63
- " )\n",
64
- "except Exception as e:\n",
65
- " logger.error(f\"Error loading config: {e}\")\n",
66
- " with open(config, \"r\", encoding=\"utf-8\") as f:\n",
67
- " configs = load_hyperpyyaml(f, overrides=override_dict)\n",
68
- "\n"
69
  ]
70
  },
71
  {
72
  "cell_type": "code",
73
- "execution_count": 6,
74
  "id": "a0ba457c",
75
  "metadata": {},
76
  "outputs": [],
77
  "source": [
78
- "data_pipeline = configs['data_pipeline']\n",
79
- "train_data = 'data/data.list'"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  ]
81
  },
82
  {
83
  "cell_type": "code",
84
- "execution_count": 7,
85
  "id": "03fe8925",
86
  "metadata": {},
87
  "outputs": [],
88
  "source": [
89
- "from cosyvoice.dataset.dataset import Dataset\n",
90
- "train_dataset = Dataset(train_data, data_pipeline=data_pipeline, mode='train', gan=False, dpo=False, shuffle=True, partition=True)"
91
  ]
92
  },
93
  {
94
  "cell_type": "code",
95
- "execution_count": 28,
96
  "id": "41bc6b44",
97
  "metadata": {},
98
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  "source": [
100
- "cnt = 0\n",
101
- "for data in train_dataset:\n",
102
- " if cnt==2:\n",
103
- " break\n",
104
- " cnt += 1"
 
 
 
105
  ]
106
  },
107
  {
108
  "cell_type": "code",
109
- "execution_count": 29,
110
  "id": "6f689e0b",
111
  "metadata": {},
112
  "outputs": [
@@ -122,7 +254,8 @@
122
  }
123
  ],
124
  "source": [
125
- "data.keys()"
 
126
  ]
127
  },
128
  {
@@ -260,6 +393,525 @@
260
  "token_len"
261
  ]
262
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  {
264
  "cell_type": "markdown",
265
  "id": "fbf1de4d",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 12,
6
  "id": "4effe69f",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
 
 
 
 
10
  "import os\n",
11
+ "import sys\n",
 
 
12
  "import torch\n",
13
+ "import torchaudio\n",
14
+ "import random\n",
15
+ "import numpy as np\n",
16
+ "import torchaudio\n",
17
+ "from omegaconf import OmegaConf\n",
18
+ "from torch.nn import functional as F\n",
19
+ "\n",
20
+ "from cosyvoice.flow.decoder import ConditionalDecoder, CausalConditionalDecoder\n",
21
+ "from cosyvoice.flow.flow import CausalMaskedDiffWithXvec\n",
22
+ "from cosyvoice.flow.flow_matching import CausalConditionalCFM\n",
23
+ "from cosyvoice.hifigan.f0_predictor import ConvRNNF0Predictor\n",
24
+ "from cosyvoice.hifigan.generator import HiFTGenerator\n",
25
+ "from cosyvoice.llm.llm import Qwen2Encoder, Qwen2LM\n",
26
+ "from cosyvoice.tokenizer.tokenizer import get_qwen_tokenizer\n",
27
+ "from cosyvoice.transformer.upsample_encoder import UpsampleConformerEncoder\n",
28
+ "from cosyvoice.utils.common import ras_sampling\n",
29
+ "\n",
30
+ "# Set CUDA device\n",
31
+ "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # Use GPU 0\n",
32
+ "device = \"cuda:0\"\n",
33
+ "\n",
34
  "\n",
35
+ "def set_deterministic_behavior(seed=42):\n",
36
+ " \"\"\"Set seeds for reproducibility across all random libraries\"\"\"\n",
37
+ " random.seed(seed)\n",
38
+ " np.random.seed(seed)\n",
39
+ " torch.manual_seed(seed)\n",
40
+ " torch.cuda.manual_seed_all(seed)\n",
41
+ " torch.backends.cudnn.deterministic = True\n",
42
+ " torch.backends.cudnn.benchmark = False\n",
43
+ " os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
44
+ "\n",
45
+ "\n",
46
+ "# Call this function at the beginning of your script\n",
47
+ "set_deterministic_behavior(70000)"
48
  ]
49
  },
50
  {
51
  "cell_type": "code",
52
+ "execution_count": 13,
53
  "id": "0322c8f4",
54
  "metadata": {},
55
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
56
  "source": [
57
+ "model_dir = './pretrained_models/CosyVoice2-0.5B'\n",
58
+ "allowed_special = 'all'\n",
59
+ "sample_rate = 24000\n",
60
+ "fp16 = False"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ]
62
  },
63
  {
64
  "cell_type": "code",
65
+ "execution_count": null,
66
  "id": "a0ba457c",
67
  "metadata": {},
68
  "outputs": [],
69
  "source": [
70
+ "llm_config = {\n",
71
+ " 'llm_input_size': 896,\n",
72
+ " 'llm_output_size': 896,\n",
73
+ " 'speech_token_size': 6561,\n",
74
+ " 'length_normalized_loss': True,\n",
75
+ " 'lsm_weight': 0,\n",
76
+ " 'mix_ratio': [5, 15]\n",
77
+ "}\n",
78
+ "\n",
79
+ "llm_encoder_config = {\n",
80
+ " 'pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')\n",
81
+ "}\n",
82
+ "\n",
83
+ "sampling_config = {\n",
84
+ " 'top_p': 0.8,\n",
85
+ " 'top_k': 25,\n",
86
+ " 'win_size': 10,\n",
87
+ " 'tau_r': 0.1\n",
88
+ "}\n",
89
+ "\n",
90
+ "flow_config = {\n",
91
+ " 'input_size': 512,\n",
92
+ " 'output_size': 80,\n",
93
+ " 'spk_embed_dim': 192,\n",
94
+ " 'output_type': 'mel',\n",
95
+ " 'vocab_size': 6561,\n",
96
+ " 'input_frame_rate': 25,\n",
97
+ " 'only_mask_loss': True,\n",
98
+ " 'token_mel_ratio': 2,\n",
99
+ " 'pre_lookahead_len': 3\n",
100
+ "}\n",
101
+ "\n",
102
+ "encoder_config = {\n",
103
+ " 'output_size': 512,\n",
104
+ " 'attention_heads': 8,\n",
105
+ " 'linear_units': 2048,\n",
106
+ " 'num_blocks': 6,\n",
107
+ " 'dropout_rate': 0.1,\n",
108
+ " 'positional_dropout_rate': 0.1,\n",
109
+ " 'attention_dropout_rate': 0.1,\n",
110
+ " 'normalize_before': True,\n",
111
+ " 'input_layer': 'linear',\n",
112
+ " 'pos_enc_layer_type': 'rel_pos_espnet',\n",
113
+ " 'selfattention_layer_type': 'rel_selfattn',\n",
114
+ " 'input_size': 512,\n",
115
+ " 'use_cnn_module': False,\n",
116
+ " 'macaron_style': False\n",
117
+ "}\n",
118
+ "\n",
119
+ "decoder_config = {\n",
120
+ " 'in_channels': 240,\n",
121
+ " 'n_spks': 1,\n",
122
+ " 'spk_emb_dim': 80,\n",
123
+ " 'cfm_params': {\n",
124
+ " 'sigma_min': 1e-06,\n",
125
+ " 'solver': 'euler',\n",
126
+ " 't_scheduler': 'cosine',\n",
127
+ " 'training_cfg_rate': 0.2,\n",
128
+ " 'inference_cfg_rate': 0.7,\n",
129
+ " 'reg_loss_type': 'l1',\n",
130
+ " 'use_immiscible': True,\n",
131
+ " 'immiscible_k': 8,\n",
132
+ " 'use_contrastive_fm': True,\n",
133
+ " 'contrastive_lambda': 0.05,\n",
134
+ " }\n",
135
+ "}\n",
136
+ "decoder_config['cfm_params'] = OmegaConf.create(decoder_config['cfm_params'])\n",
137
+ "\n",
138
+ "estimator_config = {\n",
139
+ " 'in_channels': 320,\n",
140
+ " 'out_channels': 80,\n",
141
+ " 'channels': [256],\n",
142
+ " 'dropout': 0.0,\n",
143
+ " 'attention_head_dim': 64,\n",
144
+ " 'n_blocks': 4,\n",
145
+ " 'num_mid_blocks': 12,\n",
146
+ " 'num_heads': 8,\n",
147
+ " 'act_fn': 'gelu',\n",
148
+ " 'static_chunk_size': 50,\n",
149
+ " 'num_decoding_left_chunks': 2\n",
150
+ " }\n",
151
+ "\n",
152
+ "f0_predictor_config = {\n",
153
+ " 'num_class': 1,\n",
154
+ " 'in_channels': 80,\n",
155
+ " 'cond_channels': 512,\n",
156
+ "}\n",
157
+ "\n",
158
+ "hift_config = {\n",
159
+ " 'in_channels': 80,\n",
160
+ " 'base_channels': 512,\n",
161
+ " 'nb_harmonics': 8,\n",
162
+ " 'sampling_rate': 24000,\n",
163
+ " 'nsf_alpha': 0.1,\n",
164
+ " 'nsf_sigma': 0.003,\n",
165
+ " 'nsf_voiced_threshold': 10,\n",
166
+ " 'upsample_rates': [8, 5, 3],\n",
167
+ " 'upsample_kernel_sizes': [16, 11, 7],\n",
168
+ " 'istft_params': {\n",
169
+ " 'n_fft': 16,\n",
170
+ " 'hop_len': 4,\n",
171
+ " },\n",
172
+ " 'resblock_kernel_sizes': [3, 7, 11],\n",
173
+ " 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n",
174
+ " 'source_resblock_kernel_sizes': [7, 7, 11],\n",
175
+ " 'source_resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n",
176
+ " 'lrelu_slope': 0.1,\n",
177
+ " 'audio_limit': 0.99,\n",
178
+ "}"
179
  ]
180
  },
181
  {
182
  "cell_type": "code",
183
+ "execution_count": 15,
184
  "id": "03fe8925",
185
  "metadata": {},
186
  "outputs": [],
187
  "source": [
188
+ "llm_encoder = Qwen2Encoder(**llm_encoder_config)\n",
189
+ "llm_model = Qwen2LM(llm=llm_encoder, **llm_config, sampling=ras_sampling)"
190
  ]
191
  },
192
  {
193
  "cell_type": "code",
194
+ "execution_count": 16,
195
  "id": "41bc6b44",
196
  "metadata": {},
197
+ "outputs": [
198
+ {
199
+ "name": "stderr",
200
+ "output_type": "stream",
201
+ "text": [
202
+ "/home/mas/anaconda3/envs/learnable/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n",
203
+ " deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n"
204
+ ]
205
+ },
206
+ {
207
+ "ename": "ConfigAttributeError",
208
+ "evalue": "Missing key use_immiscible\n full_key: use_immiscible\n object_type=dict",
209
+ "output_type": "error",
210
+ "traceback": [
211
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
212
+ "\u001b[0;31mConfigAttributeError\u001b[0m Traceback (most recent call last)",
213
+ "Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m flow_encoder \u001b[38;5;241m=\u001b[39m UpsampleConformerEncoder(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mencoder_config)\n\u001b[1;32m 2\u001b[0m estimator \u001b[38;5;241m=\u001b[39m CausalConditionalDecoder(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mestimator_config)\n\u001b[0;32m----> 3\u001b[0m flow_decoder \u001b[38;5;241m=\u001b[39m \u001b[43mCausalConditionalCFM\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdecoder_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m flow \u001b[38;5;241m=\u001b[39m CausalMaskedDiffWithXvec(\n\u001b[1;32m 5\u001b[0m encoder\u001b[38;5;241m=\u001b[39mflow_encoder,\n\u001b[1;32m 6\u001b[0m decoder\u001b[38;5;241m=\u001b[39mflow_decoder,\n\u001b[1;32m 7\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mflow_config\n\u001b[1;32m 8\u001b[0m )\n",
214
+ "File \u001b[0;32m/data/learnable-speech/speech/cosyvoice/flow/flow_matching.py:329\u001b[0m, in \u001b[0;36mCausalConditionalCFM.__init__\u001b[0;34m(self, in_channels, cfm_params, n_spks, spk_emb_dim, estimator)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, in_channels, cfm_params, n_spks\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, spk_emb_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, estimator: torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 329\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43min_channels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcfm_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_spks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspk_emb_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 330\u001b[0m set_all_random_seed(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrand_noise \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn([\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m80\u001b[39m, \u001b[38;5;241m50\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m300\u001b[39m])\n",
215
+ "File \u001b[0;32m/data/learnable-speech/speech/cosyvoice/flow/flow_matching.py:35\u001b[0m, in \u001b[0;36mConditionalCFM.__init__\u001b[0;34m(self, in_channels, cfm_params, n_spks, spk_emb_dim, estimator)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# Just change the architecture of the estimator here\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mestimator \u001b[38;5;241m=\u001b[39m estimator\n\u001b[0;32m---> 35\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_immiscible \u001b[38;5;241m=\u001b[39m \u001b[43mcfm_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_immiscible\u001b[49m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimmiscible_k \u001b[38;5;241m=\u001b[39m cfm_params\u001b[38;5;241m.\u001b[39mimmiscible_k\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlambda_weight \u001b[38;5;241m=\u001b[39m cfm_params\u001b[38;5;241m.\u001b[39mcontrastive_lambda\n",
216
+ "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:355\u001b[0m, in \u001b[0;36mDictConfig.__getattr__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_impl(\n\u001b[1;32m 352\u001b[0m key\u001b[38;5;241m=\u001b[39mkey, default_value\u001b[38;5;241m=\u001b[39m_DEFAULT_MARKER_, validate_key\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 353\u001b[0m )\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConfigKeyError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_format_and_raise\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 356\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtype_override\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mConfigAttributeError\u001b[49m\n\u001b[1;32m 357\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 358\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_and_raise(key\u001b[38;5;241m=\u001b[39mkey, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, cause\u001b[38;5;241m=\u001b[39me)\n",
217
+ "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/base.py:231\u001b[0m, in \u001b[0;36mNode._format_and_raise\u001b[0;34m(self, key, value, cause, msg, type_override)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_format_and_raise\u001b[39m(\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 225\u001b[0m key: Any,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 229\u001b[0m type_override: Any \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 230\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 231\u001b[0m \u001b[43mformat_and_raise\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 232\u001b[0m \u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 234\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 235\u001b[0m \u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcause\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcause\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 237\u001b[0m \u001b[43m \u001b[49m\u001b[43mtype_override\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtype_override\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 238\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 239\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
218
+ "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/_utils.py:899\u001b[0m, in \u001b[0;36mformat_and_raise\u001b[0;34m(node, key, value, msg, cause, type_override)\u001b[0m\n\u001b[1;32m 896\u001b[0m ex\u001b[38;5;241m.\u001b[39mref_type \u001b[38;5;241m=\u001b[39m ref_type\n\u001b[1;32m 897\u001b[0m ex\u001b[38;5;241m.\u001b[39mref_type_str \u001b[38;5;241m=\u001b[39m ref_type_str\n\u001b[0;32m--> 899\u001b[0m \u001b[43m_raise\u001b[49m\u001b[43m(\u001b[49m\u001b[43mex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[43m)\u001b[49m\n",
219
+ "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/_utils.py:797\u001b[0m, in \u001b[0;36m_raise\u001b[0;34m(ex, cause)\u001b[0m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 796\u001b[0m ex\u001b[38;5;241m.\u001b[39m__cause__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 797\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ex\u001b[38;5;241m.\u001b[39mwith_traceback(sys\u001b[38;5;241m.\u001b[39mexc_info()[\u001b[38;5;241m2\u001b[39m])\n",
220
+ "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:351\u001b[0m, in \u001b[0;36mDictConfig.__getattr__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m()\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdefault_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_DEFAULT_MARKER_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConfigKeyError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_and_raise(\n\u001b[1;32m 356\u001b[0m key\u001b[38;5;241m=\u001b[39mkey, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, cause\u001b[38;5;241m=\u001b[39me, type_override\u001b[38;5;241m=\u001b[39mConfigAttributeError\n\u001b[1;32m 357\u001b[0m )\n",
221
+ "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:442\u001b[0m, in \u001b[0;36mDictConfig._get_impl\u001b[0;34m(self, key, default_value, validate_key)\u001b[0m\n\u001b[1;32m 438\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_get_impl\u001b[39m(\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28mself\u001b[39m, key: DictKeyType, default_value: Any, validate_key: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 440\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 442\u001b[0m node \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_child\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 443\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_key\u001b[49m\n\u001b[1;32m 444\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConfigAttributeError, ConfigKeyError):\n\u001b[1;32m 446\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m default_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _DEFAULT_MARKER_:\n",
222
+ "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/basecontainer.py:73\u001b[0m, in \u001b[0;36mBaseContainer._get_child\u001b[0;34m(self, key, validate_access, validate_key, throw_on_missing_value, throw_on_missing_key)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_get_child\u001b[39m(\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 66\u001b[0m key: Any,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 70\u001b[0m throw_on_missing_key: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 71\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Optional[Node], List[Optional[Node]]]:\n\u001b[1;32m 72\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Like _get_node, passing through to the nearest concrete Node.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 73\u001b[0m child \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_node\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidate_access\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_access\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 76\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mthrow_on_missing_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mthrow_on_missing_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m \u001b[49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(child, UnionNode) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_special(child):\n\u001b[1;32m 81\u001b[0m value \u001b[38;5;241m=\u001b[39m child\u001b[38;5;241m.\u001b[39m_value()\n",
223
+ "File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:480\u001b[0m, in \u001b[0;36mDictConfig._get_node\u001b[0;34m(self, key, validate_access, validate_key, throw_on_missing_value, throw_on_missing_key)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m throw_on_missing_key:\n\u001b[0;32m--> 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ConfigKeyError(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMissing key \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m!s}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 481\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m throw_on_missing_value \u001b[38;5;129;01mand\u001b[39;00m value\u001b[38;5;241m.\u001b[39m_is_missing():\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MissingMandatoryValue(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMissing mandatory value: $KEY\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
224
+ "\u001b[0;31mConfigAttributeError\u001b[0m: Missing key use_immiscible\n full_key: use_immiscible\n object_type=dict"
225
+ ]
226
+ }
227
+ ],
228
  "source": [
229
+ "flow_encoder = UpsampleConformerEncoder(**encoder_config)\n",
230
+ "estimator = CausalConditionalDecoder(**estimator_config)\n",
231
+ "flow_decoder = CausalConditionalCFM(**decoder_config, estimator=estimator)\n",
232
+ "flow = CausalMaskedDiffWithXvec(\n",
233
+ " encoder=flow_encoder,\n",
234
+ " decoder=flow_decoder,\n",
235
+ " **flow_config\n",
236
+ ")"
237
  ]
238
  },
239
  {
240
  "cell_type": "code",
241
+ "execution_count": null,
242
  "id": "6f689e0b",
243
  "metadata": {},
244
  "outputs": [
 
254
  }
255
  ],
256
  "source": [
257
+ "f0_predictor = ConvRNNF0Predictor(**f0_predictor_config)\n",
258
+ "hifi = HiFTGenerator(**hift_config, f0_predictor=f0_predictor)"
259
  ]
260
  },
261
  {
 
393
  "token_len"
394
  ]
395
  },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 5,
399
+ "id": "2dcfa795",
400
+ "metadata": {},
401
+ "outputs": [
402
+ {
403
+ "name": "stdout",
404
+ "output_type": "stream",
405
+ "text": [
406
+ "Testing ResumableSequentialLR:\n",
407
+ "--------------------------------------------------\n",
408
+ "Step LR Expected Match \n",
409
+ "--------------------------------------------------\n",
410
+ "0 1.000000e-04 1.000000e-04 ✓ \n",
411
+ "1 2.800000e-04 2.800000e-04 ✓ \n",
412
+ "2 4.600000e-04 4.600000e-04 ✓ \n",
413
+ "3 6.400000e-04 6.400000e-04 ✓ \n",
414
+ "4 8.200000e-04 8.200000e-04 ✓ \n",
415
+ "5 1.000000e-03 1.000000e-03 ✓ \n",
416
+ "6 1.000000e-03 1.000000e-03 ✓ \n",
417
+ "7 1.000000e-03 1.000000e-03 ✓ \n",
418
+ "8 1.000000e-03 1.000000e-03 ✓ \n",
419
+ "9 1.000000e-03 1.000000e-03 ✓ \n",
420
+ "\n",
421
+ "Testing resume from step 7:\n",
422
+ "--------------------------------------------------\n",
423
+ "7 1.000000e-03 1.000000e-03 ✓ \n",
424
+ "8 1.000000e-03 1.000000e-03 ✓ \n",
425
+ "9 1.000000e-03 1.000000e-03 ✓ \n"
426
+ ]
427
+ }
428
+ ],
429
+ "source": [
430
+ "from torch.optim.lr_scheduler import _LRScheduler\n",
431
+ "import warnings\n",
432
+ "\n",
433
+ "class ResumableSequentialLR(_LRScheduler):\n",
434
+ " \"\"\"A resumable version of SequentialLR that properly manages child schedulers\"\"\"\n",
435
+ " \n",
436
+ " def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):\n",
437
+ " \"\"\"\n",
438
+ " Args:\n",
439
+ " optimizer: Wrapped optimizer\n",
440
+ " schedulers: List of schedulers to sequentially use\n",
441
+ " milestones: List of epoch/step numbers when to switch schedulers\n",
442
+ " last_epoch: The index of last epoch/step\n",
443
+ " \"\"\"\n",
444
+ " # Validate inputs\n",
445
+ " if len(schedulers) != len(milestones) + 1:\n",
446
+ " raise ValueError(\"Expected len(schedulers) == len(milestones) + 1\")\n",
447
+ " \n",
448
+ " self.schedulers = schedulers\n",
449
+ " self.milestones = milestones\n",
450
+ " self._scheduler_idx = 0\n",
451
+ " \n",
452
+ " # Initialize parent class (this sets last_epoch and calls step())\n",
453
+ " super().__init__(optimizer, last_epoch)\n",
454
+ " \n",
455
+ " def _get_scheduler_info(self, epoch):\n",
456
+ " \"\"\"Determine which scheduler to use and its relative epoch\"\"\"\n",
457
+ " scheduler_idx = 0\n",
458
+ " relative_epoch = epoch\n",
459
+ " \n",
460
+ " for i, milestone in enumerate(self.milestones):\n",
461
+ " if epoch >= milestone:\n",
462
+ " scheduler_idx = i + 1\n",
463
+ " if i == 0:\n",
464
+ " relative_epoch = epoch - milestone\n",
465
+ " else:\n",
466
+ " relative_epoch = epoch - milestone\n",
467
+ " else:\n",
468
+ " break\n",
469
+ " \n",
470
+ " # Calculate relative epoch for the current scheduler\n",
471
+ " if scheduler_idx == 0:\n",
472
+ " relative_epoch = epoch\n",
473
+ " elif scheduler_idx < len(self.milestones):\n",
474
+ " if scheduler_idx == 1:\n",
475
+ " relative_epoch = epoch - self.milestones[0]\n",
476
+ " else:\n",
477
+ " relative_epoch = epoch - self.milestones[scheduler_idx - 1]\n",
478
+ " \n",
479
+ " return scheduler_idx, relative_epoch\n",
480
+ " \n",
481
+ " def get_lr(self):\n",
482
+ " \"\"\"Get learning rate from the appropriate scheduler\"\"\"\n",
483
+ " if not self._get_lr_called_within_step:\n",
484
+ " warnings.warn(\"To get the last learning rate computed by the scheduler, \"\n",
485
+ " \"please use `get_last_lr()`.\", UserWarning)\n",
486
+ " \n",
487
+ " # Get current scheduler and its relative epoch\n",
488
+ " scheduler_idx, relative_epoch = self._get_scheduler_info(self.last_epoch)\n",
489
+ " scheduler = self.schedulers[scheduler_idx]\n",
490
+ " \n",
491
+ " # Set the scheduler's last_epoch to match relative progress\n",
492
+ " scheduler.last_epoch = relative_epoch\n",
493
+ " \n",
494
+ " # Get LR from the scheduler\n",
495
+ " if hasattr(scheduler, '_get_closed_form_lr'):\n",
496
+ " return scheduler._get_closed_form_lr()\n",
497
+ " else:\n",
498
+ " # Temporarily set the flag to avoid warning from child scheduler\n",
499
+ " scheduler._get_lr_called_within_step = True\n",
500
+ " lrs = scheduler.get_lr()\n",
501
+ " scheduler._get_lr_called_within_step = False\n",
502
+ " return lrs\n",
503
+ " \n",
504
+ " def step(self, epoch=None):\n",
505
+ " \"\"\"Step the scheduler\"\"\"\n",
506
+ " # Step the parent class (updates last_epoch and sets _get_lr_called_within_step)\n",
507
+ " super().step(epoch)\n",
508
+ " \n",
509
+ " def set_step(self, step):\n",
510
+ " \"\"\"Set the current step for resuming training\"\"\"\n",
511
+ " self.last_epoch = step - 1\n",
512
+ " \n",
513
+ " # Update child schedulers' state\n",
514
+ " scheduler_idx, relative_epoch = self._get_scheduler_info(step - 1)\n",
515
+ " \n",
516
+ " # Set all previous schedulers to their final state\n",
517
+ " for i in range(scheduler_idx):\n",
518
+ " if i < len(self.milestones):\n",
519
+ " if i == 0:\n",
520
+ " self.schedulers[i].last_epoch = self.milestones[i] - 1\n",
521
+ " else:\n",
522
+ " self.schedulers[i].last_epoch = self.milestones[i] - self.milestones[i-1] - 1\n",
523
+ " \n",
524
+ " # Set current scheduler to its relative position\n",
525
+ " self.schedulers[scheduler_idx].last_epoch = relative_epoch\n",
526
+ " \n",
527
+ " # Update optimizer's learning rates\n",
528
+ " for param_group, lr in zip(self.optimizer.param_groups, self.get_last_lr()):\n",
529
+ " param_group['lr'] = lr\n",
530
+ "\n",
531
+ "\n",
532
+ "# Alternative simpler implementation that's more robust\n",
533
+ "class SimpleResumableSequentialLR(_LRScheduler):\n",
534
+ " \"\"\"Simpler implementation that manually tracks scheduler states\"\"\"\n",
535
+ " \n",
536
+ " def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):\n",
537
+ " self.schedulers = schedulers\n",
538
+ " self.milestones = milestones\n",
539
+ " super().__init__(optimizer, last_epoch)\n",
540
+ " \n",
541
+ " def get_lr(self):\n",
542
+ " \"\"\"Calculate learning rate based on current epoch\"\"\"\n",
543
+ " epoch = self.last_epoch\n",
544
+ " \n",
545
+ " # For LinearLR with warmup\n",
546
+ " if epoch < self.milestones[0]:\n",
547
+ " # We're in warmup phase\n",
548
+ " warmup_scheduler = self.schedulers[0]\n",
549
+ " start_factor = warmup_scheduler.start_factor\n",
550
+ " end_factor = warmup_scheduler.end_factor\n",
551
+ " total_iters = warmup_scheduler.total_iters\n",
552
+ " \n",
553
+ " # Calculate factor\n",
554
+ " if epoch >= total_iters:\n",
555
+ " factor = end_factor\n",
556
+ " else:\n",
557
+ " factor = start_factor + (end_factor - start_factor) * epoch / total_iters\n",
558
+ " \n",
559
+ " # Apply factor to base learning rates\n",
560
+ " return [base_lr * factor for base_lr in self.base_lrs]\n",
561
+ " else:\n",
562
+ " # We're in constant phase - just return base LRs\n",
563
+ " return [base_lr * 1.0 for base_lr in self.base_lrs]\n",
564
+ "\n",
565
+ "\n",
566
+ "# Test function to verify the scheduler works correctly\n",
567
+ "def test_resumable_scheduler():\n",
568
+ " \"\"\"Test the ResumableSequentialLR implementation\"\"\"\n",
569
+ " import torch\n",
570
+ " import torch.optim as optim\n",
571
+ " from torch.optim.lr_scheduler import LinearLR, ConstantLR\n",
572
+ " \n",
573
+ " # Create dummy model and optimizer\n",
574
+ " model = torch.nn.Linear(10, 1)\n",
575
+ " base_lr = 1e-3\n",
576
+ " optimizer = optim.Adam(model.parameters(), lr=base_lr)\n",
577
+ " \n",
578
+ " # Create schedulers\n",
579
+ " warmup_steps = 5\n",
580
+ " warmup_scheduler = LinearLR(\n",
581
+ " optimizer,\n",
582
+ " start_factor=0.1,\n",
583
+ " end_factor=1.0,\n",
584
+ " total_iters=warmup_steps\n",
585
+ " )\n",
586
+ " \n",
587
+ " constant_scheduler = ConstantLR(\n",
588
+ " optimizer,\n",
589
+ " factor=1.0,\n",
590
+ " total_iters=float('inf')\n",
591
+ " )\n",
592
+ " \n",
593
+ " # Test both implementations\n",
594
+ " print(\"Testing ResumableSequentialLR:\")\n",
595
+ " print(\"-\" * 50)\n",
596
+ " \n",
597
+ " # Reset optimizer\n",
598
+ " for param_group in optimizer.param_groups:\n",
599
+ " param_group['lr'] = base_lr\n",
600
+ " \n",
601
+ " scheduler = ResumableSequentialLR(\n",
602
+ " optimizer,\n",
603
+ " schedulers=[warmup_scheduler, constant_scheduler],\n",
604
+ " milestones=[warmup_steps]\n",
605
+ " )\n",
606
+ " \n",
607
+ " print(f\"{'Step':<10} {'LR':<15} {'Expected':<15} {'Match':<10}\")\n",
608
+ " print(\"-\" * 50)\n",
609
+ " \n",
610
+ " for step in range(10):\n",
611
+ " current_lr = optimizer.param_groups[0]['lr']\n",
612
+ " \n",
613
+ " # Calculate expected LR\n",
614
+ " if step < warmup_steps:\n",
615
+ " expected_lr = base_lr * (0.1 + 0.9 * step / warmup_steps)\n",
616
+ " else:\n",
617
+ " expected_lr = base_lr\n",
618
+ " \n",
619
+ " match = \"✓\" if abs(current_lr - expected_lr) < 1e-10 else \"✗\"\n",
620
+ " print(f\"{step:<10} {current_lr:<15.6e} {expected_lr:<15.6e} {match:<10}\")\n",
621
+ " \n",
622
+ " scheduler.step()\n",
623
+ " \n",
624
+ " # Test resuming\n",
625
+ " print(\"\\nTesting resume from step 7:\")\n",
626
+ " print(\"-\" * 50)\n",
627
+ " \n",
628
+ " # Reset and jump to step 7\n",
629
+ " for param_group in optimizer.param_groups:\n",
630
+ " param_group['lr'] = base_lr\n",
631
+ " \n",
632
+ " scheduler = ResumableSequentialLR(\n",
633
+ " optimizer,\n",
634
+ " schedulers=[warmup_scheduler, constant_scheduler],\n",
635
+ " milestones=[warmup_steps]\n",
636
+ " )\n",
637
+ " scheduler.set_step(7)\n",
638
+ " \n",
639
+ " for step in range(7, 10):\n",
640
+ " scheduler.step()\n",
641
+ " current_lr = optimizer.param_groups[0]['lr']\n",
642
+ " expected_lr = base_lr # Should be constant phase\n",
643
+ " match = \"✓\" if abs(current_lr - expected_lr) < 1e-10 else \"✗\"\n",
644
+ " print(f\"{step:<10} {current_lr:<15.6e} {expected_lr:<15.6e} {match:<10}\")\n",
645
+ "\n",
646
+ "\n",
647
+ "if __name__ == \"__main__\":\n",
648
+ " test_resumable_scheduler()"
649
+ ]
650
+ },
651
+ {
652
+ "cell_type": "code",
653
+ "execution_count": null,
654
+ "id": "ce71bea4",
655
+ "metadata": {},
656
+ "outputs": [],
657
+ "source": []
658
+ },
659
+ {
660
+ "cell_type": "code",
661
+ "execution_count": null,
662
+ "id": "42b9b936",
663
+ "metadata": {},
664
+ "outputs": [],
665
+ "source": []
666
+ },
667
+ {
668
+ "cell_type": "code",
669
+ "execution_count": 3,
670
+ "id": "e3d4d5a1",
671
+ "metadata": {},
672
+ "outputs": [
673
+ {
674
+ "name": "stdout",
675
+ "output_type": "stream",
676
+ "text": [
677
+ "=== Learning Rate Source Verification ===\n",
678
+ "\n",
679
+ "Comparing LR sources during warmup:\n",
680
+ "\n",
681
+ "Step Optimizer LR Scheduler LR Match? \n",
682
+ "--------------------------------------------------\n",
683
+ "0 1.00e-04 1.00e-04 ✓ \n",
684
+ "1 2.80e-04 2.80e-04 ✓ \n",
685
+ "2 4.60e-04 4.60e-04 ✓ \n",
686
+ "3 6.40e-04 6.40e-04 ✓ \n",
687
+ "4 8.20e-04 8.20e-04 ✓ \n",
688
+ "5 1.00e-03 1.00e-03 ✓ \n",
689
+ "6 1.00e-03 1.00e-03 ✓ \n",
690
+ "7 1.00e-03 1.00e-03 ✓ \n",
691
+ "8 1.00e-03 1.00e-03 ✓ \n",
692
+ "9 1.00e-03 1.00e-03 ✓ \n",
693
+ "\n",
694
+ "Conclusion: optimizer.param_groups[0]['lr'] is the authoritative source!\n",
695
+ "\n",
696
+ "\n",
697
+ "Manual LR change test:\n",
698
+ "Current optimizer LR: 1.00e-03\n",
699
+ "After manual change: 1.00e-02\n",
700
+ "This confirms the optimizer holds the actual LR being used.\n",
701
+ "\n",
702
+ "==================================================\n",
703
+ "\n",
704
+ "\n",
705
+ "Different ways to access learning rate:\n",
706
+ "\n",
707
+ "Initial state:\n",
708
+ " optimizer.param_groups[0]['lr']: 1.00e-04\n",
709
+ " scheduler.get_last_lr(): 1.00e-04\n",
710
+ "\n",
711
+ "After scheduler.step():\n",
712
+ " optimizer.param_groups[0]['lr']: 2.80e-04\n",
713
+ " scheduler.get_last_lr(): 2.80e-04\n",
714
+ "\n",
715
+ "Key insights:\n",
716
+ "1. optimizer.param_groups[0]['lr'] - Always current, used by optimizer\n",
717
+ "2. scheduler.get_last_lr() - What scheduler set on last step()\n",
718
+ "3. scheduler.get_lr() - Internal method, calculates next LR (don't use directly)\n",
719
+ "\n",
720
+ "==================================================\n",
721
+ "\n",
722
+ "\n",
723
+ "Multiple parameter groups:\n",
724
+ " Group 0: lr = 1.00e-03\n",
725
+ " Group 1: lr = 1.00e-04\n",
726
+ "\n",
727
+ "After scheduler step:\n",
728
+ " Group 0: lr = 2.80e-04\n",
729
+ " Group 1: lr = 2.80e-05\n"
730
+ ]
731
+ }
732
+ ],
733
+ "source": [
734
+ "import torch\n",
735
+ "import torch.optim as optim\n",
736
+ "from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR\n",
737
+ "\n",
738
+ "def verify_lr_sources():\n",
739
+ " \"\"\"Verify that optimizer.param_groups[0]['lr'] is the correct source\"\"\"\n",
740
+ " \n",
741
+ " # Create a simple model and optimizer\n",
742
+ " model = torch.nn.Linear(10, 1)\n",
743
+ " optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
744
+ " \n",
745
+ " # Create schedulers\n",
746
+ " warmup_scheduler = LinearLR(\n",
747
+ " optimizer,\n",
748
+ " start_factor=0.1, # Start at 10% of base LR\n",
749
+ " end_factor=1.0, # End at 100% of base LR\n",
750
+ " total_iters=5 # 5 warmup steps\n",
751
+ " )\n",
752
+ " \n",
753
+ " constant_scheduler = ConstantLR(\n",
754
+ " optimizer,\n",
755
+ " factor=1.0,\n",
756
+ " total_iters=float('inf')\n",
757
+ " )\n",
758
+ " \n",
759
+ " scheduler = SequentialLR(\n",
760
+ " optimizer,\n",
761
+ " schedulers=[warmup_scheduler, constant_scheduler],\n",
762
+ " milestones=[5]\n",
763
+ " )\n",
764
+ " \n",
765
+ " print(\"Comparing LR sources during warmup:\\n\")\n",
766
+ " print(f\"{'Step':<6} {'Optimizer LR':<15} {'Scheduler LR':<15} {'Match?':<10}\")\n",
767
+ " print(\"-\" * 50)\n",
768
+ " \n",
769
+ " for step in range(10):\n",
770
+ " # Get LR from optimizer\n",
771
+ " optimizer_lr = optimizer.param_groups[0]['lr']\n",
772
+ " \n",
773
+ " # Get LR from scheduler (if available)\n",
774
+ " # Note: scheduler.get_last_lr() returns the LR after the last step\n",
775
+ " scheduler_lr = scheduler.get_last_lr()[0] if hasattr(scheduler, 'get_last_lr') else None\n",
776
+ " \n",
777
+ " # Print comparison\n",
778
+ " match = \"✓\" if scheduler_lr is None or abs(optimizer_lr - scheduler_lr) < 1e-10 else \"✗\"\n",
779
+ " print(f\"{step:<6} {optimizer_lr:<15.2e} {scheduler_lr:<15.2e} {match:<10}\")\n",
780
+ " \n",
781
+ " # Step the scheduler\n",
782
+ " scheduler.step()\n",
783
+ " \n",
784
+ " print(\"\\nConclusion: optimizer.param_groups[0]['lr'] is the authoritative source!\")\n",
785
+ " \n",
786
+ " # Additional verification: what happens if we manually change the optimizer's LR?\n",
787
+ " print(\"\\n\\nManual LR change test:\")\n",
788
+ " print(f\"Current optimizer LR: {optimizer.param_groups[0]['lr']:.2e}\")\n",
789
+ " \n",
790
+ " # Manually change it\n",
791
+ " for param_group in optimizer.param_groups:\n",
792
+ " param_group['lr'] = 0.01\n",
793
+ " \n",
794
+ " print(f\"After manual change: {optimizer.param_groups[0]['lr']:.2e}\")\n",
795
+ " print(\"This confirms the optimizer holds the actual LR being used.\")\n",
796
+ "\n",
797
+ "\n",
798
+ "def compare_lr_access_methods():\n",
799
+ " \"\"\"Compare different ways to access the learning rate\"\"\"\n",
800
+ " \n",
801
+ " model = torch.nn.Linear(10, 1)\n",
802
+ " optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
803
+ " \n",
804
+ " scheduler = LinearLR(\n",
805
+ " optimizer,\n",
806
+ " start_factor=0.1,\n",
807
+ " end_factor=1.0,\n",
808
+ " total_iters=5\n",
809
+ " )\n",
810
+ " \n",
811
+ " print(\"\\nDifferent ways to access learning rate:\\n\")\n",
812
+ " \n",
813
+ " # Before any steps\n",
814
+ " print(\"Initial state:\")\n",
815
+ " print(f\" optimizer.param_groups[0]['lr']: {optimizer.param_groups[0]['lr']:.2e}\")\n",
816
+ " print(f\" scheduler.get_last_lr(): {scheduler.get_last_lr()[0]:.2e}\")\n",
817
+ " \n",
818
+ " # After stepping\n",
819
+ " scheduler.step()\n",
820
+ " print(\"\\nAfter scheduler.step():\")\n",
821
+ " print(f\" optimizer.param_groups[0]['lr']: {optimizer.param_groups[0]['lr']:.2e}\")\n",
822
+ " print(f\" scheduler.get_last_lr(): {scheduler.get_last_lr()[0]:.2e}\")\n",
823
+ " \n",
824
+ " # Key insight\n",
825
+ " print(\"\\nKey insights:\")\n",
826
+ " print(\"1. optimizer.param_groups[0]['lr'] - Always current, used by optimizer\")\n",
827
+ " print(\"2. scheduler.get_last_lr() - What scheduler set on last step()\")\n",
828
+ " print(\"3. scheduler.get_lr() - Internal method, calculates next LR (don't use directly)\")\n",
829
+ "\n",
830
+ "\n",
831
+ "def check_multiple_param_groups():\n",
832
+ " \"\"\"Check how LR works with multiple parameter groups\"\"\"\n",
833
+ " \n",
834
+ " model = torch.nn.Sequential(\n",
835
+ " torch.nn.Linear(10, 20),\n",
836
+ " torch.nn.Linear(20, 1)\n",
837
+ " )\n",
838
+ " \n",
839
+ " # Different LRs for different layers\n",
840
+ " optimizer = optim.Adam([\n",
841
+ " {'params': model[0].parameters(), 'lr': 1e-3},\n",
842
+ " {'params': model[1].parameters(), 'lr': 1e-4}\n",
843
+ " ])\n",
844
+ " \n",
845
+ " print(\"\\nMultiple parameter groups:\")\n",
846
+ " for i, param_group in enumerate(optimizer.param_groups):\n",
847
+ " print(f\" Group {i}: lr = {param_group['lr']:.2e}\")\n",
848
+ " \n",
849
+ " # Scheduler affects all groups\n",
850
+ " scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=5)\n",
851
+ " scheduler.step()\n",
852
+ " \n",
853
+ " print(\"\\nAfter scheduler step:\")\n",
854
+ " for i, param_group in enumerate(optimizer.param_groups):\n",
855
+ " print(f\" Group {i}: lr = {param_group['lr']:.2e}\")\n",
856
+ "\n",
857
+ "\n",
858
+ "if __name__ == \"__main__\":\n",
859
+ " print(\"=== Learning Rate Source Verification ===\\n\")\n",
860
+ " verify_lr_sources()\n",
861
+ " print(\"\\n\" + \"=\"*50 + \"\\n\")\n",
862
+ " compare_lr_access_methods()\n",
863
+ " print(\"\\n\" + \"=\"*50 + \"\\n\")\n",
864
+ " check_multiple_param_groups()"
865
+ ]
866
+ },
867
+ {
868
+ "cell_type": "code",
869
+ "execution_count": null,
870
+ "id": "918d3322",
871
+ "metadata": {},
872
+ "outputs": [],
873
+ "source": []
874
+ },
875
+ {
876
+ "cell_type": "code",
877
+ "execution_count": null,
878
+ "id": "eb19ac5e",
879
+ "metadata": {},
880
+ "outputs": [],
881
+ "source": []
882
+ },
883
+ {
884
+ "cell_type": "code",
885
+ "execution_count": null,
886
+ "id": "7f2c3038",
887
+ "metadata": {},
888
+ "outputs": [],
889
+ "source": []
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": null,
894
+ "id": "4f528b78",
895
+ "metadata": {},
896
+ "outputs": [],
897
+ "source": []
898
+ },
899
+ {
900
+ "cell_type": "code",
901
+ "execution_count": null,
902
+ "id": "f0fcea90",
903
+ "metadata": {},
904
+ "outputs": [],
905
+ "source": []
906
+ },
907
+ {
908
+ "cell_type": "code",
909
+ "execution_count": null,
910
+ "id": "bb5de4ae",
911
+ "metadata": {},
912
+ "outputs": [],
913
+ "source": []
914
+ },
915
  {
916
  "cell_type": "markdown",
917
  "id": "fbf1de4d",
speech/test_train.sh CHANGED
@@ -66,14 +66,12 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --rdzv_id=$job_id --rdzv_backend=
66
  --cv_data data/data.list \
67
  --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
68
  --model $model \
69
- --checkpoint $pretrained_model_dir/$model.pt \
70
- --model_dir /mnt/nvme/speech/$model/ \
71
  --num_workers ${num_workers} \
72
  --prefetch ${prefetch} \
73
  --pin_memory \
74
  --use_amp \
75
- --comet_disabled
76
-
77
  # # average model
78
  # average_num=5
79
  # if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
 
66
  --cv_data data/data.list \
67
  --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
68
  --model $model \
69
+ --model_dir /data/checkpoint/$model/ \
 
70
  --num_workers ${num_workers} \
71
  --prefetch ${prefetch} \
72
  --pin_memory \
73
  --use_amp \
74
+ --checkpoint /data/checkpoint/flow/epoch_88_step_14001.pt
 
75
  # # average model
76
  # average_num=5
77
  # if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then