Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
ce1aa4b
1
Parent(s):
19f775a
update scheduler
Browse files
speech/cosyvoice/flow/flow_matching.py
CHANGED
|
@@ -311,13 +311,13 @@ class ConditionalCFM(BASECFM):
|
|
| 311 |
reduction="sum"
|
| 312 |
) / (torch.sum(mask_neg) * d)
|
| 313 |
|
| 314 |
-
print('before contrastive_loss: ', contrastive_loss)
|
| 315 |
else:
|
| 316 |
contrastive_loss = torch.tensor(0.0, device=fm_loss.device)
|
| 317 |
-
print("fm_loss: ", fm_loss)
|
| 318 |
|
| 319 |
contrastive_loss = self.lambda_weight * contrastive_loss
|
| 320 |
-
print('contrastive_loss: ', contrastive_loss)
|
| 321 |
|
| 322 |
loss = fm_loss - contrastive_loss
|
| 323 |
|
|
|
|
| 311 |
reduction="sum"
|
| 312 |
) / (torch.sum(mask_neg) * d)
|
| 313 |
|
| 314 |
+
# print('before contrastive_loss: ', contrastive_loss)
|
| 315 |
else:
|
| 316 |
contrastive_loss = torch.tensor(0.0, device=fm_loss.device)
|
| 317 |
+
# print("fm_loss: ", fm_loss)
|
| 318 |
|
| 319 |
contrastive_loss = self.lambda_weight * contrastive_loss
|
| 320 |
+
# print('contrastive_loss: ', contrastive_loss)
|
| 321 |
|
| 322 |
loss = fm_loss - contrastive_loss
|
| 323 |
|
speech/cosyvoice/utils/train_utils.py
CHANGED
|
@@ -37,44 +37,102 @@ from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR, _LRSche
|
|
| 37 |
from loguru import logger
|
| 38 |
|
| 39 |
class ResumableSequentialLR(_LRScheduler):
|
| 40 |
-
"""A resumable version of SequentialLR that
|
|
|
|
| 41 |
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
self.schedulers = schedulers
|
| 43 |
self.milestones = milestones
|
| 44 |
-
self.
|
|
|
|
|
|
|
| 45 |
super().__init__(optimizer, last_epoch)
|
| 46 |
|
| 47 |
-
def
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
for i, milestone in enumerate(self.milestones):
|
| 51 |
-
if
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
if hasattr(scheduler, '_get_closed_form_lr'):
|
| 60 |
return scheduler._get_closed_form_lr()
|
| 61 |
else:
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
def step(self, epoch=None):
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# Update learning rates
|
| 71 |
-
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
| 72 |
-
param_group['lr'] = lr
|
| 73 |
-
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
| 74 |
-
|
| 75 |
def set_step(self, step):
|
| 76 |
"""Set the current step for resuming training"""
|
| 77 |
-
self.last_epoch = step - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def init_distributed(args):
|
| 80 |
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
|
@@ -151,32 +209,36 @@ def wrap_cuda_model(args, model):
|
|
| 151 |
|
| 152 |
def init_optimizer_and_scheduler(configs, model):
|
| 153 |
"""Init optimizer and scheduler"""
|
|
|
|
|
|
|
| 154 |
if configs['train_conf']['optim'] == 'adam':
|
| 155 |
-
optimizer = optim.Adam(model.parameters(),
|
| 156 |
elif configs['train_conf']['optim'] == 'adamw':
|
| 157 |
-
optimizer = optim.AdamW(model.parameters(),
|
| 158 |
else:
|
| 159 |
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
# Create schedulers
|
| 162 |
warmup_scheduler = LinearLR(
|
| 163 |
optimizer,
|
| 164 |
-
start_factor=1e-
|
| 165 |
end_factor=1.0, # End at base learning rate
|
| 166 |
-
total_iters=
|
| 167 |
)
|
| 168 |
|
| 169 |
constant_scheduler = ConstantLR(
|
| 170 |
optimizer,
|
| 171 |
factor=1.0, # Keep learning rate constant
|
| 172 |
-
total_iters=
|
| 173 |
)
|
| 174 |
|
| 175 |
# Combine schedulers: warmup for 5k steps, then constant
|
| 176 |
scheduler = ResumableSequentialLR(
|
| 177 |
optimizer,
|
| 178 |
schedulers=[warmup_scheduler, constant_scheduler],
|
| 179 |
-
milestones=[
|
| 180 |
)
|
| 181 |
|
| 182 |
|
|
@@ -188,7 +250,7 @@ def save_model(model, model_name, info_dict):
|
|
| 188 |
"""Save model"""
|
| 189 |
rank = int(os.environ.get('RANK', 0))
|
| 190 |
model_dir = info_dict["model_dir"]
|
| 191 |
-
|
| 192 |
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
|
| 193 |
|
| 194 |
|
|
@@ -292,54 +354,41 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict, mode
|
|
| 292 |
|
| 293 |
#Define key components based on model type
|
| 294 |
if model_type == 'llm':
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
'
|
| 298 |
-
'
|
| 299 |
-
'
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
'
|
| 303 |
-
'
|
| 304 |
-
'llm_decoder': [],
|
| 305 |
-
|
| 306 |
-
# Speech components
|
| 307 |
-
'speech_embedding': [],
|
| 308 |
-
'spk_embed_affine': [],
|
| 309 |
-
|
| 310 |
-
# Other components
|
| 311 |
-
'other': []
|
| 312 |
}
|
| 313 |
elif model_type == 'flow':
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
'
|
| 317 |
-
'
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
'
|
| 321 |
-
'
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
'decoder.
|
| 325 |
-
'
|
| 326 |
-
'decoder.estimator': [], # Score/velocity estimator
|
| 327 |
-
'decoder.time_embedding': [], # Time embeddings
|
| 328 |
-
'decoder.conv': [], # Convolutional layers
|
| 329 |
-
'decoder.attention': [], # Attention layers
|
| 330 |
-
|
| 331 |
-
# Length regulation
|
| 332 |
-
'length_regulator': [],
|
| 333 |
-
|
| 334 |
-
# Other components
|
| 335 |
-
'other': []
|
| 336 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
grad_norm = 0.0
|
| 339 |
layer_grad_norms = {}
|
| 340 |
|
| 341 |
if (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
|
| 342 |
-
|
| 343 |
for name, param in model.named_parameters():
|
| 344 |
if param.grad is not None:
|
| 345 |
# Calculate gradient norm for this parameter
|
|
@@ -381,6 +430,7 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict, mode
|
|
| 381 |
logger.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
| 382 |
optimizer.zero_grad()
|
| 383 |
scheduler.step()
|
|
|
|
| 384 |
info_dict["lr"] = optimizer.param_groups[0]['lr']
|
| 385 |
info_dict["grad_norm"] = grad_norm
|
| 386 |
info_dict["layer_grad_norms"] = layer_grad_norms
|
|
@@ -413,13 +463,13 @@ def log_per_step(experiment, info_dict):
|
|
| 413 |
|
| 414 |
# TRAIN & CV, Shell log (stdout)
|
| 415 |
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
|
| 416 |
-
log_str = f'{tag} Batch {epoch}/{batch_idx + 1} '
|
| 417 |
for name, value in loss_dict.items():
|
| 418 |
if isinstance(value, torch.Tensor):
|
| 419 |
value = value.item()
|
| 420 |
log_str += f'{name} {value:.6f} '
|
| 421 |
if tag == "TRAIN":
|
| 422 |
-
log_str += f'lr {info_dict["lr"]:.
|
| 423 |
log_str += f' rank {rank}'
|
| 424 |
logger.info(log_str)
|
| 425 |
|
|
|
|
| 37 |
from loguru import logger
|
| 38 |
|
| 39 |
class ResumableSequentialLR(_LRScheduler):
|
| 40 |
+
"""A resumable version of SequentialLR that properly manages child schedulers"""
|
| 41 |
+
|
| 42 |
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
optimizer: Wrapped optimizer
|
| 46 |
+
schedulers: List of schedulers to sequentially use
|
| 47 |
+
milestones: List of epoch/step numbers when to switch schedulers
|
| 48 |
+
last_epoch: The index of last epoch/step
|
| 49 |
+
"""
|
| 50 |
+
# Validate inputs
|
| 51 |
+
if len(schedulers) != len(milestones) + 1:
|
| 52 |
+
raise ValueError("Expected len(schedulers) == len(milestones) + 1")
|
| 53 |
+
|
| 54 |
self.schedulers = schedulers
|
| 55 |
self.milestones = milestones
|
| 56 |
+
self._scheduler_idx = 0
|
| 57 |
+
|
| 58 |
+
# Initialize parent class (this sets last_epoch and calls step())
|
| 59 |
super().__init__(optimizer, last_epoch)
|
| 60 |
|
| 61 |
+
def _get_scheduler_info(self, epoch):
|
| 62 |
+
"""Determine which scheduler to use and its relative epoch"""
|
| 63 |
+
scheduler_idx = 0
|
| 64 |
+
relative_epoch = epoch
|
| 65 |
+
|
| 66 |
for i, milestone in enumerate(self.milestones):
|
| 67 |
+
if epoch >= milestone:
|
| 68 |
+
scheduler_idx = i + 1
|
| 69 |
+
if i == 0:
|
| 70 |
+
relative_epoch = epoch - milestone
|
| 71 |
+
else:
|
| 72 |
+
relative_epoch = epoch - milestone
|
| 73 |
+
else:
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
# Calculate relative epoch for the current scheduler
|
| 77 |
+
if scheduler_idx == 0:
|
| 78 |
+
relative_epoch = epoch
|
| 79 |
+
elif scheduler_idx < len(self.milestones):
|
| 80 |
+
if scheduler_idx == 1:
|
| 81 |
+
relative_epoch = epoch - self.milestones[0]
|
| 82 |
+
else:
|
| 83 |
+
relative_epoch = epoch - self.milestones[scheduler_idx - 1]
|
| 84 |
|
| 85 |
+
return scheduler_idx, relative_epoch
|
| 86 |
+
|
| 87 |
+
def get_lr(self):
|
| 88 |
+
"""Get learning rate from the appropriate scheduler"""
|
| 89 |
+
if not self._get_lr_called_within_step:
|
| 90 |
+
warnings.warn("To get the last learning rate computed by the scheduler, "
|
| 91 |
+
"please use `get_last_lr()`.", UserWarning)
|
| 92 |
+
|
| 93 |
+
# Get current scheduler and its relative epoch
|
| 94 |
+
scheduler_idx, relative_epoch = self._get_scheduler_info(self.last_epoch)
|
| 95 |
+
scheduler = self.schedulers[scheduler_idx]
|
| 96 |
+
|
| 97 |
+
# Set the scheduler's last_epoch to match relative progress
|
| 98 |
+
scheduler.last_epoch = relative_epoch
|
| 99 |
+
|
| 100 |
+
# Get LR from the scheduler
|
| 101 |
if hasattr(scheduler, '_get_closed_form_lr'):
|
| 102 |
return scheduler._get_closed_form_lr()
|
| 103 |
else:
|
| 104 |
+
# Temporarily set the flag to avoid warning from child scheduler
|
| 105 |
+
scheduler._get_lr_called_within_step = True
|
| 106 |
+
lrs = scheduler.get_lr()
|
| 107 |
+
scheduler._get_lr_called_within_step = False
|
| 108 |
+
return lrs
|
| 109 |
|
| 110 |
def step(self, epoch=None):
|
| 111 |
+
"""Step the scheduler"""
|
| 112 |
+
# Step the parent class (updates last_epoch and sets _get_lr_called_within_step)
|
| 113 |
+
super().step(epoch)
|
| 114 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def set_step(self, step):
|
| 116 |
"""Set the current step for resuming training"""
|
| 117 |
+
self.last_epoch = step - 1
|
| 118 |
+
|
| 119 |
+
# Update child schedulers' state
|
| 120 |
+
scheduler_idx, relative_epoch = self._get_scheduler_info(step - 1)
|
| 121 |
+
|
| 122 |
+
# Set all previous schedulers to their final state
|
| 123 |
+
for i in range(scheduler_idx):
|
| 124 |
+
if i < len(self.milestones):
|
| 125 |
+
if i == 0:
|
| 126 |
+
self.schedulers[i].last_epoch = self.milestones[i] - 1
|
| 127 |
+
else:
|
| 128 |
+
self.schedulers[i].last_epoch = self.milestones[i] - self.milestones[i-1] - 1
|
| 129 |
+
|
| 130 |
+
# Set current scheduler to its relative position
|
| 131 |
+
self.schedulers[scheduler_idx].last_epoch = relative_epoch
|
| 132 |
+
|
| 133 |
+
# Update optimizer's learning rates
|
| 134 |
+
for param_group, lr in zip(self.optimizer.param_groups, self.get_last_lr()):
|
| 135 |
+
param_group['lr'] = lr
|
| 136 |
|
| 137 |
def init_distributed(args):
|
| 138 |
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
|
|
|
| 209 |
|
| 210 |
def init_optimizer_and_scheduler(configs, model):
|
| 211 |
"""Init optimizer and scheduler"""
|
| 212 |
+
lr = configs['train_conf']['optim_conf']['lr']
|
| 213 |
+
logger.info(f"lr base: {lr}")
|
| 214 |
if configs['train_conf']['optim'] == 'adam':
|
| 215 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 216 |
elif configs['train_conf']['optim'] == 'adamw':
|
| 217 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr)
|
| 218 |
else:
|
| 219 |
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
| 220 |
+
|
| 221 |
+
warm_up_steps = configs['train_conf']['scheduler_conf']['warmup_steps']
|
| 222 |
+
total_iters = configs['train_conf']['total_iters']
|
| 223 |
# Create schedulers
|
| 224 |
warmup_scheduler = LinearLR(
|
| 225 |
optimizer,
|
| 226 |
+
start_factor=1e-4, # Start at nearly 0
|
| 227 |
end_factor=1.0, # End at base learning rate
|
| 228 |
+
total_iters=warm_up_steps # 5k warmup steps
|
| 229 |
)
|
| 230 |
|
| 231 |
constant_scheduler = ConstantLR(
|
| 232 |
optimizer,
|
| 233 |
factor=1.0, # Keep learning rate constant
|
| 234 |
+
total_iters=total_iters # Run indefinitely
|
| 235 |
)
|
| 236 |
|
| 237 |
# Combine schedulers: warmup for 5k steps, then constant
|
| 238 |
scheduler = ResumableSequentialLR(
|
| 239 |
optimizer,
|
| 240 |
schedulers=[warmup_scheduler, constant_scheduler],
|
| 241 |
+
milestones=[warm_up_steps]
|
| 242 |
)
|
| 243 |
|
| 244 |
|
|
|
|
| 250 |
"""Save model"""
|
| 251 |
rank = int(os.environ.get('RANK', 0))
|
| 252 |
model_dir = info_dict["model_dir"]
|
| 253 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 254 |
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
|
| 255 |
|
| 256 |
|
|
|
|
| 354 |
|
| 355 |
#Define key components based on model type
|
| 356 |
if model_type == 'llm':
|
| 357 |
+
component_patterns = {
|
| 358 |
+
'text_embedding': r'^text_embedding\.',
|
| 359 |
+
'text_encoder': r'^text_encoder\.',
|
| 360 |
+
'text_encoder_affine': r'^text_encoder_affine\.',
|
| 361 |
+
'llm_embedding': r'^llm_embedding\.',
|
| 362 |
+
'llm.model': r'^llm\.model\.',
|
| 363 |
+
'llm_decoder': r'^llm_decoder\.',
|
| 364 |
+
'speech_embedding': r'^speech_embedding\.',
|
| 365 |
+
'spk_embed_affine': r'^spk_embed_affine\.',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
}
|
| 367 |
elif model_type == 'flow':
|
| 368 |
+
component_patterns = {
|
| 369 |
+
'input_embedding': r'^input_embedding\.',
|
| 370 |
+
'spk_embed_affine': r'^spk_embed_affine\.',
|
| 371 |
+
'encoder': r'^encoder\.',
|
| 372 |
+
'encoder_proj': r'^encoder_proj\.',
|
| 373 |
+
'decoder.cfm': r'^decoder\..*cfm',
|
| 374 |
+
'decoder.unet': r'^decoder\..*unet',
|
| 375 |
+
'decoder.estimator': r'^decoder\..*estimator',
|
| 376 |
+
'decoder.time_embedding': r'^decoder\..*time_embedding',
|
| 377 |
+
'decoder.conv': r'^decoder\..*conv',
|
| 378 |
+
'decoder.attention': r'^decoder\..*attention',
|
| 379 |
+
'length_regulator': r'^length_regulator\.',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
}
|
| 381 |
+
else:
|
| 382 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 383 |
+
|
| 384 |
+
key_components = {key: [] for key in component_patterns}
|
| 385 |
+
key_components['other'] = []
|
| 386 |
|
| 387 |
grad_norm = 0.0
|
| 388 |
layer_grad_norms = {}
|
| 389 |
|
| 390 |
if (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
|
| 391 |
+
# logger.info('start to calculate grad norm')
|
| 392 |
for name, param in model.named_parameters():
|
| 393 |
if param.grad is not None:
|
| 394 |
# Calculate gradient norm for this parameter
|
|
|
|
| 430 |
logger.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
| 431 |
optimizer.zero_grad()
|
| 432 |
scheduler.step()
|
| 433 |
+
logger.info(f"lr after step {optimizer.param_groups[0]['lr']}")
|
| 434 |
info_dict["lr"] = optimizer.param_groups[0]['lr']
|
| 435 |
info_dict["grad_norm"] = grad_norm
|
| 436 |
info_dict["layer_grad_norms"] = layer_grad_norms
|
|
|
|
| 463 |
|
| 464 |
# TRAIN & CV, Shell log (stdout)
|
| 465 |
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
|
| 466 |
+
log_str = f'{tag} Batch {epoch}/{batch_idx + 1} step {step} '
|
| 467 |
for name, value in loss_dict.items():
|
| 468 |
if isinstance(value, torch.Tensor):
|
| 469 |
value = value.item()
|
| 470 |
log_str += f'{name} {value:.6f} '
|
| 471 |
if tag == "TRAIN":
|
| 472 |
+
log_str += f'lr {info_dict["lr"]:.15f} grad_norm {info_dict["grad_norm"]:.6f}'
|
| 473 |
log_str += f' rank {rank}'
|
| 474 |
logger.info(log_str)
|
| 475 |
|