primepake commited on
Commit
ce1aa4b
·
1 Parent(s): 19f775a

update scheduler

Browse files
speech/cosyvoice/flow/flow_matching.py CHANGED
@@ -311,13 +311,13 @@ class ConditionalCFM(BASECFM):
311
  reduction="sum"
312
  ) / (torch.sum(mask_neg) * d)
313
 
314
- print('before contrastive_loss: ', contrastive_loss)
315
  else:
316
  contrastive_loss = torch.tensor(0.0, device=fm_loss.device)
317
- print("fm_loss: ", fm_loss)
318
 
319
  contrastive_loss = self.lambda_weight * contrastive_loss
320
- print('contrastive_loss: ', contrastive_loss)
321
 
322
  loss = fm_loss - contrastive_loss
323
 
 
311
  reduction="sum"
312
  ) / (torch.sum(mask_neg) * d)
313
 
314
+ # print('before contrastive_loss: ', contrastive_loss)
315
  else:
316
  contrastive_loss = torch.tensor(0.0, device=fm_loss.device)
317
+ # print("fm_loss: ", fm_loss)
318
 
319
  contrastive_loss = self.lambda_weight * contrastive_loss
320
+ # print('contrastive_loss: ', contrastive_loss)
321
 
322
  loss = fm_loss - contrastive_loss
323
 
speech/cosyvoice/utils/train_utils.py CHANGED
@@ -37,44 +37,102 @@ from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR, _LRSche
37
  from loguru import logger
38
 
39
  class ResumableSequentialLR(_LRScheduler):
40
- """A resumable version of SequentialLR that supports set_step"""
 
41
  def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):
 
 
 
 
 
 
 
 
 
 
 
42
  self.schedulers = schedulers
43
  self.milestones = milestones
44
- self._last_lr = [group['lr'] for group in optimizer.param_groups]
 
 
45
  super().__init__(optimizer, last_epoch)
46
 
47
- def get_lr(self):
48
- # Find which scheduler to use based on last_epoch
49
- idx = 0
 
 
50
  for i, milestone in enumerate(self.milestones):
51
- if self.last_epoch >= milestone:
52
- idx = i + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- if idx >= len(self.schedulers):
55
- idx = len(self.schedulers) - 1
56
-
57
- # Get lr from the appropriate scheduler
58
- scheduler = self.schedulers[idx]
 
 
 
 
 
 
 
 
 
 
 
59
  if hasattr(scheduler, '_get_closed_form_lr'):
60
  return scheduler._get_closed_form_lr()
61
  else:
62
- return scheduler.get_lr()
 
 
 
 
63
 
64
  def step(self, epoch=None):
65
- if epoch is None:
66
- self.last_epoch += 1
67
- else:
68
- self.last_epoch = epoch
69
-
70
- # Update learning rates
71
- for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
72
- param_group['lr'] = lr
73
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
74
-
75
  def set_step(self, step):
76
  """Set the current step for resuming training"""
77
- self.last_epoch = step - 1 # -1 because step() will increment it
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def init_distributed(args):
80
  world_size = int(os.environ.get('WORLD_SIZE', 1))
@@ -151,32 +209,36 @@ def wrap_cuda_model(args, model):
151
 
152
  def init_optimizer_and_scheduler(configs, model):
153
  """Init optimizer and scheduler"""
 
 
154
  if configs['train_conf']['optim'] == 'adam':
155
- optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
156
  elif configs['train_conf']['optim'] == 'adamw':
157
- optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
158
  else:
159
  raise ValueError("unknown optimizer: " + configs['train_conf'])
160
-
 
 
161
  # Create schedulers
162
  warmup_scheduler = LinearLR(
163
  optimizer,
164
- start_factor=1e-9, # Start at nearly 0
165
  end_factor=1.0, # End at base learning rate
166
- total_iters=5000 # 5k warmup steps
167
  )
168
 
169
  constant_scheduler = ConstantLR(
170
  optimizer,
171
  factor=1.0, # Keep learning rate constant
172
- total_iters=float('inf') # Run indefinitely
173
  )
174
 
175
  # Combine schedulers: warmup for 5k steps, then constant
