primepake commited on
Commit
434855f
·
1 Parent(s): ba2c5eb
speech/cosyvoice/utils/executor.py CHANGED
@@ -235,8 +235,8 @@ class Executor:
235
  info_dict["loss_dict"] = total_loss_dict
236
  log_per_save(writer, info_dict)
237
  model_name = (
238
- "epoch_{}_whole".format(self.epoch)
239
  if on_batch_end
240
- else "epoch_{}_step_{}".format(self.epoch, self.step + 1)
241
  )
242
  save_model(model, model_name, info_dict)
 
235
  info_dict["loss_dict"] = total_loss_dict
236
  log_per_save(writer, info_dict)
237
  model_name = (
238
+ f"epoch_{self.epoch}_whole"
239
  if on_batch_end
240
+ else f"epoch_{self.epoch}_step_{self.step + 1}"
241
  )
242
  save_model(model, model_name, info_dict)
speech/cosyvoice/utils/train_utils.py CHANGED
@@ -187,7 +187,7 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
187
 
188
 
189
  def init_summarywriter(args):
190
-
191
  writer = None
192
  if int(os.environ.get('RANK', 0)) == 0:
193
  os.makedirs(args.model_dir, exist_ok=True)
@@ -196,6 +196,7 @@ def init_summarywriter(args):
196
 
197
 
198
  def save_model(model, model_name, info_dict):
 
199
  rank = int(os.environ.get('RANK', 0))
200
  model_dir = info_dict["model_dir"]
201
  save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
@@ -280,6 +281,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
280
 
281
 
282
  def batch_backward(model, scaler, info_dict):
 
283
  if info_dict["train_engine"] == "deepspeed":
284
  scaled_loss = model.backward(info_dict['loss_dict']['loss'])
285
  else:
@@ -294,6 +296,7 @@ def batch_backward(model, scaler, info_dict):
294
 
295
 
296
  def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
 
297
  grad_norm = 0.0
298
  if info_dict['train_engine'] == "deepspeed":
299
  info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
@@ -326,6 +329,7 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
326
 
327
 
328
  def log_per_step(writer, info_dict):
 
329
  tag = info_dict["tag"]
330
  epoch = info_dict.get('epoch', 0)
331
  step = info_dict["step"]
@@ -338,23 +342,23 @@ def log_per_step(writer, info_dict):
338
  if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
339
  (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
340
  for k in ['epoch', 'lr', 'grad_norm']:
341
- writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
342
  for k, v in loss_dict.items():
343
- writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
344
 
345
  # TRAIN & CV, Shell log (stdout)
346
  if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
347
- log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
348
  for name, value in loss_dict.items():
349
- log_str += '{} {:.6f} '.format(name, value)
350
  if tag == "TRAIN":
351
- log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
352
- info_dict["lr"], info_dict['grad_norm'])
353
- log_str += ' rank {}'.format(rank)
354
  logging.debug(log_str)
355
 
356
 
357
  def log_per_save(writer, info_dict):
 
358
  tag = info_dict["tag"]
359
  epoch = info_dict["epoch"]
360
  step = info_dict["step"]
@@ -366,6 +370,6 @@ def log_per_save(writer, info_dict):
366
 
367
  if writer is not None:
368
  for k in ['epoch', 'lr']:
369
- writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
370
  for k, v in loss_dict.items():
371
- writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
 
187
 
188
 
189
  def init_summarywriter(args):
190
+ """Init summary writer"""
191
  writer = None
192
  if int(os.environ.get('RANK', 0)) == 0:
193
  os.makedirs(args.model_dir, exist_ok=True)
 
196
 
197
 
198
  def save_model(model, model_name, info_dict):
199
+ """Save model"""
200
  rank = int(os.environ.get('RANK', 0))
201
  model_dir = info_dict["model_dir"]
202
  save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
 
281
 
282
 
283
  def batch_backward(model, scaler, info_dict):
284
+ """Backward batch"""
285
  if info_dict["train_engine"] == "deepspeed":
286
  scaled_loss = model.backward(info_dict['loss_dict']['loss'])
287
  else:
 
296
 
297
 
298
  def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
299
+ """Update parameters and learning rate"""
300
  grad_norm = 0.0
