Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
d1b8469
1
Parent(s):
5805255
update notebook
Browse files- speech/.gitignore +0 -52
- speech/config.yaml +5 -4
- speech/dev.ipynb +717 -65
- speech/test_train.sh +2 -4
speech/.gitignore
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
# Byte-compiled / optimized / DLL files
|
| 2 |
-
__pycache__/
|
| 3 |
-
*.py[cod]
|
| 4 |
-
*$py.class
|
| 5 |
-
|
| 6 |
-
# Visual Studio Code files
|
| 7 |
-
.vscode
|
| 8 |
-
.vs
|
| 9 |
-
|
| 10 |
-
# PyCharm files
|
| 11 |
-
.idea
|
| 12 |
-
|
| 13 |
-
# Eclipse Project settings
|
| 14 |
-
*.*project
|
| 15 |
-
.settings
|
| 16 |
-
|
| 17 |
-
# Sublime Text settings
|
| 18 |
-
*.sublime-workspace
|
| 19 |
-
*.sublime-project
|
| 20 |
-
|
| 21 |
-
# Editor temporaries
|
| 22 |
-
*.swn
|
| 23 |
-
*.swo
|
| 24 |
-
*.swp
|
| 25 |
-
*.swm
|
| 26 |
-
*~
|
| 27 |
-
|
| 28 |
-
# IPython notebook checkpoints
|
| 29 |
-
.ipynb_checkpoints
|
| 30 |
-
|
| 31 |
-
# macOS dir files
|
| 32 |
-
.DS_Store
|
| 33 |
-
|
| 34 |
-
exp
|
| 35 |
-
data
|
| 36 |
-
raw_wav
|
| 37 |
-
tensorboard
|
| 38 |
-
**/*build*
|
| 39 |
-
|
| 40 |
-
# Clangd files
|
| 41 |
-
.cache
|
| 42 |
-
compile_commands.json
|
| 43 |
-
|
| 44 |
-
# train/inference files
|
| 45 |
-
*.wav
|
| 46 |
-
*.m4a
|
| 47 |
-
*.aac
|
| 48 |
-
*.pt
|
| 49 |
-
pretrained_models/*
|
| 50 |
-
*_pb2_grpc.py
|
| 51 |
-
*_pb2.py
|
| 52 |
-
*.tar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
speech/config.yaml
CHANGED
|
@@ -200,12 +200,13 @@ data_pipeline: [
|
|
| 200 |
train_conf:
|
| 201 |
optim: adamw
|
| 202 |
optim_conf:
|
| 203 |
-
lr:
|
| 204 |
scheduler: constantlr # change to constantlr during sft
|
| 205 |
scheduler_conf:
|
| 206 |
-
warmup_steps:
|
| 207 |
-
max_epoch:
|
| 208 |
grad_clip: 1
|
| 209 |
accum_grad: 1
|
| 210 |
log_interval: 5
|
| 211 |
-
save_per_step:
|
|
|
|
|
|
| 200 |
train_conf:
|
| 201 |
optim: adamw
|
| 202 |
optim_conf:
|
| 203 |
+
lr: 5e-5 # change to 1e-5 during sft
|
| 204 |
scheduler: constantlr # change to constantlr during sft
|
| 205 |
scheduler_conf:
|
| 206 |
+
warmup_steps: 500
|
| 207 |
+
max_epoch: 2000
|
| 208 |
grad_clip: 1
|
| 209 |
accum_grad: 1
|
| 210 |
log_interval: 5
|
| 211 |
+
save_per_step: 2000
|
| 212 |
+
total_iters: 1000000000
|
speech/dev.ipynb
CHANGED
|
@@ -2,111 +2,243 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "4effe69f",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
| 9 |
"source": [
|
| 10 |
-
"from __future__ import print_function\n",
|
| 11 |
-
"\n",
|
| 12 |
-
"import argparse\n",
|
| 13 |
-
"import datetime\n",
|
| 14 |
"import os\n",
|
| 15 |
-
"
|
| 16 |
-
"\n",
|
| 17 |
-
"import deepspeed\n",
|
| 18 |
"import torch\n",
|
| 19 |
-
"import
|
| 20 |
-
"
|
| 21 |
-
"
|
| 22 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
"\n",
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
-
"
|
| 28 |
-
"
|
| 29 |
-
"
|
| 30 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
]
|
| 32 |
},
|
| 33 |
{
|
| 34 |
"cell_type": "code",
|
| 35 |
-
"execution_count":
|
| 36 |
"id": "0322c8f4",
|
| 37 |
"metadata": {},
|
| 38 |
-
"outputs": [
|
| 39 |
-
{
|
| 40 |
-
"name": "stderr",
|
| 41 |
-
"output_type": "stream",
|
| 42 |
-
"text": [
|
| 43 |
-
"/home/mas/anaconda3/envs/learnable/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n",
|
| 44 |
-
" deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n",
|
| 45 |
-
"2025-07-14 13:59:59,637 INFO input frame rate=25\n"
|
| 46 |
-
]
|
| 47 |
-
}
|
| 48 |
-
],
|
| 49 |
"source": [
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
| 53 |
-
"
|
| 54 |
-
"qwen_pretrain_path = './pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN'\n",
|
| 55 |
-
"try:\n",
|
| 56 |
-
" with open(config, \"r\", encoding=\"utf-8\") as f:\n",
|
| 57 |
-
" configs = load_hyperpyyaml(\n",
|
| 58 |
-
" f,\n",
|
| 59 |
-
" overrides={\n",
|
| 60 |
-
" **override_dict,\n",
|
| 61 |
-
" \"qwen_pretrain_path\": qwen_pretrain_path,\n",
|
| 62 |
-
" },\n",
|
| 63 |
-
" )\n",
|
| 64 |
-
"except Exception as e:\n",
|
| 65 |
-
" logger.error(f\"Error loading config: {e}\")\n",
|
| 66 |
-
" with open(config, \"r\", encoding=\"utf-8\") as f:\n",
|
| 67 |
-
" configs = load_hyperpyyaml(f, overrides=override_dict)\n",
|
| 68 |
-
"\n"
|
| 69 |
]
|
| 70 |
},
|
| 71 |
{
|
| 72 |
"cell_type": "code",
|
| 73 |
-
"execution_count":
|
| 74 |
"id": "a0ba457c",
|
| 75 |
"metadata": {},
|
| 76 |
"outputs": [],
|
| 77 |
"source": [
|
| 78 |
-
"
|
| 79 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
]
|
| 81 |
},
|
| 82 |
{
|
| 83 |
"cell_type": "code",
|
| 84 |
-
"execution_count":
|
| 85 |
"id": "03fe8925",
|
| 86 |
"metadata": {},
|
| 87 |
"outputs": [],
|
| 88 |
"source": [
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
]
|
| 92 |
},
|
| 93 |
{
|
| 94 |
"cell_type": "code",
|
| 95 |
-
"execution_count":
|
| 96 |
"id": "41bc6b44",
|
| 97 |
"metadata": {},
|
| 98 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
"source": [
|
| 100 |
-
"
|
| 101 |
-
"
|
| 102 |
-
"
|
| 103 |
-
"
|
| 104 |
-
"
|
|
|
|
|
|
|
|
|
|
| 105 |
]
|
| 106 |
},
|
| 107 |
{
|
| 108 |
"cell_type": "code",
|
| 109 |
-
"execution_count":
|
| 110 |
"id": "6f689e0b",
|
| 111 |
"metadata": {},
|
| 112 |
"outputs": [
|
|
@@ -122,7 +254,8 @@
|
|
| 122 |
}
|
| 123 |
],
|
| 124 |
"source": [
|
| 125 |
-
"
|
|
|
|
| 126 |
]
|
| 127 |
},
|
| 128 |
{
|
|
@@ -260,6 +393,525 @@
|
|
| 260 |
"token_len"
|
| 261 |
]
|
| 262 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
{
|
| 264 |
"cell_type": "markdown",
|
| 265 |
"id": "fbf1de4d",
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 12,
|
| 6 |
"id": "4effe69f",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
| 9 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"import os\n",
|
| 11 |
+
"import sys\n",
|
|
|
|
|
|
|
| 12 |
"import torch\n",
|
| 13 |
+
"import torchaudio\n",
|
| 14 |
+
"import random\n",
|
| 15 |
+
"import numpy as np\n",
|
| 16 |
+
"import torchaudio\n",
|
| 17 |
+
"from omegaconf import OmegaConf\n",
|
| 18 |
+
"from torch.nn import functional as F\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"from cosyvoice.flow.decoder import ConditionalDecoder, CausalConditionalDecoder\n",
|
| 21 |
+
"from cosyvoice.flow.flow import CausalMaskedDiffWithXvec\n",
|
| 22 |
+
"from cosyvoice.flow.flow_matching import CausalConditionalCFM\n",
|
| 23 |
+
"from cosyvoice.hifigan.f0_predictor import ConvRNNF0Predictor\n",
|
| 24 |
+
"from cosyvoice.hifigan.generator import HiFTGenerator\n",
|
| 25 |
+
"from cosyvoice.llm.llm import Qwen2Encoder, Qwen2LM\n",
|
| 26 |
+
"from cosyvoice.tokenizer.tokenizer import get_qwen_tokenizer\n",
|
| 27 |
+
"from cosyvoice.transformer.upsample_encoder import UpsampleConformerEncoder\n",
|
| 28 |
+
"from cosyvoice.utils.common import ras_sampling\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"# Set CUDA device\n",
|
| 31 |
+
"# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # Use GPU 0\n",
|
| 32 |
+
"device = \"cuda:0\"\n",
|
| 33 |
+
"\n",
|
| 34 |
"\n",
|
| 35 |
+
"def set_deterministic_behavior(seed=42):\n",
|
| 36 |
+
" \"\"\"Set seeds for reproducibility across all random libraries\"\"\"\n",
|
| 37 |
+
" random.seed(seed)\n",
|
| 38 |
+
" np.random.seed(seed)\n",
|
| 39 |
+
" torch.manual_seed(seed)\n",
|
| 40 |
+
" torch.cuda.manual_seed_all(seed)\n",
|
| 41 |
+
" torch.backends.cudnn.deterministic = True\n",
|
| 42 |
+
" torch.backends.cudnn.benchmark = False\n",
|
| 43 |
+
" os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"# Call this function at the beginning of your script\n",
|
| 47 |
+
"set_deterministic_behavior(70000)"
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "code",
|
| 52 |
+
"execution_count": 13,
|
| 53 |
"id": "0322c8f4",
|
| 54 |
"metadata": {},
|
| 55 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
"source": [
|
| 57 |
+
"model_dir = './pretrained_models/CosyVoice2-0.5B'\n",
|
| 58 |
+
"allowed_special = 'all'\n",
|
| 59 |
+
"sample_rate = 24000\n",
|
| 60 |
+
"fp16 = False"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
]
|
| 62 |
},
|
| 63 |
{
|
| 64 |
"cell_type": "code",
|
| 65 |
+
"execution_count": null,
|
| 66 |
"id": "a0ba457c",
|
| 67 |
"metadata": {},
|
| 68 |
"outputs": [],
|
| 69 |
"source": [
|
| 70 |
+
"llm_config = {\n",
|
| 71 |
+
" 'llm_input_size': 896,\n",
|
| 72 |
+
" 'llm_output_size': 896,\n",
|
| 73 |
+
" 'speech_token_size': 6561,\n",
|
| 74 |
+
" 'length_normalized_loss': True,\n",
|
| 75 |
+
" 'lsm_weight': 0,\n",
|
| 76 |
+
" 'mix_ratio': [5, 15]\n",
|
| 77 |
+
"}\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"llm_encoder_config = {\n",
|
| 80 |
+
" 'pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')\n",
|
| 81 |
+
"}\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"sampling_config = {\n",
|
| 84 |
+
" 'top_p': 0.8,\n",
|
| 85 |
+
" 'top_k': 25,\n",
|
| 86 |
+
" 'win_size': 10,\n",
|
| 87 |
+
" 'tau_r': 0.1\n",
|
| 88 |
+
"}\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"flow_config = {\n",
|
| 91 |
+
" 'input_size': 512,\n",
|
| 92 |
+
" 'output_size': 80,\n",
|
| 93 |
+
" 'spk_embed_dim': 192,\n",
|
| 94 |
+
" 'output_type': 'mel',\n",
|
| 95 |
+
" 'vocab_size': 6561,\n",
|
| 96 |
+
" 'input_frame_rate': 25,\n",
|
| 97 |
+
" 'only_mask_loss': True,\n",
|
| 98 |
+
" 'token_mel_ratio': 2,\n",
|
| 99 |
+
" 'pre_lookahead_len': 3\n",
|
| 100 |
+
"}\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"encoder_config = {\n",
|
| 103 |
+
" 'output_size': 512,\n",
|
| 104 |
+
" 'attention_heads': 8,\n",
|
| 105 |
+
" 'linear_units': 2048,\n",
|
| 106 |
+
" 'num_blocks': 6,\n",
|
| 107 |
+
" 'dropout_rate': 0.1,\n",
|
| 108 |
+
" 'positional_dropout_rate': 0.1,\n",
|
| 109 |
+
" 'attention_dropout_rate': 0.1,\n",
|
| 110 |
+
" 'normalize_before': True,\n",
|
| 111 |
+
" 'input_layer': 'linear',\n",
|
| 112 |
+
" 'pos_enc_layer_type': 'rel_pos_espnet',\n",
|
| 113 |
+
" 'selfattention_layer_type': 'rel_selfattn',\n",
|
| 114 |
+
" 'input_size': 512,\n",
|
| 115 |
+
" 'use_cnn_module': False,\n",
|
| 116 |
+
" 'macaron_style': False\n",
|
| 117 |
+
"}\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"decoder_config = {\n",
|
| 120 |
+
" 'in_channels': 240,\n",
|
| 121 |
+
" 'n_spks': 1,\n",
|
| 122 |
+
" 'spk_emb_dim': 80,\n",
|
| 123 |
+
" 'cfm_params': {\n",
|
| 124 |
+
" 'sigma_min': 1e-06,\n",
|
| 125 |
+
" 'solver': 'euler',\n",
|
| 126 |
+
" 't_scheduler': 'cosine',\n",
|
| 127 |
+
" 'training_cfg_rate': 0.2,\n",
|
| 128 |
+
" 'inference_cfg_rate': 0.7,\n",
|
| 129 |
+
" 'reg_loss_type': 'l1',\n",
|
| 130 |
+
" 'use_immiscible': True,\n",
|
| 131 |
+
" 'immiscible_k': 8,\n",
|
| 132 |
+
" 'use_contrastive_fm': True,\n",
|
| 133 |
+
" 'contrastive_lambda': 0.05,\n",
|
| 134 |
+
" }\n",
|
| 135 |
+
"}\n",
|
| 136 |
+
"decoder_config['cfm_params'] = OmegaConf.create(decoder_config['cfm_params'])\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"estimator_config = {\n",
|
| 139 |
+
" 'in_channels': 320,\n",
|
| 140 |
+
" 'out_channels': 80,\n",
|
| 141 |
+
" 'channels': [256],\n",
|
| 142 |
+
" 'dropout': 0.0,\n",
|
| 143 |
+
" 'attention_head_dim': 64,\n",
|
| 144 |
+
" 'n_blocks': 4,\n",
|
| 145 |
+
" 'num_mid_blocks': 12,\n",
|
| 146 |
+
" 'num_heads': 8,\n",
|
| 147 |
+
" 'act_fn': 'gelu',\n",
|
| 148 |
+
" 'static_chunk_size': 50,\n",
|
| 149 |
+
" 'num_decoding_left_chunks': 2\n",
|
| 150 |
+
" }\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"f0_predictor_config = {\n",
|
| 153 |
+
" 'num_class': 1,\n",
|
| 154 |
+
" 'in_channels': 80,\n",
|
| 155 |
+
" 'cond_channels': 512,\n",
|
| 156 |
+
"}\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"hift_config = {\n",
|
| 159 |
+
" 'in_channels': 80,\n",
|
| 160 |
+
" 'base_channels': 512,\n",
|
| 161 |
+
" 'nb_harmonics': 8,\n",
|
| 162 |
+
" 'sampling_rate': 24000,\n",
|
| 163 |
+
" 'nsf_alpha': 0.1,\n",
|
| 164 |
+
" 'nsf_sigma': 0.003,\n",
|
| 165 |
+
" 'nsf_voiced_threshold': 10,\n",
|
| 166 |
+
" 'upsample_rates': [8, 5, 3],\n",
|
| 167 |
+
" 'upsample_kernel_sizes': [16, 11, 7],\n",
|
| 168 |
+
" 'istft_params': {\n",
|
| 169 |
+
" 'n_fft': 16,\n",
|
| 170 |
+
" 'hop_len': 4,\n",
|
| 171 |
+
" },\n",
|
| 172 |
+
" 'resblock_kernel_sizes': [3, 7, 11],\n",
|
| 173 |
+
" 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n",
|
| 174 |
+
" 'source_resblock_kernel_sizes': [7, 7, 11],\n",
|
| 175 |
+
" 'source_resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n",
|
| 176 |
+
" 'lrelu_slope': 0.1,\n",
|
| 177 |
+
" 'audio_limit': 0.99,\n",
|
| 178 |
+
"}"
|
| 179 |
]
|
| 180 |
},
|
| 181 |
{
|
| 182 |
"cell_type": "code",
|
| 183 |
+
"execution_count": 15,
|
| 184 |
"id": "03fe8925",
|
| 185 |
"metadata": {},
|
| 186 |
"outputs": [],
|
| 187 |
"source": [
|
| 188 |
+
"llm_encoder = Qwen2Encoder(**llm_encoder_config)\n",
|
| 189 |
+
"llm_model = Qwen2LM(llm=llm_encoder, **llm_config, sampling=ras_sampling)"
|
| 190 |
]
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
| 194 |
+
"execution_count": 16,
|
| 195 |
"id": "41bc6b44",
|
| 196 |
"metadata": {},
|
| 197 |
+
"outputs": [
|
| 198 |
+
{
|
| 199 |
+
"name": "stderr",
|
| 200 |
+
"output_type": "stream",
|
| 201 |
+
"text": [
|
| 202 |
+
"/home/mas/anaconda3/envs/learnable/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n",
|
| 203 |
+
" deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"ename": "ConfigAttributeError",
|
| 208 |
+
"evalue": "Missing key use_immiscible\n full_key: use_immiscible\n object_type=dict",
|
| 209 |
+
"output_type": "error",
|
| 210 |
+
"traceback": [
|
| 211 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 212 |
+
"\u001b[0;31mConfigAttributeError\u001b[0m Traceback (most recent call last)",
|
| 213 |
+
"Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m flow_encoder \u001b[38;5;241m=\u001b[39m UpsampleConformerEncoder(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mencoder_config)\n\u001b[1;32m 2\u001b[0m estimator \u001b[38;5;241m=\u001b[39m CausalConditionalDecoder(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mestimator_config)\n\u001b[0;32m----> 3\u001b[0m flow_decoder \u001b[38;5;241m=\u001b[39m \u001b[43mCausalConditionalCFM\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdecoder_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m flow \u001b[38;5;241m=\u001b[39m CausalMaskedDiffWithXvec(\n\u001b[1;32m 5\u001b[0m encoder\u001b[38;5;241m=\u001b[39mflow_encoder,\n\u001b[1;32m 6\u001b[0m decoder\u001b[38;5;241m=\u001b[39mflow_decoder,\n\u001b[1;32m 7\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mflow_config\n\u001b[1;32m 8\u001b[0m )\n",
|
| 214 |
+
"File \u001b[0;32m/data/learnable-speech/speech/cosyvoice/flow/flow_matching.py:329\u001b[0m, in \u001b[0;36mCausalConditionalCFM.__init__\u001b[0;34m(self, in_channels, cfm_params, n_spks, spk_emb_dim, estimator)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, in_channels, cfm_params, n_spks\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, spk_emb_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, estimator: torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 329\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43min_channels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcfm_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_spks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspk_emb_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 330\u001b[0m set_all_random_seed(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrand_noise \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn([\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m80\u001b[39m, \u001b[38;5;241m50\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m300\u001b[39m])\n",
|
| 215 |
+
"File \u001b[0;32m/data/learnable-speech/speech/cosyvoice/flow/flow_matching.py:35\u001b[0m, in \u001b[0;36mConditionalCFM.__init__\u001b[0;34m(self, in_channels, cfm_params, n_spks, spk_emb_dim, estimator)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# Just change the architecture of the estimator here\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mestimator \u001b[38;5;241m=\u001b[39m estimator\n\u001b[0;32m---> 35\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_immiscible \u001b[38;5;241m=\u001b[39m \u001b[43mcfm_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_immiscible\u001b[49m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimmiscible_k \u001b[38;5;241m=\u001b[39m cfm_params\u001b[38;5;241m.\u001b[39mimmiscible_k\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlambda_weight \u001b[38;5;241m=\u001b[39m cfm_params\u001b[38;5;241m.\u001b[39mcontrastive_lambda\n",
|
| 216 |
+
"File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:355\u001b[0m, in \u001b[0;36mDictConfig.__getattr__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_impl(\n\u001b[1;32m 352\u001b[0m key\u001b[38;5;241m=\u001b[39mkey, default_value\u001b[38;5;241m=\u001b[39m_DEFAULT_MARKER_, validate_key\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 353\u001b[0m )\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConfigKeyError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_format_and_raise\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 356\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtype_override\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mConfigAttributeError\u001b[49m\n\u001b[1;32m 357\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 358\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_and_raise(key\u001b[38;5;241m=\u001b[39mkey, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, cause\u001b[38;5;241m=\u001b[39me)\n",
|
| 217 |
+
"File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/base.py:231\u001b[0m, in \u001b[0;36mNode._format_and_raise\u001b[0;34m(self, key, value, cause, msg, type_override)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_format_and_raise\u001b[39m(\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 225\u001b[0m key: Any,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 229\u001b[0m type_override: Any \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 230\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 231\u001b[0m \u001b[43mformat_and_raise\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 232\u001b[0m \u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 234\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 235\u001b[0m \u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcause\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcause\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 237\u001b[0m \u001b[43m \u001b[49m\u001b[43mtype_override\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtype_override\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 238\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 239\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
| 218 |
+
"File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/_utils.py:899\u001b[0m, in \u001b[0;36mformat_and_raise\u001b[0;34m(node, key, value, msg, cause, type_override)\u001b[0m\n\u001b[1;32m 896\u001b[0m ex\u001b[38;5;241m.\u001b[39mref_type \u001b[38;5;241m=\u001b[39m ref_type\n\u001b[1;32m 897\u001b[0m ex\u001b[38;5;241m.\u001b[39mref_type_str \u001b[38;5;241m=\u001b[39m ref_type_str\n\u001b[0;32m--> 899\u001b[0m \u001b[43m_raise\u001b[49m\u001b[43m(\u001b[49m\u001b[43mex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcause\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 219 |
+
"File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/_utils.py:797\u001b[0m, in \u001b[0;36m_raise\u001b[0;34m(ex, cause)\u001b[0m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 796\u001b[0m ex\u001b[38;5;241m.\u001b[39m__cause__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 797\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ex\u001b[38;5;241m.\u001b[39mwith_traceback(sys\u001b[38;5;241m.\u001b[39mexc_info()[\u001b[38;5;241m2\u001b[39m])\n",
|
| 220 |
+
"File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:351\u001b[0m, in \u001b[0;36mDictConfig.__getattr__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m()\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdefault_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_DEFAULT_MARKER_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConfigKeyError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_and_raise(\n\u001b[1;32m 356\u001b[0m key\u001b[38;5;241m=\u001b[39mkey, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, cause\u001b[38;5;241m=\u001b[39me, type_override\u001b[38;5;241m=\u001b[39mConfigAttributeError\n\u001b[1;32m 357\u001b[0m )\n",
|
| 221 |
+
"File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:442\u001b[0m, in \u001b[0;36mDictConfig._get_impl\u001b[0;34m(self, key, default_value, validate_key)\u001b[0m\n\u001b[1;32m 438\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_get_impl\u001b[39m(\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28mself\u001b[39m, key: DictKeyType, default_value: Any, validate_key: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 440\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 442\u001b[0m node \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_child\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 443\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_key\u001b[49m\n\u001b[1;32m 444\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConfigAttributeError, ConfigKeyError):\n\u001b[1;32m 446\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m default_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _DEFAULT_MARKER_:\n",
|
| 222 |
+
"File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/basecontainer.py:73\u001b[0m, in \u001b[0;36mBaseContainer._get_child\u001b[0;34m(self, key, validate_access, validate_key, throw_on_missing_value, throw_on_missing_key)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_get_child\u001b[39m(\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 66\u001b[0m key: Any,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 70\u001b[0m throw_on_missing_key: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 71\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Optional[Node], List[Optional[Node]]]:\n\u001b[1;32m 72\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Like _get_node, passing through to the nearest concrete Node.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 73\u001b[0m child \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_node\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidate_access\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_access\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 76\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidate_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidate_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mthrow_on_missing_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mthrow_on_missing_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m \u001b[49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mthrow_on_missing_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(child, UnionNode) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_special(child):\n\u001b[1;32m 81\u001b[0m value \u001b[38;5;241m=\u001b[39m child\u001b[38;5;241m.\u001b[39m_value()\n",
|
| 223 |
+
"File \u001b[0;32m~/anaconda3/envs/learnable/lib/python3.10/site-packages/omegaconf/dictconfig.py:480\u001b[0m, in \u001b[0;36mDictConfig._get_node\u001b[0;34m(self, key, validate_access, validate_key, throw_on_missing_value, throw_on_missing_key)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m throw_on_missing_key:\n\u001b[0;32m--> 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ConfigKeyError(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMissing key \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m!s}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 481\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m throw_on_missing_value \u001b[38;5;129;01mand\u001b[39;00m value\u001b[38;5;241m.\u001b[39m_is_missing():\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MissingMandatoryValue(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMissing mandatory value: $KEY\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
| 224 |
+
"\u001b[0;31mConfigAttributeError\u001b[0m: Missing key use_immiscible\n full_key: use_immiscible\n object_type=dict"
|
| 225 |
+
]
|
| 226 |
+
}
|
| 227 |
+
],
|
| 228 |
"source": [
|
| 229 |
+
"flow_encoder = UpsampleConformerEncoder(**encoder_config)\n",
|
| 230 |
+
"estimator = CausalConditionalDecoder(**estimator_config)\n",
|
| 231 |
+
"flow_decoder = CausalConditionalCFM(**decoder_config, estimator=estimator)\n",
|
| 232 |
+
"flow = CausalMaskedDiffWithXvec(\n",
|
| 233 |
+
" encoder=flow_encoder,\n",
|
| 234 |
+
" decoder=flow_decoder,\n",
|
| 235 |
+
" **flow_config\n",
|
| 236 |
+
")"
|
| 237 |
]
|
| 238 |
},
|
| 239 |
{
|
| 240 |
"cell_type": "code",
|
| 241 |
+
"execution_count": null,
|
| 242 |
"id": "6f689e0b",
|
| 243 |
"metadata": {},
|
| 244 |
"outputs": [
|
|
|
|
| 254 |
}
|
| 255 |
],
|
| 256 |
"source": [
|
| 257 |
+
"f0_predictor = ConvRNNF0Predictor(**f0_predictor_config)\n",
|
| 258 |
+
"hifi = HiFTGenerator(**hift_config, f0_predictor=f0_predictor)"
|
| 259 |
]
|
| 260 |
},
|
| 261 |
{
|
|
|
|
| 393 |
"token_len"
|
| 394 |
]
|
| 395 |
},
|
| 396 |
+
{
|
| 397 |
+
"cell_type": "code",
|
| 398 |
+
"execution_count": 5,
|
| 399 |
+
"id": "2dcfa795",
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"outputs": [
|
| 402 |
+
{
|
| 403 |
+
"name": "stdout",
|
| 404 |
+
"output_type": "stream",
|
| 405 |
+
"text": [
|
| 406 |
+
"Testing ResumableSequentialLR:\n",
|
| 407 |
+
"--------------------------------------------------\n",
|
| 408 |
+
"Step LR Expected Match \n",
|
| 409 |
+
"--------------------------------------------------\n",
|
| 410 |
+
"0 1.000000e-04 1.000000e-04 ✓ \n",
|
| 411 |
+
"1 2.800000e-04 2.800000e-04 ✓ \n",
|
| 412 |
+
"2 4.600000e-04 4.600000e-04 ✓ \n",
|
| 413 |
+
"3 6.400000e-04 6.400000e-04 ✓ \n",
|
| 414 |
+
"4 8.200000e-04 8.200000e-04 ✓ \n",
|
| 415 |
+
"5 1.000000e-03 1.000000e-03 ✓ \n",
|
| 416 |
+
"6 1.000000e-03 1.000000e-03 ✓ \n",
|
| 417 |
+
"7 1.000000e-03 1.000000e-03 ✓ \n",
|
| 418 |
+
"8 1.000000e-03 1.000000e-03 ✓ \n",
|
| 419 |
+
"9 1.000000e-03 1.000000e-03 ✓ \n",
|
| 420 |
+
"\n",
|
| 421 |
+
"Testing resume from step 7:\n",
|
| 422 |
+
"--------------------------------------------------\n",
|
| 423 |
+
"7 1.000000e-03 1.000000e-03 ✓ \n",
|
| 424 |
+
"8 1.000000e-03 1.000000e-03 ✓ \n",
|
| 425 |
+
"9 1.000000e-03 1.000000e-03 ✓ \n"
|
| 426 |
+
]
|
| 427 |
+
}
|
| 428 |
+
],
|
| 429 |
+
"source": [
|
| 430 |
+
"from torch.optim.lr_scheduler import _LRScheduler\n",
|
| 431 |
+
"import warnings\n",
|
| 432 |
+
"\n",
|
| 433 |
+
"class ResumableSequentialLR(_LRScheduler):\n",
|
| 434 |
+
" \"\"\"A resumable version of SequentialLR that properly manages child schedulers\"\"\"\n",
|
| 435 |
+
" \n",
|
| 436 |
+
" def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):\n",
|
| 437 |
+
" \"\"\"\n",
|
| 438 |
+
" Args:\n",
|
| 439 |
+
" optimizer: Wrapped optimizer\n",
|
| 440 |
+
" schedulers: List of schedulers to sequentially use\n",
|
| 441 |
+
" milestones: List of epoch/step numbers when to switch schedulers\n",
|
| 442 |
+
" last_epoch: The index of last epoch/step\n",
|
| 443 |
+
" \"\"\"\n",
|
| 444 |
+
" # Validate inputs\n",
|
| 445 |
+
" if len(schedulers) != len(milestones) + 1:\n",
|
| 446 |
+
" raise ValueError(\"Expected len(schedulers) == len(milestones) + 1\")\n",
|
| 447 |
+
" \n",
|
| 448 |
+
" self.schedulers = schedulers\n",
|
| 449 |
+
" self.milestones = milestones\n",
|
| 450 |
+
" self._scheduler_idx = 0\n",
|
| 451 |
+
" \n",
|
| 452 |
+
" # Initialize parent class (this sets last_epoch and calls step())\n",
|
| 453 |
+
" super().__init__(optimizer, last_epoch)\n",
|
| 454 |
+
" \n",
|
| 455 |
+
" def _get_scheduler_info(self, epoch):\n",
|
| 456 |
+
" \"\"\"Determine which scheduler to use and its relative epoch\"\"\"\n",
|
| 457 |
+
" scheduler_idx = 0\n",
|
| 458 |
+
" relative_epoch = epoch\n",
|
| 459 |
+
" \n",
|
| 460 |
+
" for i, milestone in enumerate(self.milestones):\n",
|
| 461 |
+
" if epoch >= milestone:\n",
|
| 462 |
+
" scheduler_idx = i + 1\n",
|
| 463 |
+
" if i == 0:\n",
|
| 464 |
+
" relative_epoch = epoch - milestone\n",
|
| 465 |
+
" else:\n",
|
| 466 |
+
" relative_epoch = epoch - milestone\n",
|
| 467 |
+
" else:\n",
|
| 468 |
+
" break\n",
|
| 469 |
+
" \n",
|
| 470 |
+
" # Calculate relative epoch for the current scheduler\n",
|
| 471 |
+
" if scheduler_idx == 0:\n",
|
| 472 |
+
" relative_epoch = epoch\n",
|
| 473 |
+
" elif scheduler_idx < len(self.milestones):\n",
|
| 474 |
+
" if scheduler_idx == 1:\n",
|
| 475 |
+
" relative_epoch = epoch - self.milestones[0]\n",
|
| 476 |
+
" else:\n",
|
| 477 |
+
" relative_epoch = epoch - self.milestones[scheduler_idx - 1]\n",
|
| 478 |
+
" \n",
|
| 479 |
+
" return scheduler_idx, relative_epoch\n",
|
| 480 |
+
" \n",
|
| 481 |
+
" def get_lr(self):\n",
|
| 482 |
+
" \"\"\"Get learning rate from the appropriate scheduler\"\"\"\n",
|
| 483 |
+
" if not self._get_lr_called_within_step:\n",
|
| 484 |
+
" warnings.warn(\"To get the last learning rate computed by the scheduler, \"\n",
|
| 485 |
+
" \"please use `get_last_lr()`.\", UserWarning)\n",
|
| 486 |
+
" \n",
|
| 487 |
+
" # Get current scheduler and its relative epoch\n",
|
| 488 |
+
" scheduler_idx, relative_epoch = self._get_scheduler_info(self.last_epoch)\n",
|
| 489 |
+
" scheduler = self.schedulers[scheduler_idx]\n",
|
| 490 |
+
" \n",
|
| 491 |
+
" # Set the scheduler's last_epoch to match relative progress\n",
|
| 492 |
+
" scheduler.last_epoch = relative_epoch\n",
|
| 493 |
+
" \n",
|
| 494 |
+
" # Get LR from the scheduler\n",
|
| 495 |
+
" if hasattr(scheduler, '_get_closed_form_lr'):\n",
|
| 496 |
+
" return scheduler._get_closed_form_lr()\n",
|
| 497 |
+
" else:\n",
|
| 498 |
+
" # Temporarily set the flag to avoid warning from child scheduler\n",
|
| 499 |
+
" scheduler._get_lr_called_within_step = True\n",
|
| 500 |
+
" lrs = scheduler.get_lr()\n",
|
| 501 |
+
" scheduler._get_lr_called_within_step = False\n",
|
| 502 |
+
" return lrs\n",
|
| 503 |
+
" \n",
|
| 504 |
+
" def step(self, epoch=None):\n",
|
| 505 |
+
" \"\"\"Step the scheduler\"\"\"\n",
|
| 506 |
+
" # Step the parent class (updates last_epoch and sets _get_lr_called_within_step)\n",
|
| 507 |
+
" super().step(epoch)\n",
|
| 508 |
+
" \n",
|
| 509 |
+
" def set_step(self, step):\n",
|
| 510 |
+
" \"\"\"Set the current step for resuming training\"\"\"\n",
|
| 511 |
+
" self.last_epoch = step - 1\n",
|
| 512 |
+
" \n",
|
| 513 |
+
" # Update child schedulers' state\n",
|
| 514 |
+
" scheduler_idx, relative_epoch = self._get_scheduler_info(step - 1)\n",
|
| 515 |
+
" \n",
|
| 516 |
+
" # Set all previous schedulers to their final state\n",
|
| 517 |
+
" for i in range(scheduler_idx):\n",
|
| 518 |
+
" if i < len(self.milestones):\n",
|
| 519 |
+
" if i == 0:\n",
|
| 520 |
+
" self.schedulers[i].last_epoch = self.milestones[i] - 1\n",
|
| 521 |
+
" else:\n",
|
| 522 |
+
" self.schedulers[i].last_epoch = self.milestones[i] - self.milestones[i-1] - 1\n",
|
| 523 |
+
" \n",
|
| 524 |
+
" # Set current scheduler to its relative position\n",
|
| 525 |
+
" self.schedulers[scheduler_idx].last_epoch = relative_epoch\n",
|
| 526 |
+
" \n",
|
| 527 |
+
" # Update optimizer's learning rates\n",
|
| 528 |
+
" for param_group, lr in zip(self.optimizer.param_groups, self.get_last_lr()):\n",
|
| 529 |
+
" param_group['lr'] = lr\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"# Alternative simpler implementation that's more robust\n",
|
| 533 |
+
"class SimpleResumableSequentialLR(_LRScheduler):\n",
|
| 534 |
+
" \"\"\"Simpler implementation that manually tracks scheduler states\"\"\"\n",
|
| 535 |
+
" \n",
|
| 536 |
+
" def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):\n",
|
| 537 |
+
" self.schedulers = schedulers\n",
|
| 538 |
+
" self.milestones = milestones\n",
|
| 539 |
+
" super().__init__(optimizer, last_epoch)\n",
|
| 540 |
+
" \n",
|
| 541 |
+
" def get_lr(self):\n",
|
| 542 |
+
" \"\"\"Calculate learning rate based on current epoch\"\"\"\n",
|
| 543 |
+
" epoch = self.last_epoch\n",
|
| 544 |
+
" \n",
|
| 545 |
+
" # For LinearLR with warmup\n",
|
| 546 |
+
" if epoch < self.milestones[0]:\n",
|
| 547 |
+
" # We're in warmup phase\n",
|
| 548 |
+
" warmup_scheduler = self.schedulers[0]\n",
|
| 549 |
+
" start_factor = warmup_scheduler.start_factor\n",
|
| 550 |
+
" end_factor = warmup_scheduler.end_factor\n",
|
| 551 |
+
" total_iters = warmup_scheduler.total_iters\n",
|
| 552 |
+
" \n",
|
| 553 |
+
" # Calculate factor\n",
|
| 554 |
+
" if epoch >= total_iters:\n",
|
| 555 |
+
" factor = end_factor\n",
|
| 556 |
+
" else:\n",
|
| 557 |
+
" factor = start_factor + (end_factor - start_factor) * epoch / total_iters\n",
|
| 558 |
+
" \n",
|
| 559 |
+
" # Apply factor to base learning rates\n",
|
| 560 |
+
" return [base_lr * factor for base_lr in self.base_lrs]\n",
|
| 561 |
+
" else:\n",
|
| 562 |
+
" # We're in constant phase - just return base LRs\n",
|
| 563 |
+
" return [base_lr * 1.0 for base_lr in self.base_lrs]\n",
|
| 564 |
+
"\n",
|
| 565 |
+
"\n",
|
| 566 |
+
"# Test function to verify the scheduler works correctly\n",
|
| 567 |
+
"def test_resumable_scheduler():\n",
|
| 568 |
+
" \"\"\"Test the ResumableSequentialLR implementation\"\"\"\n",
|
| 569 |
+
" import torch\n",
|
| 570 |
+
" import torch.optim as optim\n",
|
| 571 |
+
" from torch.optim.lr_scheduler import LinearLR, ConstantLR\n",
|
| 572 |
+
" \n",
|
| 573 |
+
" # Create dummy model and optimizer\n",
|
| 574 |
+
" model = torch.nn.Linear(10, 1)\n",
|
| 575 |
+
" base_lr = 1e-3\n",
|
| 576 |
+
" optimizer = optim.Adam(model.parameters(), lr=base_lr)\n",
|
| 577 |
+
" \n",
|
| 578 |
+
" # Create schedulers\n",
|
| 579 |
+
" warmup_steps = 5\n",
|
| 580 |
+
" warmup_scheduler = LinearLR(\n",
|
| 581 |
+
" optimizer,\n",
|
| 582 |
+
" start_factor=0.1,\n",
|
| 583 |
+
" end_factor=1.0,\n",
|
| 584 |
+
" total_iters=warmup_steps\n",
|
| 585 |
+
" )\n",
|
| 586 |
+
" \n",
|
| 587 |
+
" constant_scheduler = ConstantLR(\n",
|
| 588 |
+
" optimizer,\n",
|
| 589 |
+
" factor=1.0,\n",
|
| 590 |
+
" total_iters=float('inf')\n",
|
| 591 |
+
" )\n",
|
| 592 |
+
" \n",
|
| 593 |
+
" # Test both implementations\n",
|
| 594 |
+
" print(\"Testing ResumableSequentialLR:\")\n",
|
| 595 |
+
" print(\"-\" * 50)\n",
|
| 596 |
+
" \n",
|
| 597 |
+
" # Reset optimizer\n",
|
| 598 |
+
" for param_group in optimizer.param_groups:\n",
|
| 599 |
+
" param_group['lr'] = base_lr\n",
|
| 600 |
+
" \n",
|
| 601 |
+
" scheduler = ResumableSequentialLR(\n",
|
| 602 |
+
" optimizer,\n",
|
| 603 |
+
" schedulers=[warmup_scheduler, constant_scheduler],\n",
|
| 604 |
+
" milestones=[warmup_steps]\n",
|
| 605 |
+
" )\n",
|
| 606 |
+
" \n",
|
| 607 |
+
" print(f\"{'Step':<10} {'LR':<15} {'Expected':<15} {'Match':<10}\")\n",
|
| 608 |
+
" print(\"-\" * 50)\n",
|
| 609 |
+
" \n",
|
| 610 |
+
" for step in range(10):\n",
|
| 611 |
+
" current_lr = optimizer.param_groups[0]['lr']\n",
|
| 612 |
+
" \n",
|
| 613 |
+
" # Calculate expected LR\n",
|
| 614 |
+
" if step < warmup_steps:\n",
|
| 615 |
+
" expected_lr = base_lr * (0.1 + 0.9 * step / warmup_steps)\n",
|
| 616 |
+
" else:\n",
|
| 617 |
+
" expected_lr = base_lr\n",
|
| 618 |
+
" \n",
|
| 619 |
+
" match = \"✓\" if abs(current_lr - expected_lr) < 1e-10 else \"✗\"\n",
|
| 620 |
+
" print(f\"{step:<10} {current_lr:<15.6e} {expected_lr:<15.6e} {match:<10}\")\n",
|
| 621 |
+
" \n",
|
| 622 |
+
" scheduler.step()\n",
|
| 623 |
+
" \n",
|
| 624 |
+
" # Test resuming\n",
|
| 625 |
+
" print(\"\\nTesting resume from step 7:\")\n",
|
| 626 |
+
" print(\"-\" * 50)\n",
|
| 627 |
+
" \n",
|
| 628 |
+
" # Reset and jump to step 7\n",
|
| 629 |
+
" for param_group in optimizer.param_groups:\n",
|
| 630 |
+
" param_group['lr'] = base_lr\n",
|
| 631 |
+
" \n",
|
| 632 |
+
" scheduler = ResumableSequentialLR(\n",
|
| 633 |
+
" optimizer,\n",
|
| 634 |
+
" schedulers=[warmup_scheduler, constant_scheduler],\n",
|
| 635 |
+
" milestones=[warmup_steps]\n",
|
| 636 |
+
" )\n",
|
| 637 |
+
" scheduler.set_step(7)\n",
|
| 638 |
+
" \n",
|
| 639 |
+
" for step in range(7, 10):\n",
|
| 640 |
+
" scheduler.step()\n",
|
| 641 |
+
" current_lr = optimizer.param_groups[0]['lr']\n",
|
| 642 |
+
" expected_lr = base_lr # Should be constant phase\n",
|
| 643 |
+
" match = \"✓\" if abs(current_lr - expected_lr) < 1e-10 else \"✗\"\n",
|
| 644 |
+
" print(f\"{step:<10} {current_lr:<15.6e} {expected_lr:<15.6e} {match:<10}\")\n",
|
| 645 |
+
"\n",
|
| 646 |
+
"\n",
|
| 647 |
+
"if __name__ == \"__main__\":\n",
|
| 648 |
+
" test_resumable_scheduler()"
|
| 649 |
+
]
|
| 650 |
+
},
|
| 651 |
+
{
|
| 652 |
+
"cell_type": "code",
|
| 653 |
+
"execution_count": null,
|
| 654 |
+
"id": "ce71bea4",
|
| 655 |
+
"metadata": {},
|
| 656 |
+
"outputs": [],
|
| 657 |
+
"source": []
|
| 658 |
+
},
|
| 659 |
+
{
|
| 660 |
+
"cell_type": "code",
|
| 661 |
+
"execution_count": null,
|
| 662 |
+
"id": "42b9b936",
|
| 663 |
+
"metadata": {},
|
| 664 |
+
"outputs": [],
|
| 665 |
+
"source": []
|
| 666 |
+
},
|
| 667 |
+
{
|
| 668 |
+
"cell_type": "code",
|
| 669 |
+
"execution_count": 3,
|
| 670 |
+
"id": "e3d4d5a1",
|
| 671 |
+
"metadata": {},
|
| 672 |
+
"outputs": [
|
| 673 |
+
{
|
| 674 |
+
"name": "stdout",
|
| 675 |
+
"output_type": "stream",
|
| 676 |
+
"text": [
|
| 677 |
+
"=== Learning Rate Source Verification ===\n",
|
| 678 |
+
"\n",
|
| 679 |
+
"Comparing LR sources during warmup:\n",
|
| 680 |
+
"\n",
|
| 681 |
+
"Step Optimizer LR Scheduler LR Match? \n",
|
| 682 |
+
"--------------------------------------------------\n",
|
| 683 |
+
"0 1.00e-04 1.00e-04 ✓ \n",
|
| 684 |
+
"1 2.80e-04 2.80e-04 ✓ \n",
|
| 685 |
+
"2 4.60e-04 4.60e-04 ✓ \n",
|
| 686 |
+
"3 6.40e-04 6.40e-04 ✓ \n",
|
| 687 |
+
"4 8.20e-04 8.20e-04 ✓ \n",
|
| 688 |
+
"5 1.00e-03 1.00e-03 ✓ \n",
|
| 689 |
+
"6 1.00e-03 1.00e-03 ✓ \n",
|
| 690 |
+
"7 1.00e-03 1.00e-03 ✓ \n",
|
| 691 |
+
"8 1.00e-03 1.00e-03 ✓ \n",
|
| 692 |
+
"9 1.00e-03 1.00e-03 ✓ \n",
|
| 693 |
+
"\n",
|
| 694 |
+
"Conclusion: optimizer.param_groups[0]['lr'] is the authoritative source!\n",
|
| 695 |
+
"\n",
|
| 696 |
+
"\n",
|
| 697 |
+
"Manual LR change test:\n",
|
| 698 |
+
"Current optimizer LR: 1.00e-03\n",
|
| 699 |
+
"After manual change: 1.00e-02\n",
|
| 700 |
+
"This confirms the optimizer holds the actual LR being used.\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"==================================================\n",
|
| 703 |
+
"\n",
|
| 704 |
+
"\n",
|
| 705 |
+
"Different ways to access learning rate:\n",
|
| 706 |
+
"\n",
|
| 707 |
+
"Initial state:\n",
|
| 708 |
+
" optimizer.param_groups[0]['lr']: 1.00e-04\n",
|
| 709 |
+
" scheduler.get_last_lr(): 1.00e-04\n",
|
| 710 |
+
"\n",
|
| 711 |
+
"After scheduler.step():\n",
|
| 712 |
+
" optimizer.param_groups[0]['lr']: 2.80e-04\n",
|
| 713 |
+
" scheduler.get_last_lr(): 2.80e-04\n",
|
| 714 |
+
"\n",
|
| 715 |
+
"Key insights:\n",
|
| 716 |
+
"1. optimizer.param_groups[0]['lr'] - Always current, used by optimizer\n",
|
| 717 |
+
"2. scheduler.get_last_lr() - What scheduler set on last step()\n",
|
| 718 |
+
"3. scheduler.get_lr() - Internal method, calculates next LR (don't use directly)\n",
|
| 719 |
+
"\n",
|
| 720 |
+
"==================================================\n",
|
| 721 |
+
"\n",
|
| 722 |
+
"\n",
|
| 723 |
+
"Multiple parameter groups:\n",
|
| 724 |
+
" Group 0: lr = 1.00e-03\n",
|
| 725 |
+
" Group 1: lr = 1.00e-04\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"After scheduler step:\n",
|
| 728 |
+
" Group 0: lr = 2.80e-04\n",
|
| 729 |
+
" Group 1: lr = 2.80e-05\n"
|
| 730 |
+
]
|
| 731 |
+
}
|
| 732 |
+
],
|
| 733 |
+
"source": [
|
| 734 |
+
"import torch\n",
|
| 735 |
+
"import torch.optim as optim\n",
|
| 736 |
+
"from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"def verify_lr_sources():\n",
|
| 739 |
+
" \"\"\"Verify that optimizer.param_groups[0]['lr'] is the correct source\"\"\"\n",
|
| 740 |
+
" \n",
|
| 741 |
+
" # Create a simple model and optimizer\n",
|
| 742 |
+
" model = torch.nn.Linear(10, 1)\n",
|
| 743 |
+
" optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
|
| 744 |
+
" \n",
|
| 745 |
+
" # Create schedulers\n",
|
| 746 |
+
" warmup_scheduler = LinearLR(\n",
|
| 747 |
+
" optimizer,\n",
|
| 748 |
+
" start_factor=0.1, # Start at 10% of base LR\n",
|
| 749 |
+
" end_factor=1.0, # End at 100% of base LR\n",
|
| 750 |
+
" total_iters=5 # 5 warmup steps\n",
|
| 751 |
+
" )\n",
|
| 752 |
+
" \n",
|
| 753 |
+
" constant_scheduler = ConstantLR(\n",
|
| 754 |
+
" optimizer,\n",
|
| 755 |
+
" factor=1.0,\n",
|
| 756 |
+
" total_iters=float('inf')\n",
|
| 757 |
+
" )\n",
|
| 758 |
+
" \n",
|
| 759 |
+
" scheduler = SequentialLR(\n",
|
| 760 |
+
" optimizer,\n",
|
| 761 |
+
" schedulers=[warmup_scheduler, constant_scheduler],\n",
|
| 762 |
+
" milestones=[5]\n",
|
| 763 |
+
" )\n",
|
| 764 |
+
" \n",
|
| 765 |
+
" print(\"Comparing LR sources during warmup:\\n\")\n",
|
| 766 |
+
" print(f\"{'Step':<6} {'Optimizer LR':<15} {'Scheduler LR':<15} {'Match?':<10}\")\n",
|
| 767 |
+
" print(\"-\" * 50)\n",
|
| 768 |
+
" \n",
|
| 769 |
+
" for step in range(10):\n",
|
| 770 |
+
" # Get LR from optimizer\n",
|
| 771 |
+
" optimizer_lr = optimizer.param_groups[0]['lr']\n",
|
| 772 |
+
" \n",
|
| 773 |
+
" # Get LR from scheduler (if available)\n",
|
| 774 |
+
" # Note: scheduler.get_last_lr() returns the LR after the last step\n",
|
| 775 |
+
" scheduler_lr = scheduler.get_last_lr()[0] if hasattr(scheduler, 'get_last_lr') else None\n",
|
| 776 |
+
" \n",
|
| 777 |
+
" # Print comparison\n",
|
| 778 |
+
" match = \"✓\" if scheduler_lr is None or abs(optimizer_lr - scheduler_lr) < 1e-10 else \"✗\"\n",
|
| 779 |
+
" print(f\"{step:<6} {optimizer_lr:<15.2e} {scheduler_lr:<15.2e} {match:<10}\")\n",
|
| 780 |
+
" \n",
|
| 781 |
+
" # Step the scheduler\n",
|
| 782 |
+
" scheduler.step()\n",
|
| 783 |
+
" \n",
|
| 784 |
+
" print(\"\\nConclusion: optimizer.param_groups[0]['lr'] is the authoritative source!\")\n",
|
| 785 |
+
" \n",
|
| 786 |
+
" # Additional verification: what happens if we manually change the optimizer's LR?\n",
|
| 787 |
+
" print(\"\\n\\nManual LR change test:\")\n",
|
| 788 |
+
" print(f\"Current optimizer LR: {optimizer.param_groups[0]['lr']:.2e}\")\n",
|
| 789 |
+
" \n",
|
| 790 |
+
" # Manually change it\n",
|
| 791 |
+
" for param_group in optimizer.param_groups:\n",
|
| 792 |
+
" param_group['lr'] = 0.01\n",
|
| 793 |
+
" \n",
|
| 794 |
+
" print(f\"After manual change: {optimizer.param_groups[0]['lr']:.2e}\")\n",
|
| 795 |
+
" print(\"This confirms the optimizer holds the actual LR being used.\")\n",
|
| 796 |
+
"\n",
|
| 797 |
+
"\n",
|
| 798 |
+
"def compare_lr_access_methods():\n",
|
| 799 |
+
" \"\"\"Compare different ways to access the learning rate\"\"\"\n",
|
| 800 |
+
" \n",
|
| 801 |
+
" model = torch.nn.Linear(10, 1)\n",
|
| 802 |
+
" optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
|
| 803 |
+
" \n",
|
| 804 |
+
" scheduler = LinearLR(\n",
|
| 805 |
+
" optimizer,\n",
|
| 806 |
+
" start_factor=0.1,\n",
|
| 807 |
+
" end_factor=1.0,\n",
|
| 808 |
+
" total_iters=5\n",
|
| 809 |
+
" )\n",
|
| 810 |
+
" \n",
|
| 811 |
+
" print(\"\\nDifferent ways to access learning rate:\\n\")\n",
|
| 812 |
+
" \n",
|
| 813 |
+
" # Before any steps\n",
|
| 814 |
+
" print(\"Initial state:\")\n",
|
| 815 |
+
" print(f\" optimizer.param_groups[0]['lr']: {optimizer.param_groups[0]['lr']:.2e}\")\n",
|
| 816 |
+
" print(f\" scheduler.get_last_lr(): {scheduler.get_last_lr()[0]:.2e}\")\n",
|
| 817 |
+
" \n",
|
| 818 |
+
" # After stepping\n",
|
| 819 |
+
" scheduler.step()\n",
|
| 820 |
+
" print(\"\\nAfter scheduler.step():\")\n",
|
| 821 |
+
" print(f\" optimizer.param_groups[0]['lr']: {optimizer.param_groups[0]['lr']:.2e}\")\n",
|
| 822 |
+
" print(f\" scheduler.get_last_lr(): {scheduler.get_last_lr()[0]:.2e}\")\n",
|
| 823 |
+
" \n",
|
| 824 |
+
" # Key insight\n",
|
| 825 |
+
" print(\"\\nKey insights:\")\n",
|
| 826 |
+
" print(\"1. optimizer.param_groups[0]['lr'] - Always current, used by optimizer\")\n",
|
| 827 |
+
" print(\"2. scheduler.get_last_lr() - What scheduler set on last step()\")\n",
|
| 828 |
+
" print(\"3. scheduler.get_lr() - Internal method, calculates next LR (don't use directly)\")\n",
|
| 829 |
+
"\n",
|
| 830 |
+
"\n",
|
| 831 |
+
"def check_multiple_param_groups():\n",
|
| 832 |
+
" \"\"\"Check how LR works with multiple parameter groups\"\"\"\n",
|
| 833 |
+
" \n",
|
| 834 |
+
" model = torch.nn.Sequential(\n",
|
| 835 |
+
" torch.nn.Linear(10, 20),\n",
|
| 836 |
+
" torch.nn.Linear(20, 1)\n",
|
| 837 |
+
" )\n",
|
| 838 |
+
" \n",
|
| 839 |
+
" # Different LRs for different layers\n",
|
| 840 |
+
" optimizer = optim.Adam([\n",
|
| 841 |
+
" {'params': model[0].parameters(), 'lr': 1e-3},\n",
|
| 842 |
+
" {'params': model[1].parameters(), 'lr': 1e-4}\n",
|
| 843 |
+
" ])\n",
|
| 844 |
+
" \n",
|
| 845 |
+
" print(\"\\nMultiple parameter groups:\")\n",
|
| 846 |
+
" for i, param_group in enumerate(optimizer.param_groups):\n",
|
| 847 |
+
" print(f\" Group {i}: lr = {param_group['lr']:.2e}\")\n",
|
| 848 |
+
" \n",
|
| 849 |
+
" # Scheduler affects all groups\n",
|
| 850 |
+
" scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=5)\n",
|
| 851 |
+
" scheduler.step()\n",
|
| 852 |
+
" \n",
|
| 853 |
+
" print(\"\\nAfter scheduler step:\")\n",
|
| 854 |
+
" for i, param_group in enumerate(optimizer.param_groups):\n",
|
| 855 |
+
" print(f\" Group {i}: lr = {param_group['lr']:.2e}\")\n",
|
| 856 |
+
"\n",
|
| 857 |
+
"\n",
|
| 858 |
+
"if __name__ == \"__main__\":\n",
|
| 859 |
+
" print(\"=== Learning Rate Source Verification ===\\n\")\n",
|
| 860 |
+
" verify_lr_sources()\n",
|
| 861 |
+
" print(\"\\n\" + \"=\"*50 + \"\\n\")\n",
|
| 862 |
+
" compare_lr_access_methods()\n",
|
| 863 |
+
" print(\"\\n\" + \"=\"*50 + \"\\n\")\n",
|
| 864 |
+
" check_multiple_param_groups()"
|
| 865 |
+
]
|
| 866 |
+
},
|
| 867 |
+
{
|
| 868 |
+
"cell_type": "code",
|
| 869 |
+
"execution_count": null,
|
| 870 |
+
"id": "918d3322",
|
| 871 |
+
"metadata": {},
|
| 872 |
+
"outputs": [],
|
| 873 |
+
"source": []
|
| 874 |
+
},
|
| 875 |
+
{
|
| 876 |
+
"cell_type": "code",
|
| 877 |
+
"execution_count": null,
|
| 878 |
+
"id": "eb19ac5e",
|
| 879 |
+
"metadata": {},
|
| 880 |
+
"outputs": [],
|
| 881 |
+
"source": []
|
| 882 |
+
},
|
| 883 |
+
{
|
| 884 |
+
"cell_type": "code",
|
| 885 |
+
"execution_count": null,
|
| 886 |
+
"id": "7f2c3038",
|
| 887 |
+
"metadata": {},
|
| 888 |
+
"outputs": [],
|
| 889 |
+
"source": []
|
| 890 |
+
},
|
| 891 |
+
{
|
| 892 |
+
"cell_type": "code",
|
| 893 |
+
"execution_count": null,
|
| 894 |
+
"id": "4f528b78",
|
| 895 |
+
"metadata": {},
|
| 896 |
+
"outputs": [],
|
| 897 |
+
"source": []
|
| 898 |
+
},
|
| 899 |
+
{
|
| 900 |
+
"cell_type": "code",
|
| 901 |
+
"execution_count": null,
|
| 902 |
+
"id": "f0fcea90",
|
| 903 |
+
"metadata": {},
|
| 904 |
+
"outputs": [],
|
| 905 |
+
"source": []
|
| 906 |
+
},
|
| 907 |
+
{
|
| 908 |
+
"cell_type": "code",
|
| 909 |
+
"execution_count": null,
|
| 910 |
+
"id": "bb5de4ae",
|
| 911 |
+
"metadata": {},
|
| 912 |
+
"outputs": [],
|
| 913 |
+
"source": []
|
| 914 |
+
},
|
| 915 |
{
|
| 916 |
"cell_type": "markdown",
|
| 917 |
"id": "fbf1de4d",
|
speech/test_train.sh
CHANGED
|
@@ -66,14 +66,12 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --rdzv_id=$job_id --rdzv_backend=
|
|
| 66 |
--cv_data data/data.list \
|
| 67 |
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
| 68 |
--model $model \
|
| 69 |
-
--checkpoint
|
| 70 |
-
--model_dir /mnt/nvme/speech/$model/ \
|
| 71 |
--num_workers ${num_workers} \
|
| 72 |
--prefetch ${prefetch} \
|
| 73 |
--pin_memory \
|
| 74 |
--use_amp \
|
| 75 |
-
--
|
| 76 |
-
|
| 77 |
# # average model
|
| 78 |
# average_num=5
|
| 79 |
# if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
|
|
|
| 66 |
--cv_data data/data.list \
|
| 67 |
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
| 68 |
--model $model \
|
| 69 |
+
--model_dir /data/checkpoint/$model/ \
|
|
|
|
| 70 |
--num_workers ${num_workers} \
|
| 71 |
--prefetch ${prefetch} \
|
| 72 |
--pin_memory \
|
| 73 |
--use_amp \
|
| 74 |
+
--checkpoint /data/checkpoint/flow/epoch_88_step_14001.pt
|
|
|
|
| 75 |
# # average model
|
| 76 |
# average_num=5
|
| 77 |
# if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|