Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
55ac664
1
Parent(s):
9f4fc9f
add contrastive loss
Browse files- speech/config.yaml +8 -3
- speech/cosyvoice/flow/flow.py +203 -112
- speech/cosyvoice/flow/flow_matching.py +88 -37
- speech/cosyvoice/utils/executor.py +7 -0
- speech/cosyvoice/utils/train_utils.py +1 -0
speech/config.yaml
CHANGED
|
@@ -73,6 +73,10 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
|
|
| 73 |
training_cfg_rate: 0.2
|
| 74 |
inference_cfg_rate: 0.7
|
| 75 |
reg_loss_type: 'l1'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
|
| 77 |
in_channels: 320
|
| 78 |
out_channels: 80
|
|
@@ -161,6 +165,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
|
| 161 |
center: False
|
| 162 |
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
| 163 |
feat_extractor: !ref <feat_extractor>
|
|
|
|
| 164 |
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
| 165 |
sample_rate: !ref <sample_rate>
|
| 166 |
hop_size: 480
|
|
@@ -172,7 +177,7 @@ sort: !name:cosyvoice.dataset.processor.sort
|
|
| 172 |
sort_size: 500 # sort_size should be less than shuffle_size
|
| 173 |
batch: !name:cosyvoice.dataset.processor.batch
|
| 174 |
batch_type: 'dynamic'
|
| 175 |
-
max_frames_in_batch:
|
| 176 |
padding: !name:cosyvoice.dataset.processor.padding
|
| 177 |
use_spk_embedding: False # change to True during sft
|
| 178 |
|
|
@@ -195,12 +200,12 @@ data_pipeline: [
|
|
| 195 |
train_conf:
|
| 196 |
optim: adamw
|
| 197 |
optim_conf:
|
| 198 |
-
lr:
|
| 199 |
scheduler: constantlr # change to constantlr during sft
|
| 200 |
scheduler_conf:
|
| 201 |
warmup_steps: 2500
|
| 202 |
max_epoch: 200
|
| 203 |
grad_clip: 1
|
| 204 |
accum_grad: 1
|
| 205 |
-
log_interval:
|
| 206 |
save_per_step: -1
|
|
|
|
| 73 |
training_cfg_rate: 0.2
|
| 74 |
inference_cfg_rate: 0.7
|
| 75 |
reg_loss_type: 'l1'
|
| 76 |
+
use_immiscible: True
|
| 77 |
+
immiscible_k: 8
|
| 78 |
+
use_contrastive_fm: True
|
| 79 |
+
contrastive_lambda: 0.05
|
| 80 |
estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
|
| 81 |
in_channels: 320
|
| 82 |
out_channels: 80
|
|
|
|
| 165 |
center: False
|
| 166 |
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
| 167 |
feat_extractor: !ref <feat_extractor>
|
| 168 |
+
token_mel_ratio: !ref <token_mel_ratio>
|
| 169 |
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
| 170 |
sample_rate: !ref <sample_rate>
|
| 171 |
hop_size: 480
|
|
|
|
| 177 |
sort_size: 500 # sort_size should be less than shuffle_size
|
| 178 |
batch: !name:cosyvoice.dataset.processor.batch
|
| 179 |
batch_type: 'dynamic'
|
| 180 |
+
max_frames_in_batch: 25000
|
| 181 |
padding: !name:cosyvoice.dataset.processor.padding
|
| 182 |
use_spk_embedding: False # change to True during sft
|
| 183 |
|
|
|
|
| 200 |
train_conf:
|
| 201 |
optim: adamw
|
| 202 |
optim_conf:
|
| 203 |
+
lr: 2e-6 # change to 1e-5 during sft
|
| 204 |
scheduler: constantlr # change to constantlr during sft
|
| 205 |
scheduler_conf:
|
| 206 |
warmup_steps: 2500
|
| 207 |
max_epoch: 200
|
| 208 |
grad_clip: 1
|
| 209 |
accum_grad: 1
|
| 210 |
+
log_interval: 5
|
| 211 |
save_per_step: -1
|
speech/cosyvoice/flow/flow.py
CHANGED
|
@@ -11,9 +11,10 @@
|
|
| 11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
|
|
|
| 14 |
import logging
|
| 15 |
import random
|
| 16 |
-
from typing import Dict, Optional
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
| 19 |
from torch.nn import functional as F
|
|
@@ -22,24 +23,57 @@ from cosyvoice.utils.mask import make_pad_mask
|
|
| 22 |
|
| 23 |
|
| 24 |
class MaskedDiffWithXvec(torch.nn.Module):
|
| 25 |
-
def __init__(
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
super().__init__()
|
| 44 |
self.input_size = input_size
|
| 45 |
self.output_size = output_size
|
|
@@ -58,22 +92,22 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 58 |
self.only_mask_loss = only_mask_loss
|
| 59 |
|
| 60 |
def forward(
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
) -> Dict[str, Optional[torch.Tensor]]:
|
| 65 |
-
token = batch[
|
| 66 |
-
token_len = batch[
|
| 67 |
-
feat = batch[
|
| 68 |
-
feat_len = batch[
|
| 69 |
-
embedding = batch[
|
| 70 |
|
| 71 |
# xvec projection
|
| 72 |
embedding = F.normalize(embedding, dim=1)
|
| 73 |
embedding = self.spk_embed_affine_layer(embedding)
|
| 74 |
|
| 75 |
# concat text and prompt_text
|
| 76 |
-
print(
|
| 77 |
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 78 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 79 |
|
|
@@ -98,20 +132,22 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 98 |
mask.unsqueeze(1),
|
| 99 |
h.transpose(1, 2).contiguous(),
|
| 100 |
embedding,
|
| 101 |
-
cond=conds
|
| 102 |
)
|
| 103 |
-
return {
|
| 104 |
|
| 105 |
@torch.inference_mode()
|
| 106 |
-
def inference(
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
assert token.shape[0] == 1
|
| 116 |
# xvec projection
|
| 117 |
embedding = F.normalize(embedding, dim=1)
|
|
@@ -119,18 +155,31 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 119 |
|
| 120 |
# concat speech token and prompt speech token
|
| 121 |
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
| 122 |
-
token, token_len =
|
|
|
|
|
|
|
|
|
|
| 123 |
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 124 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 125 |
|
| 126 |
# text encode
|
| 127 |
h, h_lengths = self.encoder(token, token_len)
|
| 128 |
h = self.encoder_proj(h)
|
| 129 |
-
mel_len1, mel_len2 = prompt_feat.shape[1], int(
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
# get conditions
|
| 133 |
-
conds = torch.zeros(
|
|
|
|
|
|
|
| 134 |
conds[:, :mel_len1] = prompt_feat
|
| 135 |
conds = conds.transpose(1, 2)
|
| 136 |
|
|
@@ -142,7 +191,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 142 |
cond=conds,
|
| 143 |
n_timesteps=10,
|
| 144 |
prompt_len=mel_len1,
|
| 145 |
-
cache=flow_cache
|
| 146 |
)
|
| 147 |
feat = feat[:, :, mel_len1:]
|
| 148 |
assert feat.shape[2] == mel_len2
|
|
@@ -150,25 +199,58 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 150 |
|
| 151 |
|
| 152 |
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
| 153 |
-
def __init__(
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
super().__init__()
|
| 173 |
self.input_size = input_size
|
| 174 |
self.output_size = output_size
|
|
@@ -186,32 +268,26 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 186 |
self.only_mask_loss = only_mask_loss
|
| 187 |
self.token_mel_ratio = token_mel_ratio
|
| 188 |
self.pre_lookahead_len = pre_lookahead_len
|
|
|
|
|
|
|
| 189 |
|
| 190 |
def forward(
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
) -> Dict[str, Optional[torch.Tensor]]:
|
| 195 |
-
token = batch[
|
| 196 |
-
token_len = batch[
|
| 197 |
-
feat = batch[
|
| 198 |
-
feat_len = batch[
|
| 199 |
-
embedding = batch[
|
| 200 |
-
|
| 201 |
-
# print('token: ', token.shape)
|
| 202 |
-
# print('token_len: ', token_len.shape)
|
| 203 |
-
# print('feat: ', feat.shape)
|
| 204 |
-
# print('feat_len: ', feat_len.shape)
|
| 205 |
-
# print('embedding: ', embedding.shape)
|
| 206 |
|
| 207 |
# NOTE unified training, static_chunk_size > 0 or = 0
|
| 208 |
-
streaming = False# if random.random() < 0.5 else False
|
| 209 |
|
| 210 |
# xvec projection
|
| 211 |
embedding = F.normalize(embedding, dim=1)
|
| 212 |
embedding = self.spk_embed_affine_layer(embedding)
|
| 213 |
-
# print('token_len values: ', token_len)
|
| 214 |
-
# concat text and prompt_text
|
| 215 |
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 216 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 217 |
|
|
@@ -229,42 +305,50 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 229 |
conds = conds.transpose(1, 2)
|
| 230 |
|
| 231 |
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
@torch.inference_mode()
|
| 251 |
-
def inference(
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
| 261 |
assert token.shape[0] == 1
|
| 262 |
# xvec projection
|
| 263 |
embedding = F.normalize(embedding, dim=1)
|
| 264 |
embedding = self.spk_embed_affine_layer(embedding)
|
| 265 |
|
| 266 |
# concat text and prompt_text
|
| 267 |
-
token, token_len =
|
|
|
|
|
|
|
|
|
|
| 268 |
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 269 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 270 |
|
|
@@ -272,13 +356,20 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 272 |
if finalize is True:
|
| 273 |
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
| 274 |
else:
|
| 275 |
-
token, context =
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
| 278 |
h = self.encoder_proj(h)
|
| 279 |
|
| 280 |
# get conditions
|
| 281 |
-
conds = torch.zeros(
|
|
|
|
|
|
|
| 282 |
conds[:, :mel_len1] = prompt_feat
|
| 283 |
conds = conds.transpose(1, 2)
|
| 284 |
|
|
@@ -289,7 +380,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 289 |
spks=embedding,
|
| 290 |
cond=conds,
|
| 291 |
n_timesteps=10,
|
| 292 |
-
streaming=streaming
|
| 293 |
)
|
| 294 |
feat = feat[:, :, mel_len1:]
|
| 295 |
assert feat.shape[2] == mel_len2
|
|
|
|
| 11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
+
from ast import List
|
| 15 |
import logging
|
| 16 |
import random
|
| 17 |
+
from typing import Dict, Optional, Tuple
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
| 20 |
from torch.nn import functional as F
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class MaskedDiffWithXvec(torch.nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
input_size: int = 512,
|
| 29 |
+
output_size: int = 80,
|
| 30 |
+
spk_embed_dim: int = 192,
|
| 31 |
+
output_type: str = "mel",
|
| 32 |
+
vocab_size: int = 4096,
|
| 33 |
+
input_frame_rate: int = 50,
|
| 34 |
+
only_mask_loss: bool = True,
|
| 35 |
+
encoder: torch.nn.Module = None,
|
| 36 |
+
length_regulator: torch.nn.Module = None,
|
| 37 |
+
decoder: torch.nn.Module = None,
|
| 38 |
+
decoder_conf: Dict = {
|
| 39 |
+
"in_channels": 240,
|
| 40 |
+
"out_channel": 80,
|
| 41 |
+
"spk_emb_dim": 80,
|
| 42 |
+
"n_spks": 1,
|
| 43 |
+
"cfm_params": DictConfig(
|
| 44 |
+
{
|
| 45 |
+
"sigma_min": 1e-06,
|
| 46 |
+
"solver": "euler",
|
| 47 |
+
"t_scheduler": "cosine",
|
| 48 |
+
"training_cfg_rate": 0.2,
|
| 49 |
+
"inference_cfg_rate": 0.7,
|
| 50 |
+
"reg_loss_type": "l1",
|
| 51 |
+
"use_immiscible": True,
|
| 52 |
+
"immiscible_k": 8,
|
| 53 |
+
"use_contrastive_fm": False,
|
| 54 |
+
"contrastive_lambda": 0.05
|
| 55 |
+
}
|
| 56 |
+
),
|
| 57 |
+
"decoder_params": {
|
| 58 |
+
"channels": [256, 256],
|
| 59 |
+
"dropout": 0.0,
|
| 60 |
+
"attention_head_dim": 64,
|
| 61 |
+
"n_blocks": 4,
|
| 62 |
+
"num_mid_blocks": 12,
|
| 63 |
+
"num_heads": 8,
|
| 64 |
+
"act_fn": "gelu",
|
| 65 |
+
},
|
| 66 |
+
},
|
| 67 |
+
mel_feat_conf: Dict = {
|
| 68 |
+
"n_fft": 1024,
|
| 69 |
+
"num_mels": 80,
|
| 70 |
+
"sampling_rate": 22050,
|
| 71 |
+
"hop_size": 256,
|
| 72 |
+
"win_size": 1024,
|
| 73 |
+
"fmin": 0,
|
| 74 |
+
"fmax": 8000,
|
| 75 |
+
},
|
| 76 |
+
):
|
| 77 |
super().__init__()
|
| 78 |
self.input_size = input_size
|
| 79 |
self.output_size = output_size
|
|
|
|
| 92 |
self.only_mask_loss = only_mask_loss
|
| 93 |
|
| 94 |
def forward(
|
| 95 |
+
self,
|
| 96 |
+
batch: dict,
|
| 97 |
+
device: torch.device,
|
| 98 |
) -> Dict[str, Optional[torch.Tensor]]:
|
| 99 |
+
token = batch["speech_token"].to(device)
|
| 100 |
+
token_len = batch["speech_token_len"].to(device)
|
| 101 |
+
feat = batch["speech_feat"].to(device)
|
| 102 |
+
feat_len = batch["speech_feat_len"].to(device)
|
| 103 |
+
embedding = batch["embedding"].to(device)
|
| 104 |
|
| 105 |
# xvec projection
|
| 106 |
embedding = F.normalize(embedding, dim=1)
|
| 107 |
embedding = self.spk_embed_affine_layer(embedding)
|
| 108 |
|
| 109 |
# concat text and prompt_text
|
| 110 |
+
print("token_len values: ", token_len)
|
| 111 |
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 112 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 113 |
|
|
|
|
| 132 |
mask.unsqueeze(1),
|
| 133 |
h.transpose(1, 2).contiguous(),
|
| 134 |
embedding,
|
| 135 |
+
cond=conds,
|
| 136 |
)
|
| 137 |
+
return {"loss": loss}
|
| 138 |
|
| 139 |
@torch.inference_mode()
|
| 140 |
+
def inference(
|
| 141 |
+
self,
|
| 142 |
+
token,
|
| 143 |
+
token_len,
|
| 144 |
+
prompt_token,
|
| 145 |
+
prompt_token_len,
|
| 146 |
+
prompt_feat,
|
| 147 |
+
prompt_feat_len,
|
| 148 |
+
embedding,
|
| 149 |
+
flow_cache,
|
| 150 |
+
):
|
| 151 |
assert token.shape[0] == 1
|
| 152 |
# xvec projection
|
| 153 |
embedding = F.normalize(embedding, dim=1)
|
|
|
|
| 155 |
|
| 156 |
# concat speech token and prompt speech token
|
| 157 |
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
| 158 |
+
token, token_len = (
|
| 159 |
+
torch.concat([prompt_token, token], dim=1),
|
| 160 |
+
prompt_token_len + token_len,
|
| 161 |
+
)
|
| 162 |
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 163 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 164 |
|
| 165 |
# text encode
|
| 166 |
h, h_lengths = self.encoder(token, token_len)
|
| 167 |
h = self.encoder_proj(h)
|
| 168 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(
|
| 169 |
+
token_len2 / self.input_frame_rate * 22050 / 256
|
| 170 |
+
)
|
| 171 |
+
h, h_lengths = self.length_regulator.inference(
|
| 172 |
+
h[:, :token_len1],
|
| 173 |
+
h[:, token_len1:],
|
| 174 |
+
mel_len1,
|
| 175 |
+
mel_len2,
|
| 176 |
+
self.input_frame_rate,
|
| 177 |
+
)
|
| 178 |
|
| 179 |
# get conditions
|
| 180 |
+
conds = torch.zeros(
|
| 181 |
+
[1, mel_len1 + mel_len2, self.output_size], device=token.device
|
| 182 |
+
).to(h.dtype)
|
| 183 |
conds[:, :mel_len1] = prompt_feat
|
| 184 |
conds = conds.transpose(1, 2)
|
| 185 |
|
|
|
|
| 191 |
cond=conds,
|
| 192 |
n_timesteps=10,
|
| 193 |
prompt_len=mel_len1,
|
| 194 |
+
cache=flow_cache,
|
| 195 |
)
|
| 196 |
feat = feat[:, :, mel_len1:]
|
| 197 |
assert feat.shape[2] == mel_len2
|
|
|
|
| 199 |
|
| 200 |
|
| 201 |
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
input_size: int = 512,
|
| 205 |
+
output_size: int = 80,
|
| 206 |
+
spk_embed_dim: int = 192,
|
| 207 |
+
output_type: str = "mel",
|
| 208 |
+
vocab_size: int = 4096,
|
| 209 |
+
input_frame_rate: int = 50,
|
| 210 |
+
only_mask_loss: bool = True,
|
| 211 |
+
token_mel_ratio: int = 2,
|
| 212 |
+
pre_lookahead_len: int = 3,
|
| 213 |
+
encoder: torch.nn.Module = None,
|
| 214 |
+
decoder: torch.nn.Module = None,
|
| 215 |
+
decoder_conf: Dict = {
|
| 216 |
+
"in_channels": 240,
|
| 217 |
+
"out_channel": 80,
|
| 218 |
+
"spk_emb_dim": 80,
|
| 219 |
+
"n_spks": 1,
|
| 220 |
+
"cfm_params": DictConfig(
|
| 221 |
+
{
|
| 222 |
+
"sigma_min": 1e-06,
|
| 223 |
+
"solver": "euler",
|
| 224 |
+
"t_scheduler": "cosine",
|
| 225 |
+
"training_cfg_rate": 0.2,
|
| 226 |
+
"inference_cfg_rate": 0.7,
|
| 227 |
+
"reg_loss_type": "l1",
|
| 228 |
+
"use_immiscible": True,
|
| 229 |
+
"immiscible_k": 8,
|
| 230 |
+
"use_contrastive_fm": True,
|
| 231 |
+
"contrastive_lambda": 0.05
|
| 232 |
+
}
|
| 233 |
+
),
|
| 234 |
+
"decoder_params": {
|
| 235 |
+
"channels": [256, 256],
|
| 236 |
+
"dropout": 0.0,
|
| 237 |
+
"attention_head_dim": 64,
|
| 238 |
+
"n_blocks": 4,
|
| 239 |
+
"num_mid_blocks": 12,
|
| 240 |
+
"num_heads": 8,
|
| 241 |
+
"act_fn": "gelu",
|
| 242 |
+
},
|
| 243 |
+
},
|
| 244 |
+
mel_feat_conf: Dict = {
|
| 245 |
+
"n_fft": 1024,
|
| 246 |
+
"num_mels": 80,
|
| 247 |
+
"sampling_rate": 22050,
|
| 248 |
+
"hop_size": 256,
|
| 249 |
+
"win_size": 1024,
|
| 250 |
+
"fmin": 0,
|
| 251 |
+
"fmax": 8000,
|
| 252 |
+
},
|
| 253 |
+
):
|
| 254 |
super().__init__()
|
| 255 |
self.input_size = input_size
|
| 256 |
self.output_size = output_size
|
|
|
|
| 268 |
self.only_mask_loss = only_mask_loss
|
| 269 |
self.token_mel_ratio = token_mel_ratio
|
| 270 |
self.pre_lookahead_len = pre_lookahead_len
|
| 271 |
+
print(" decoder_conf['cfm_params']: ", decoder_conf["cfm_params"])
|
| 272 |
+
self.use_contrastive_fm = decoder_conf["cfm_params"]["use_contrastive_fm"]
|
| 273 |
|
| 274 |
def forward(
|
| 275 |
+
self,
|
| 276 |
+
batch: dict,
|
| 277 |
+
device: torch.device,
|
| 278 |
) -> Dict[str, Optional[torch.Tensor]]:
|
| 279 |
+
token = batch["speech_token"].to(device)
|
| 280 |
+
token_len = batch["speech_token_len"].to(device)
|
| 281 |
+
feat = batch["speech_feat"].to(device)
|
| 282 |
+
feat_len = batch["speech_feat_len"].to(device)
|
| 283 |
+
embedding = batch["embedding"].to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
# NOTE unified training, static_chunk_size > 0 or = 0
|
| 286 |
+
streaming = False # if random.random() < 0.5 else False
|
| 287 |
|
| 288 |
# xvec projection
|
| 289 |
embedding = F.normalize(embedding, dim=1)
|
| 290 |
embedding = self.spk_embed_affine_layer(embedding)
|
|
|
|
|
|
|
| 291 |
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 292 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 293 |
|
|
|
|
| 305 |
conds = conds.transpose(1, 2)
|
| 306 |
|
| 307 |
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
| 308 |
+
if not self.use_contrastive_fm:
|
| 309 |
+
loss, _ = self.decoder.compute_loss(
|
| 310 |
+
feat.transpose(1, 2).contiguous(),
|
| 311 |
+
mask.unsqueeze(1),
|
| 312 |
+
h.transpose(1, 2).contiguous(),
|
| 313 |
+
embedding,
|
| 314 |
+
cond=conds,
|
| 315 |
+
streaming=streaming,
|
| 316 |
+
)
|
| 317 |
+
else:
|
| 318 |
+
# print("use contrastive fm")
|
| 319 |
+
loss, _ = self.decoder.compute_loss_contrastive(
|
| 320 |
+
feat.transpose(1, 2).contiguous(),
|
| 321 |
+
mask.unsqueeze(1),
|
| 322 |
+
h.transpose(1, 2).contiguous(),
|
| 323 |
+
embedding,
|
| 324 |
+
cond=conds,
|
| 325 |
+
streaming=streaming,
|
| 326 |
+
)
|
| 327 |
+
return {"loss": loss}
|
| 328 |
|
| 329 |
@torch.inference_mode()
|
| 330 |
+
def inference(
|
| 331 |
+
self,
|
| 332 |
+
token,
|
| 333 |
+
token_len,
|
| 334 |
+
prompt_token,
|
| 335 |
+
prompt_token_len,
|
| 336 |
+
prompt_feat,
|
| 337 |
+
prompt_feat_len,
|
| 338 |
+
embedding,
|
| 339 |
+
streaming,
|
| 340 |
+
finalize,
|
| 341 |
+
):
|
| 342 |
assert token.shape[0] == 1
|
| 343 |
# xvec projection
|
| 344 |
embedding = F.normalize(embedding, dim=1)
|
| 345 |
embedding = self.spk_embed_affine_layer(embedding)
|
| 346 |
|
| 347 |
# concat text and prompt_text
|
| 348 |
+
token, token_len = (
|
| 349 |
+
torch.concat([prompt_token, token], dim=1),
|
| 350 |
+
prompt_token_len + token_len,
|
| 351 |
+
)
|
| 352 |
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 353 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 354 |
|
|
|
|
| 356 |
if finalize is True:
|
| 357 |
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
| 358 |
else:
|
| 359 |
+
token, context = (
|
| 360 |
+
token[:, : -self.pre_lookahead_len],
|
| 361 |
+
token[:, -self.pre_lookahead_len :],
|
| 362 |
+
)
|
| 363 |
+
h, h_lengths = self.encoder(
|
| 364 |
+
token, token_len, context=context, streaming=streaming
|
| 365 |
+
)
|
| 366 |
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
| 367 |
h = self.encoder_proj(h)
|
| 368 |
|
| 369 |
# get conditions
|
| 370 |
+
conds = torch.zeros(
|
| 371 |
+
[1, mel_len1 + mel_len2, self.output_size], device=token.device
|
| 372 |
+
).to(h.dtype)
|
| 373 |
conds[:, :mel_len1] = prompt_feat
|
| 374 |
conds = conds.transpose(1, 2)
|
| 375 |
|
|
|
|
| 380 |
spks=embedding,
|
| 381 |
cond=conds,
|
| 382 |
n_timesteps=10,
|
| 383 |
+
streaming=streaming,
|
| 384 |
)
|
| 385 |
feat = feat[:, :, mel_len1:]
|
| 386 |
assert feat.shape[2] == mel_len2
|
speech/cosyvoice/flow/flow_matching.py
CHANGED
|
@@ -34,6 +34,7 @@ class ConditionalCFM(BASECFM):
|
|
| 34 |
self.estimator = estimator
|
| 35 |
self.use_immiscible = cfm_params.use_immiscible
|
| 36 |
self.immiscible_k = cfm_params.immiscible_k
|
|
|
|
| 37 |
|
| 38 |
@torch.inference_mode()
|
| 39 |
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
|
@@ -177,14 +178,6 @@ class ConditionalCFM(BASECFM):
|
|
| 177 |
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 178 |
if self.t_scheduler == 'cosine':
|
| 179 |
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
print(f"\n=== Immiscible Diffusion Debug ===")
|
| 183 |
-
print(f"x1 shape: {x1.shape}")
|
| 184 |
-
print(f"mu shape: {mu.shape}")
|
| 185 |
-
print(f"t shape: {t.shape}")
|
| 186 |
-
print(f"Device: {x1.device}")
|
| 187 |
-
print(f"Dtype: {x1.dtype}")
|
| 188 |
|
| 189 |
# Apply immiscible diffusion with KNN
|
| 190 |
if self.use_immiscible:
|
|
@@ -192,49 +185,87 @@ class ConditionalCFM(BASECFM):
|
|
| 192 |
|
| 193 |
# Generate k noise samples for each data point
|
| 194 |
z_candidates = torch.randn(b, k, d, T, device=x1.device, dtype=x1.dtype)
|
| 195 |
-
|
| 196 |
-
print(f"z_candidates shape: {z_candidates.shape}")
|
| 197 |
-
print(f"z_candidates stats - mean: {z_candidates.mean():.4f}, std: {z_candidates.std():.4f}")
|
| 198 |
-
|
| 199 |
-
# Flatten for distance computation
|
| 200 |
x1_flat = x1.flatten(start_dim=1).to(torch.float16)
|
| 201 |
z_candidates_flat = z_candidates.flatten(start_dim=2).to(torch.float16)
|
| 202 |
|
| 203 |
-
|
| 204 |
-
print(f"z_candidates_flat shape: {z_candidates_flat.shape}")
|
| 205 |
-
|
| 206 |
-
# Calculate distances
|
| 207 |
distances = torch.norm(x1_flat.unsqueeze(1) - z_candidates_flat, dim=2)
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
| 219 |
|
| 220 |
-
|
| 221 |
z = torch.gather(
|
| 222 |
z_candidates,
|
| 223 |
1,
|
| 224 |
min_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(-1, 1, d, T)
|
| 225 |
)[:, 0, :, :]
|
| 226 |
|
| 227 |
-
print(f"Selected z shape: {z.shape}")
|
| 228 |
-
print(f"Selected z stats - mean: {z.mean():.4f}, std: {z.std():.4f}")
|
| 229 |
-
|
| 230 |
-
# Calculate distance reduction
|
| 231 |
-
with torch.no_grad():
|
| 232 |
-
orig_distance = distances[:, 0].mean()
|
| 233 |
-
selected_distance = min_distances.mean()
|
| 234 |
-
reduction_rate = (orig_distance - selected_distance) / orig_distance
|
| 235 |
-
print(f"Distance reduction: {reduction_rate:.3%}")
|
| 236 |
-
print(f"Original distance: {orig_distance:.4f}")
|
| 237 |
-
print(f"Selected distance: {selected_distance:.4f}")
|
| 238 |
else:
|
| 239 |
# sample noise p(x_0)
|
| 240 |
z = torch.randn_like(x1)
|
|
@@ -250,7 +281,27 @@ class ConditionalCFM(BASECFM):
|
|
| 250 |
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 251 |
|
| 252 |
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
return loss, y
|
| 255 |
|
| 256 |
|
|
|
|
| 34 |
self.estimator = estimator
|
| 35 |
self.use_immiscible = cfm_params.use_immiscible
|
| 36 |
self.immiscible_k = cfm_params.immiscible_k
|
| 37 |
+
self.lambda_weight = cfm_params.contrastive_lambda
|
| 38 |
|
| 39 |
@torch.inference_mode()
|
| 40 |
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
|
|
|
| 178 |
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 179 |
if self.t_scheduler == 'cosine':
|
| 180 |
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
# Apply immiscible diffusion with KNN
|
| 183 |
if self.use_immiscible:
|
|
|
|
| 185 |
|
| 186 |
# Generate k noise samples for each data point
|
| 187 |
z_candidates = torch.randn(b, k, d, T, device=x1.device, dtype=x1.dtype)
|
| 188 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
x1_flat = x1.flatten(start_dim=1).to(torch.float16)
|
| 190 |
z_candidates_flat = z_candidates.flatten(start_dim=2).to(torch.float16)
|
| 191 |
|
| 192 |
+
|
|
|
|
|
|
|
|
|
|
| 193 |
distances = torch.norm(x1_flat.unsqueeze(1) - z_candidates_flat, dim=2)
|
| 194 |
+
|
| 195 |
+
min_distances, min_indices = torch.min(distances, dim=1)
|
| 196 |
|
| 197 |
+
|
| 198 |
+
z = torch.gather(
|
| 199 |
+
z_candidates,
|
| 200 |
+
1,
|
| 201 |
+
min_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(-1, 1, d, T)
|
| 202 |
+
)[:, 0, :, :]
|
| 203 |
|
| 204 |
+
else:
|
| 205 |
+
# sample noise p(x_0)
|
| 206 |
+
z = torch.randn_like(x1)
|
| 207 |
+
|
| 208 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 209 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 210 |
+
|
| 211 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
| 212 |
+
if self.training_cfg_rate > 0:
|
| 213 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| 214 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 215 |
+
spks = spks * cfg_mask.view(-1, 1)
|
| 216 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 217 |
+
|
| 218 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
| 219 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 220 |
+
return loss, y
|
| 221 |
+
|
| 222 |
+
def compute_loss_contrastive(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
| 223 |
+
"""Computes diffusion loss
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
x1 (torch.Tensor): Target
|
| 227 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 228 |
+
mask (torch.Tensor): target mask
|
| 229 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 230 |
+
mu (torch.Tensor): output of encoder
|
| 231 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 232 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 233 |
+
shape: (batch_size, spk_emb_dim)
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
loss: conditional flow matching loss
|
| 237 |
+
y: conditional flow
|
| 238 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 239 |
+
"""
|
| 240 |
+
b, d, T = mu.shape
|
| 241 |
+
|
| 242 |
+
# random timestep
|
| 243 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 244 |
+
if self.t_scheduler == 'cosine':
|
| 245 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 246 |
+
|
| 247 |
+
# Apply immiscible diffusion with KNN
|
| 248 |
+
if self.use_immiscible:
|
| 249 |
+
k = getattr(self, 'immiscible_k', 4)
|
| 250 |
|
| 251 |
+
# Generate k noise samples for each data point
|
| 252 |
+
z_candidates = torch.randn(b, k, d, T, device=x1.device, dtype=x1.dtype)
|
| 253 |
+
|
| 254 |
+
x1_flat = x1.flatten(start_dim=1).to(torch.float16)
|
| 255 |
+
z_candidates_flat = z_candidates.flatten(start_dim=2).to(torch.float16)
|
| 256 |
|
| 257 |
+
|
| 258 |
+
distances = torch.norm(x1_flat.unsqueeze(1) - z_candidates_flat, dim=2)
|
| 259 |
+
|
| 260 |
+
min_distances, min_indices = torch.min(distances, dim=1)
|
| 261 |
|
| 262 |
+
|
| 263 |
z = torch.gather(
|
| 264 |
z_candidates,
|
| 265 |
1,
|
| 266 |
min_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(-1, 1, d, T)
|
| 267 |
)[:, 0, :, :]
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
else:
|
| 270 |
# sample noise p(x_0)
|
| 271 |
z = torch.randn_like(x1)
|
|
|
|
| 281 |
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 282 |
|
| 283 |
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
| 284 |
+
fm_loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 285 |
+
|
| 286 |
+
neg_indices = torch.roll(torch.arange(b, device=x1.device), shifts=1)
|
| 287 |
+
|
| 288 |
+
# Get negative targets from shifted indices
|
| 289 |
+
if b > 1:
|
| 290 |
+
u_neg = u[neg_indices]
|
| 291 |
+
neg_mask = mask[neg_indices]
|
| 292 |
+
|
| 293 |
+
# Contrastive loss
|
| 294 |
+
contrastive_loss = F.mse_loss(
|
| 295 |
+
pred * neg_mask,
|
| 296 |
+
u_neg * neg_mask,
|
| 297 |
+
reduction="sum"
|
| 298 |
+
) / (torch.sum(neg_mask) * d)
|
| 299 |
+
print('contrastive_loss: ', contrastive_loss)
|
| 300 |
+
else:
|
| 301 |
+
contrastive_loss = torch.tensor(0.0, device=fm_loss.device)
|
| 302 |
+
|
| 303 |
+
loss = fm_loss - self.lambda_weight * contrastive_loss
|
| 304 |
+
|
| 305 |
return loss, y
|
| 306 |
|
| 307 |
|
speech/cosyvoice/utils/executor.py
CHANGED
|
@@ -33,6 +33,7 @@ class Executor:
|
|
| 33 |
gan: bool = False,
|
| 34 |
ref_model: torch.nn.Module = None,
|
| 35 |
dpo_loss: torch.nn.Module = None,
|
|
|
|
| 36 |
):
|
| 37 |
self.gan = gan
|
| 38 |
self.ref_model = ref_model
|
|
@@ -41,6 +42,7 @@ class Executor:
|
|
| 41 |
self.epoch = 0
|
| 42 |
self.rank = int(os.environ.get("RANK", 0))
|
| 43 |
self.device = torch.device(f"cuda:{self.rank}")
|
|
|
|
| 44 |
|
| 45 |
def train_one_epoc(
|
| 46 |
self,
|
|
@@ -69,16 +71,20 @@ class Executor:
|
|
| 69 |
|
| 70 |
use_ddp = info_dict["train_engine"] == "torch_ddp"
|
| 71 |
|
|
|
|
| 72 |
for batch_idx, batch_dict in enumerate(train_data_loader):
|
| 73 |
info_dict["tag"] = "TRAIN"
|
| 74 |
info_dict["step"] = self.step
|
| 75 |
info_dict["epoch"] = self.epoch
|
| 76 |
info_dict["batch_idx"] = batch_idx
|
| 77 |
|
|
|
|
| 78 |
if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
| 79 |
context = model.no_sync
|
| 80 |
else:
|
| 81 |
context = nullcontext
|
|
|
|
|
|
|
| 82 |
with context():
|
| 83 |
info_dict = batch_forward(
|
| 84 |
model,
|
|
@@ -88,6 +94,7 @@ class Executor:
|
|
| 88 |
ref_model=self.ref_model,
|
| 89 |
dpo_loss=self.dpo_loss,
|
| 90 |
)
|
|
|
|
| 91 |
info_dict = batch_backward(model, scaler, info_dict)
|
| 92 |
|
| 93 |
info_dict = update_parameter_and_lr(
|
|
|
|
| 33 |
gan: bool = False,
|
| 34 |
ref_model: torch.nn.Module = None,
|
| 35 |
dpo_loss: torch.nn.Module = None,
|
| 36 |
+
use_contrastive_fm: bool = False
|
| 37 |
):
|
| 38 |
self.gan = gan
|
| 39 |
self.ref_model = ref_model
|
|
|
|
| 42 |
self.epoch = 0
|
| 43 |
self.rank = int(os.environ.get("RANK", 0))
|
| 44 |
self.device = torch.device(f"cuda:{self.rank}")
|
| 45 |
+
self.use_contrastive_fm = use_contrastive_fm
|
| 46 |
|
| 47 |
def train_one_epoc(
|
| 48 |
self,
|
|
|
|
| 71 |
|
| 72 |
use_ddp = info_dict["train_engine"] == "torch_ddp"
|
| 73 |
|
| 74 |
+
|
| 75 |
for batch_idx, batch_dict in enumerate(train_data_loader):
|
| 76 |
info_dict["tag"] = "TRAIN"
|
| 77 |
info_dict["step"] = self.step
|
| 78 |
info_dict["epoch"] = self.epoch
|
| 79 |
info_dict["batch_idx"] = batch_idx
|
| 80 |
|
| 81 |
+
|
| 82 |
if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
| 83 |
context = model.no_sync
|
| 84 |
else:
|
| 85 |
context = nullcontext
|
| 86 |
+
|
| 87 |
+
|
| 88 |
with context():
|
| 89 |
info_dict = batch_forward(
|
| 90 |
model,
|
|
|
|
| 94 |
ref_model=self.ref_model,
|
| 95 |
dpo_loss=self.dpo_loss,
|
| 96 |
)
|
| 97 |
+
|
| 98 |
info_dict = batch_backward(model, scaler, info_dict)
|
| 99 |
|
| 100 |
info_dict = update_parameter_and_lr(
|
speech/cosyvoice/utils/train_utils.py
CHANGED
|
@@ -250,6 +250,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
|
|
| 250 |
|
| 251 |
with autocast:
|
| 252 |
info_dict['loss_dict'] = model(batch, device)
|
|
|
|
| 253 |
if ref_model is not None and dpo_loss is not None:
|
| 254 |
chosen_logps = info_dict['loss_dict']["chosen_logps"]
|
| 255 |
rejected_logps = info_dict['loss_dict']["rejected_logps"]
|
|
|
|
| 250 |
|
| 251 |
with autocast:
|
| 252 |
info_dict['loss_dict'] = model(batch, device)
|
| 253 |
+
# print('infor_dict loss_dict : ', info_dict['loss_dict'])
|
| 254 |
if ref_model is not None and dpo_loss is not None:
|
| 255 |
chosen_logps = info_dict['loss_dict']["chosen_logps"]
|
| 256 |
rejected_logps = info_dict['loss_dict']["rejected_logps"]
|