176
  scheduler = ResumableSequentialLR(
177
  optimizer,
178
  schedulers=[warmup_scheduler, constant_scheduler],
179
- milestones=[5000]
180
  )
181
 
182
 
@@ -188,7 +250,7 @@ def save_model(model, model_name, info_dict):
188
  """Save model"""
189
  rank = int(os.environ.get('RANK', 0))
190
  model_dir = info_dict["model_dir"]
191
- # os.makedirs(model_dir, exist_ok=True)
192
  save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
193
 
194
 
@@ -292,54 +354,41 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict, mode
292
 
293
  #Define key components based on model type
294
  if model_type == 'llm':
295
- key_components = {
296
- # Text processing components
297
- 'text_embedding': [],
298
- 'text_encoder': [],
299
- 'text_encoder_affine': [],
300
-
301
- # LLM core components
302
- 'llm_embedding': [],
303
- 'llm.model': [], # Qwen2 model layers
304
- 'llm_decoder': [],
305
-
306
- # Speech components
307
- 'speech_embedding': [],
308
- 'spk_embed_affine': [],
309
-
310
- # Other components
311
- 'other': []
312
  }
313
  elif model_type == 'flow':
314
- key_components = {
315
- # Input processing
316
- 'input_embedding': [],
317
- 'spk_embed_affine': [],
318
-
319
- # Encoder components
320
- 'encoder': [],
321
- 'encoder_proj': [],
322
-
323
- # Flow/Diffusion components
324
- 'decoder.cfm': [], # Conditional Flow Matching
325
- 'decoder.unet': [], # UNet backbone
326
- 'decoder.estimator': [], # Score/velocity estimator
327
- 'decoder.time_embedding': [], # Time embeddings
328
- 'decoder.conv': [], # Convolutional layers
329
- 'decoder.attention': [], # Attention layers
330
-
331
- # Length regulation
332
- 'length_regulator': [],
333
-
334
- # Other components
335
- 'other': []
336
  }
 
 
 
 
 
337
 
338
  grad_norm = 0.0
339
  layer_grad_norms = {}
340
 
341
  if (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
342
-
343
  for name, param in model.named_parameters():
344
  if param.grad is not None:
345
  # Calculate gradient norm for this parameter
@@ -381,6 +430,7 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict, mode
381
  logger.warning('get infinite grad_norm, check your code/data if it appears frequently')
382
  optimizer.zero_grad()
383
  scheduler.step()
 
384
  info_dict["lr"] = optimizer.param_groups[0]['lr']
385
  info_dict["grad_norm"] = grad_norm
386
  info_dict["layer_grad_norms"] = layer_grad_norms
@@ -413,13 +463,13 @@ def log_per_step(experiment, info_dict):
413
 
414
  # TRAIN & CV, Shell log (stdout)
415
  if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
416
- log_str = f'{tag} Batch {epoch}/{batch_idx + 1} '
417
  for name, value in loss_dict.items():
418
  if isinstance(value, torch.Tensor):
419
  value = value.item()
420
  log_str += f'{name} {value:.6f} '
421
  if tag == "TRAIN":
422
- log_str += f'lr {info_dict["lr"]:.8f} grad_norm {info_dict["grad_norm"]:.6f}'
423
  log_str += f' rank {rank}'
424
  logger.info(log_str)
425
 
 
37
  from loguru import logger
38
 
39
  class ResumableSequentialLR(_LRScheduler):
40
+ """A resumable version of SequentialLR that properly manages child schedulers"""
41
+
42
  def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):
43
+ """
44
+ Args:
45
+ optimizer: Wrapped optimizer
46
+ schedulers: List of schedulers to sequentially use
47
+ milestones: List of epoch/step numbers when to switch schedulers
48
+ last_epoch: The index of last epoch/step
49
+ """
50
+ # Validate inputs
51
+ if len(schedulers) != len(milestones) + 1:
52
+ raise ValueError("Expected len(schedulers) == len(milestones) + 1")
53
+
54
  self.schedulers = schedulers
55
  self.milestones = milestones
56
+ self._scheduler_idx = 0
57
+
58
+ # Initialize parent class (this sets last_epoch and calls step())
59
  super().__init__(optimizer, last_epoch)