301
  if info_dict['train_engine'] == "deepspeed":
302
  info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
 
329
 
330
 
331
  def log_per_step(writer, info_dict):
332
+ """Log per step"""
333
  tag = info_dict["tag"]
334
  epoch = info_dict.get('epoch', 0)
335
  step = info_dict["step"]
 
342
  if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
343
  (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
344
  for k in ['epoch', 'lr', 'grad_norm']:
345
+ writer.add_scalar(f'{tag}/{k}', info_dict[k], step + 1)
346
  for k, v in loss_dict.items():
347
+ writer.add_scalar(f'{tag}/{k}', v, step + 1)
348
 
349
  # TRAIN & CV, Shell log (stdout)
350
  if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
351
+ log_str = f'{tag} Batch {epoch}/{batch_idx + 1} '
352
  for name, value in loss_dict.items():
353
+ log_str += f'{name} {value:.6f} '
354
  if tag == "TRAIN":
355
+ log_str += f'lr {info_dict["lr"]:.8f} grad_norm {info_dict["grad_norm"]:.6f}'
356
+ log_str += f' rank {rank}'
 
357
  logging.debug(log_str)
358
 
359
 
360
  def log_per_save(writer, info_dict):
361
+ """Log per save"""
362
  tag = info_dict["tag"]
363
  epoch = info_dict["epoch"]
364
  step = info_dict["step"]
 
370
 
371
  if writer is not None:
372
  for k in ['epoch', 'lr']:
373
+ writer.add_scalar(f'{tag}/{k}', info_dict[k], step + 1)
374
  for k, v in loss_dict.items():
375
+ writer.add_scalar(f'{tag}/{k}', v, step + 1)
speech/train.py CHANGED
@@ -13,82 +13,97 @@
13
  # limitations under the License.
14
 
15
  from __future__ import print_function
 
16
  import argparse
17
  import datetime
18
  import logging
19
- logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
- from copy import deepcopy
21
  import os
 
 
 
22
  import torch
23
  import torch.distributed as dist
24
- import deepspeed
25
- from loguru import logger
26
-
27
  from hyperpyyaml import load_hyperpyyaml
28
-
29
  from torch.distributed.elastic.multiprocessing.errors import record
30
 
31
- from cosyvoice.utils.losses import DPOLoss
32
  from cosyvoice.utils.executor import Executor
33
- from cosyvoice.utils.train_utils import (
34
- init_distributed,
35
- init_dataset_and_dataloader,
36
- init_optimizer_and_scheduler,
37
- init_summarywriter, save_model,
38
- check_modify_and_save_config)
39
 
40
 
41
  def get_args():
42
- parser = argparse.ArgumentParser(description='training your network')
43
- parser.add_argument('--train_engine',
44
- default='torch_ddp',
45
- choices=['torch_ddp', 'deepspeed'],
46
- help='Engine for paralleled training')
47
- parser.add_argument('--model', required=True, help='model which will be trained')
48
- parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
49
- parser.add_argument('--config', required=True, help='config file')
50
- parser.add_argument('--train_data', required=True, help='train data file')
51
- parser.add_argument('--cv_data', required=True, help='cv data file')
52
- parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
53
- parser.add_argument('--checkpoint', help='checkpoint model')
54
- parser.add_argument('--model_dir', required=True, help='save model dir')
55
- parser.add_argument('--tensorboard_dir',
56
- default='tensorboard',
57
- help='tensorboard log dir')
58
- parser.add_argument('--ddp.dist_backend',
59
- dest='dist_backend',
60
- default='nccl',
61
- choices=['nccl', 'gloo'],
62
- help='distributed backend')
63
- parser.add_argument('--num_workers',
64
- default=0,
65
- type=int,
66
- help='num of subprocess workers for reading')
67
- parser.add_argument('--prefetch',
68
- default=100,
69
- type=int,
70
- help='prefetch number')
71
- parser.add_argument('--pin_memory',
72
- action='store_true',
73
- default=False,
74
- help='Use pinned memory buffers used for reading')
75
- parser.add_argument('--use_amp',
76
- action='store_true',
77
- default=False,
78
- help='Use automatic mixed precision training')
79
- parser.add_argument('--dpo',
80
- action='store_true',
81
- default=False,
82
- help='Use Direct Preference Optimization')
83
- parser.add_argument('--deepspeed.save_states',
84
- dest='save_states',
85
- default='model_only',
86
- choices=['model_only', 'model+optimizer'],
87
- help='save model/optimizer states')
88
- parser.add_argument('--timeout',
89
- default=60,
90
- type=int,
91
- help='timeout (in seconds) of cosyvoice_join.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  parser = deepspeed.add_config_arguments(parser)
93
  args = parser.parse_args()
94
  return args
@@ -97,30 +112,41 @@ def get_args():
97
  @record
98
  def main():
99
  args = get_args()
100
- logging.basicConfig(level=logging.DEBUG,
101
- format='%(asctime)s %(levelname)s %(message)s')
 
102
  # gan train has some special initialization logic
103
- gan = True if args.model == 'hifigan' else False
104
 
105
- override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
 
 
106
  if gan is True:
107
- override_dict.pop('hift')
108
  try:
109
- with open(args.config, 'r') as f:
110
- configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
111
- except Exception:
112
- with open(args.config, 'r') as f:
 
 
 
 
 
 
 
113
  configs = load_hyperpyyaml(f, overrides=override_dict)
114
  if gan is True:
115
- configs['train_conf'] = configs['train_conf_gan']
116
- configs['train_conf'].update(vars(args))
117
 
118
  # Init env for ddp
119
  init_distributed(args)
120
 
121
  # Get dataset & dataloader
122
- train_dataset, _, train_data_loader, cv_data_loader = \
123
- init_dataset_and_dataloader(args, configs, gan, args.dpo)
 
124
 
125
  # Do some sanity checks and save config to arsg.model_dir
126
  configs = check_modify_and_save_config(args, configs)
@@ -136,40 +162,45 @@ def main():
136
  start_step, start_epoch = 0, -1
137
  if args.checkpoint is not None:
138
  if os.path.exists(args.checkpoint):
139
- state_dict = torch.load(args.checkpoint, map_location='cpu')
140
  model.load_state_dict(state_dict, strict=False)
141
- if 'step' in state_dict:
142
- start_step = state_dict['step']
143
- if 'epoch' in state_dict:
144
- start_epoch = state_dict['epoch']
145
  else:
146
- logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
147
 
148
  # Dispatch model from cpu to gpu
149
  model = model.cuda()
150
- model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
151
-
 
152
 
153
  # Get optimizer & scheduler
154
- model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
 
 
155
  scheduler.set_step(start_step)
156
  if scheduler_d is not None:
157
  scheduler_d.set_step(start_step)
158
 
159
  # Save init checkpoints
160
- info_dict = deepcopy(configs['train_conf'])
161
- info_dict['step'] = start_step
162
- info_dict['epoch'] = start_epoch
163
- save_model(model, 'init', info_dict)
164
 
165
  # DPO related
166
  if args.dpo is True:
167
  ref_model = deepcopy(configs[args.model])
168
- state_dict = torch.load(args.ref_model, map_location='cpu')
169
  ref_model.load_state_dict(state_dict, strict=False)
170
  dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
171
  ref_model = ref_model.cuda()
172
- ref_model = torch.nn.parallel.DistributedDataParallel(ref_model, find_unused_parameters=True)
 
 
173
  else:
174
  ref_model, dpo_loss = None, None
175
 
@@ -179,21 +210,44 @@ def main():
179
 
180
  # Init scaler, used for pytorch amp mixed precision training
181
  scaler = torch.amp.GradScaler() if args.use_amp else None
182
- logger.info(f'start step {start_step} start epoch {start_epoch}')
183
 
184
  # Start training loop
185
- for epoch in range(start_epoch + 1, info_dict['max_epoch']):
186
  executor.epoch = epoch
187
  train_dataset.set_epoch(epoch)
188
  dist.barrier()
189
- group_join = dist.new_group(backend="nccl", timeout=datetime.timedelta(seconds=args.timeout))
 
 
190
  if gan is True:
191
- executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
192
- writer, info_dict, scaler, group_join)
 
 
 
 
 
 
 
 
 
 
 
193
  else:
194
- executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
 
 
 
 
 
 
 
 
 
 
195
  dist.destroy_process_group(group_join)
196
 
197
 
198
- if __name__ == '__main__':
199
- main()
 
13
  # limitations under the License.
14
 
15
  from __future__ import print_function
16
+
17
  import argparse
18
  import datetime
19
  import logging
20
+
21
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
22
  import os
23
+ from copy import deepcopy
24
+
25
+ import deepspeed
26
  import torch
27
  import torch.distributed as dist
 
 
 
28
  from hyperpyyaml import load_hyperpyyaml
29
+ from loguru import logger
30
  from torch.distributed.elastic.multiprocessing.errors import record
31
 
 
32
  from cosyvoice.utils.executor import Executor
33
+ from cosyvoice.utils.losses import DPOLoss
34
+ from cosyvoice.utils.train_utils import (check_modify_and_save_config,
35
+ init_dataset_and_dataloader,
36
+ init_distributed,
37
+ init_optimizer_and_scheduler,
38
+ init_summarywriter, save_model)
39
 
40
 
41
  def get_args():
42
+ parser = argparse.ArgumentParser(description="training your network")
43
+ parser.add_argument(
44
+ "--train_engine",
45
+ default="torch_ddp",
46
+ choices=["torch_ddp", "deepspeed"],
47
+ help="Engine for paralleled training",
48
+ )
49
+ parser.add_argument("--model", required=True, help="model which will be trained")
50
+ parser.add_argument("--ref_model", required=False, help="ref model used in dpo")
51
+ parser.add_argument("--config", required=True, help="config file")
52
+ parser.add_argument("--train_data", required=True, help="train data file")
53
+ parser.add_argument("--cv_data", required=True, help="cv data file")
54
+ parser.add_argument(
55
+ "--qwen_pretrain_path", required=False, help="qwen pretrain path"
56
+ )
57
+ parser.add_argument("--checkpoint", help="checkpoint model")
58
+ parser.add_argument("--model_dir", required=True, help="save model dir")
59
+ parser.add_argument(
60
+ "--tensorboard_dir", default="tensorboard", help="tensorboard log dir"
61
+ )
62
+ parser.add_argument(
63
+ "--ddp.dist_backend",
64
+ dest="dist_backend",
65
+ default="nccl",
66
+ choices=["nccl", "gloo"],
67
+ help="distributed backend",
68
+ )
69
+ parser.add_argument(
70
+ "--num_workers",
71
+ default=0,
72
+ type=int,
73
+ help="num of subprocess workers for reading",
74
+ )
75
+ parser.add_argument("--prefetch", default=100, type=int, help="prefetch number")
76
+ parser.add_argument(
77
+ "--pin_memory",
78
+ action="store_true",
79
+ default=False,
80
+ help="Use pinned memory buffers used for reading",
81
+ )
82
+ parser.add_argument(
83
+ "--use_amp",
84
+ action="store_true",
85
+ default=False,
86
+ help="Use automatic mixed precision training",
87
+ )
88
+ parser.add_argument(
89
+ "--dpo",
90
+ action="store_true",
91
+ default=False,
92
+ help="Use Direct Preference Optimization",
93
+ )
94
+ parser.add_argument(
95
+ "--deepspeed.save_states",
96
+ dest="save_states",
97
+ default="model_only",
98
+ choices=["model_only", "model+optimizer"],
99
+ help="save model/optimizer states",
100
+ )
101
+ parser.add_argument(
102
+ "--timeout",
103
+ default=60,
104
+ type=int,
105
+ help="timeout (in seconds) of cosyvoice_join.",
106
+ )
107
  parser = deepspeed.add_config_arguments(parser)
108
  args = parser.parse_args()
109
  return args
 
112
  @record
113
  def main():
114
  args = get_args()
115
+ logging.basicConfig(
116
+ level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s"
117
+ )
118
  # gan train has some special initialization logic
119
+ gan = True if args.model == "hifigan" else False
120
 
121
+ override_dict = {
122
+ k: None for k in ["llm", "flow", "hift", "hifigan"] if k != args.model
123
+ }
124
  if gan is True:
125
+ override_dict.pop("hift")
126
  try:
127
+ with open(args.config, "r", encoding="utf-8") as f:
128
+ configs = load_hyperpyyaml(
129
+ f,
130
+ overrides={
131
+ **override_dict,
132
+ "qwen_pretrain_path": args.qwen_pretrain_path,
133
+ },
134
+ )
135
+ except Exception as e:
136
+ logger.error(f"Error loading config: {e}")
137
+ with open(args.config, "r", encoding="utf-8") as f:
138
  configs = load_hyperpyyaml(f, overrides=override_dict)
139
  if gan is True:
140
+ configs["train_conf"] = configs["train_conf_gan"]
141
+ configs["train_conf"].update(vars(args))
142
 
143
  # Init env for ddp
144
  init_distributed(args)
145
 
146
  # Get dataset & dataloader
147
+ train_dataset, _, train_data_loader, cv_data_loader = init_dataset_and_dataloader(
148
+ args, configs, gan, args.dpo
149
+ )
150
 
151
  # Do some sanity checks and save config to arsg.model_dir
152
  configs = check_modify_and_save_config(args, configs)
 
162
  start_step, start_epoch = 0, -1
163
  if args.checkpoint is not None:
164
  if os.path.exists(args.checkpoint):
165
+ state_dict = torch.load(args.checkpoint, map_location="cpu")
166
  model.load_state_dict(state_dict, strict=False)
167
+ if "step" in state_dict:
168
+ start_step = state_dict["step"]
169
+ if "epoch" in state_dict:
170
+ start_epoch = state_dict["epoch"]
171
  else:
172
+ logger.warning(f"checkpoint {args.checkpoint} do not exsist!")
173
 
174
  # Dispatch model from cpu to gpu
175
  model = model.cuda()
176
+ model = torch.nn.parallel.DistributedDataParallel(
177
+ model, find_unused_parameters=True
178
+ )
179
 
180
  # Get optimizer & scheduler
181
+ model, optimizer, scheduler, optimizer_d, scheduler_d = (
182
+ init_optimizer_and_scheduler(args, configs, model, gan)
183
+ )
184
  scheduler.set_step(start_step)
185
  if scheduler_d is not None:
186
  scheduler_d.set_step(start_step)
187
 
188
  # Save init checkpoints
189
+ info_dict = deepcopy(configs["train_conf"])
190
+ info_dict["step"] = start_step
191
+ info_dict["epoch"] = start_epoch
192
+ save_model(model, "init", info_dict)
193
 
194
  # DPO related
195
  if args.dpo is True:
196
  ref_model = deepcopy(configs[args.model])
197
+ state_dict = torch.load(args.ref_model, map_location="cpu")
198
  ref_model.load_state_dict(state_dict, strict=False)
199
  dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
200
  ref_model = ref_model.cuda()
201
+ ref_model = torch.nn.parallel.DistributedDataParallel(
202
+ ref_model, find_unused_parameters=True
203
+ )
204
  else:
205
  ref_model, dpo_loss = None, None
206
 
 
210
 
211
  # Init scaler, used for pytorch amp mixed precision training
212
  scaler = torch.amp.GradScaler() if args.use_amp else None
213
+ logger.info(f"start step {start_step} start epoch {start_epoch}")
214
 
215
  # Start training loop
216
+ for epoch in range(start_epoch + 1, info_dict["max_epoch"]):
217
  executor.epoch = epoch
218
  train_dataset.set_epoch(epoch)
219
  dist.barrier()
220
+ group_join = dist.new_group(
221
+ backend="nccl", timeout=datetime.timedelta(seconds=args.timeout)
222
+ )
223
  if gan is True:
224
+ executor.train_one_epoc_gan(
225
+ model,
226
+ optimizer,
227
+ scheduler,
228
+ optimizer_d,
229
+ scheduler_d,
230
+ train_data_loader,
231
+ cv_data_loader,
232
+ writer,
233
+ info_dict,
234
+ scaler,
235
+ group_join,
236
+ )
237
  else:
238
+ executor.train_one_epoc(
239
+ model,
240
+ optimizer,
241
+ scheduler,
242
+ train_data_loader,
243
+ cv_data_loader,
244
+ writer,
245
+ info_dict,
246
+ scaler,
247
+ group_join,
248
+ )
249
  dist.destroy_process_group(group_join)
250
 
251
 
252
+ if __name__ == "__main__":
253
+ main()