60
 
61
+ def _get_scheduler_info(self, epoch):
62
+ """Determine which scheduler to use and its relative epoch"""
63
+ scheduler_idx = 0
64
+ relative_epoch = epoch
65
+
66
  for i, milestone in enumerate(self.milestones):
67
+ if epoch >= milestone:
68
+ scheduler_idx = i + 1
69
+ if i == 0:
70
+ relative_epoch = epoch - milestone
71
+ else:
72
+ relative_epoch = epoch - milestone
73
+ else:
74
+ break
75
+
76
+ # Calculate relative epoch for the current scheduler
77
+ if scheduler_idx == 0:
78
+ relative_epoch = epoch
79
+ elif scheduler_idx < len(self.milestones):
80
+ if scheduler_idx == 1:
81
+ relative_epoch = epoch - self.milestones[0]
82
+ else:
83
+ relative_epoch = epoch - self.milestones[scheduler_idx - 1]
84
 
85
+ return scheduler_idx, relative_epoch
86
+
87
+ def get_lr(self):
88
+ """Get learning rate from the appropriate scheduler"""
89
+ if not self._get_lr_called_within_step:
90
+ warnings.warn("To get the last learning rate computed by the scheduler, "
91
+ "please use `get_last_lr()`.", UserWarning)
92
+
93
+ # Get current scheduler and its relative epoch
94
+ scheduler_idx, relative_epoch = self._get_scheduler_info(self.last_epoch)
95
+ scheduler = self.schedulers[scheduler_idx]
96
+
97
+ # Set the scheduler's last_epoch to match relative progress
98
+ scheduler.last_epoch = relative_epoch
99
+
100
+ # Get LR from the scheduler
101
  if hasattr(scheduler, '_get_closed_form_lr'):
102
  return scheduler._get_closed_form_lr()
103
  else:
104
+ # Temporarily set the flag to avoid warning from child scheduler
105
+ scheduler._get_lr_called_within_step = True
106
+ lrs = scheduler.get_lr()
107
+ scheduler._get_lr_called_within_step = False
108
+ return lrs
109
 
110
  def step(self, epoch=None):
111
+ """Step the scheduler"""
112
+ # Step the parent class (updates last_epoch and sets _get_lr_called_within_step)
113
+ super().step(epoch)
114
+
 
 
 
 
 
 
115
  def set_step(self, step):
116
  """Set the current step for resuming training"""
117
+ self.last_epoch = step - 1
118
+
119
+ # Update child schedulers' state
120
+ scheduler_idx, relative_epoch = self._get_scheduler_info(step - 1)
121
+
122
+ # Set all previous schedulers to their final state
123
+ for i in range(scheduler_idx):
124
+ if i < len(self.milestones):
125
+ if i == 0:
126
+ self.schedulers[i].last_epoch = self.milestones[i] - 1
127
+ else:
128
+ self.schedulers[i].last_epoch = self.milestones[i] - self.milestones[i-1] - 1
129
+
130
+ # Set current scheduler to its relative position
131
+ self.schedulers[scheduler_idx].last_epoch = relative_epoch
132
+
133
+ # Update optimizer's learning rates
134
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_last_lr()):
135
+ param_group['lr'] = lr
136
 
137
  def init_distributed(args):
138
  world_size = int(os.environ.get('WORLD_SIZE', 1))
 
209
 
210
  def init_optimizer_and_scheduler(configs, model):
211
  """Init optimizer and scheduler"""
212
+ lr = configs['train_conf']['optim_conf']['lr']
213
+ logger.info(f"lr base: {lr}")
214
  if configs['train_conf']['optim'] == 'adam':
215
+ optimizer = optim.Adam(model.parameters(), lr=lr)
216
  elif configs['train_conf']['optim'] == 'adamw':
217
+ optimizer = optim.AdamW(model.parameters(), lr=lr)
218
  else:
219
  raise ValueError("unknown optimizer: " + configs['train_conf'])
220
+
221
+ warm_up_steps = configs['train_conf']['scheduler_conf']['warmup_steps']
222
+ total_iters = configs['train_conf']['total_iters']
223
  # Create schedulers
224
  warmup_scheduler = LinearLR(
225
  optimizer,
226
+ start_factor=1e-4, # Start at nearly 0
227
  end_factor=1.0, # End at base learning rate
228
+ total_iters=warm_up_steps # 5k warmup steps
229
  )
230
 
231
  constant_scheduler = ConstantLR(
232
  optimizer,
233
  factor=1.0, # Keep learning rate constant
234
+ total_iters=total_iters # Run indefinitely
235
  )
236
 
237
  # Combine schedulers: warmup for 5k steps, then constant
238
  scheduler = ResumableSequentialLR(
239
  optimizer,
240
  schedulers=[warmup_scheduler, constant_scheduler],
241
+ milestones=[warm_up_steps]
242
  )
243
 
244
 
 
250
  """Save model"""
251
  rank = int(os.environ.get('RANK', 0))
252
  model_dir = info_dict["model_dir"]
253
+ os.makedirs(model_dir, exist_ok=True)
254
  save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
255
 
256
 
 
354
 
355
  #Define key components based on model type
356
  if model_type == 'llm':
357
+ component_patterns = {
358
+ 'text_embedding': r'^text_embedding\.',
359
+ 'text_encoder': r'^text_encoder\.',
360
+ 'text_encoder_affine': r'^text_encoder_affine\.',
361
+ 'llm_embedding': r'^llm_embedding\.',
362
+ 'llm.model': r'^llm\.model\.',
363
+ 'llm_decoder': r'^llm_decoder\.',
364
+ 'speech_embedding': r'^speech_embedding\.',
365
+ 'spk_embed_affine': r'^spk_embed_affine\.',
 
 
 
 
 
 
 
 
366
  }
367
  elif model_type == 'flow':
368
+ component_patterns = {
369
+ 'input_embedding': r'^input_embedding\.',
370
+ 'spk_embed_affine': r'^spk_embed_affine\.',
371
+ 'encoder': r'^encoder\.',
372
+ 'encoder_proj': r'^encoder_proj\.',
373
+ 'decoder.cfm': r'^decoder\..*cfm',
374
+ 'decoder.unet': r'^decoder\..*unet',
375
+ 'decoder.estimator': r'^decoder\..*estimator',
376
+ 'decoder.time_embedding': r'^decoder\..*time_embedding',
377
+ 'decoder.conv': r'^decoder\..*conv',
378
+ 'decoder.attention': r'^decoder\..*attention',
379
+ 'length_regulator': r'^length_regulator\.',
 
 
 
 
 
 
 
 
 
 
380
  }
381
+ else:
382
+ raise ValueError(f"Unknown model_type: {model_type}")
383
+
384
+ key_components = {key: [] for key in component_patterns}
385
+ key_components['other'] = []
386
 
387
  grad_norm = 0.0
388
  layer_grad_norms = {}
389
 
390
  if (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
391
+ # logger.info('start to calculate grad norm')
392
  for name, param in model.named_parameters():
393
  if param.grad is not None:
394
  # Calculate gradient norm for this parameter
 
430
  logger.warning('get infinite grad_norm, check your code/data if it appears frequently')
431
  optimizer.zero_grad()
432
  scheduler.step()
433
+ logger.info(f"lr after step {optimizer.param_groups[0]['lr']}")
434
  info_dict["lr"] = optimizer.param_groups[0]['lr']
435
  info_dict["grad_norm"] = grad_norm
436
  info_dict["layer_grad_norms"] = layer_grad_norms
 
463
 
464
  # TRAIN & CV, Shell log (stdout)
465
  if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
466
+ log_str = f'{tag} Batch {epoch}/{batch_idx + 1} step {step} '
467
  for name, value in loss_dict.items():
468
  if isinstance(value, torch.Tensor):
469
  value = value.item()
470
  log_str += f'{name} {value:.6f} '
471
  if tag == "TRAIN":
472
+ log_str += f'lr {info_dict["lr"]:.15f} grad_norm {info_dict["grad_norm"]:.6f}'
473
  log_str += f' rank {rank}'
474
  logger.info(log_str)
475