primepake commited on
Commit
ba2c5eb
·
1 Parent(s): 32d5b2b

update train

Browse files
speech/cosyvoice/utils/executor.py CHANGED
@@ -13,42 +13,63 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- import logging
17
- from contextlib import nullcontext
18
  import os
 
19
 
20
  import torch
21
  import torch.distributed as dist
 
 
 
 
22
 
23
- from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
24
 
25
 
26
  class Executor:
27
-
28
- def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
 
 
 
 
 
29
  self.gan = gan
30
  self.ref_model = ref_model
31
  self.dpo_loss = dpo_loss
32
  self.step = 0
33
  self.epoch = 0
34
- self.rank = int(os.environ.get('RANK', 0))
35
- self.device = torch.device('cuda:{}'.format(self.rank))
36
-
37
- def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
38
- ''' Train one epoch
39
- '''
40
-
41
- lr = optimizer.param_groups[0]['lr']
42
- logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
43
- logging.info('using accumulate grad, new batch size is {} times'
44
- ' larger than before'.format(info_dict['accum_grad']))
45
- # A context manager to be used in conjunction with an instance of
46
- # torch.nn.parallel.DistributedDataParallel to be able to train
47
- # with uneven inputs across participating processes.
 
 
 
 
 
 
 
 
 
 
 
48
  model.train()
49
  if self.ref_model is not None:
50
  self.ref_model.eval()
51
- model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
 
 
52
  with model_context():
53
  for batch_idx, batch_dict in enumerate(train_data_loader):
54
  info_dict["tag"] = "TRAIN"
@@ -58,47 +79,77 @@ class Executor:
58
  if cosyvoice_join(group_join, info_dict):
59
  break
60
 
61
- # Disable gradient synchronizations across DDP processes.
62
- # Within this context, gradients will be accumulated on module
63
- # variables, which will later be synchronized.
64
- if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
 
65
  context = model.no_sync
66
- # Used for single gpu training and DDP gradient synchronization
67
- # processes.
68
  else:
69
  context = nullcontext
70
 
71
  with context():
72
- info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
 
 
 
 
 
 
 
73
  info_dict = batch_backward(model, scaler, info_dict)
74
 
75
- info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
 
 
76
  log_per_step(writer, info_dict)
77
  # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
78
- if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
79
- (batch_idx + 1) % info_dict["accum_grad"] == 0:
 
 
 
80
  dist.barrier()
81
- self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
 
 
82
  model.train()
83
  if (batch_idx + 1) % info_dict["accum_grad"] == 0:
84
  self.step += 1
85
  dist.barrier()
86
  self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
87
 
88
- def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
89
- writer, info_dict, scaler, group_join):
90
- ''' Train one epoch
91
- '''
92
-
93
- lr = optimizer.param_groups[0]['lr']
94
- logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
95
- logging.info('using accumulate grad, new batch size is {} times'
96
- ' larger than before'.format(info_dict['accum_grad']))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # A context manager to be used in conjunction with an instance of
98
  # torch.nn.parallel.DistributedDataParallel to be able to train
99
  # with uneven inputs across participating processes.
100
  model.train()
101
- model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
 
 
102
  with model_context():
103
  for batch_idx, batch_dict in enumerate(train_data_loader):
104
  info_dict["tag"] = "TRAIN"
@@ -111,7 +162,10 @@ class Executor:
111
  # Disable gradient synchronizations across DDP processes.
112
  # Within this context, gradients will be accumulated on module
113
  # variables, which will later be synchronized.
114
- if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
 
 
 
115
  context = model.no_sync
116
  # Used for single gpu training and DDP gradient synchronization
117
  # processes.
@@ -119,35 +173,43 @@ class Executor:
119
  context = nullcontext
120
 
121
  with context():
122
- batch_dict['turn'] = 'discriminator'
123
  info_dict = batch_forward(model, batch_dict, scaler, info_dict)
124
  info_dict = batch_backward(model, scaler, info_dict)
125
- info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
 
 
126
  optimizer.zero_grad()
127
  log_per_step(writer, info_dict)
128
  with context():
129
- batch_dict['turn'] = 'generator'
130
  info_dict = batch_forward(model, batch_dict, scaler, info_dict)
131
  info_dict = batch_backward(model, scaler, info_dict)
132
- info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
 
 
133
  optimizer_d.zero_grad()
134
  log_per_step(writer, info_dict)
135
  # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
136
- if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
137
- (batch_idx + 1) % info_dict["accum_grad"] == 0:
 
 
 
138
  dist.barrier()
139
- self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
 
 
140
  model.train()
141
  if (batch_idx + 1) % info_dict["accum_grad"] == 0:
142
  self.step += 1
143
  dist.barrier()
144
- self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
145
 
146
  @torch.inference_mode()
147
  def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
148
- ''' Cross validation on
149
- '''
150
- logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
151
  model.eval()
152
  total_num_utts, total_loss_dict = 0, {} # avoid division by 0
153
  for batch_idx, batch_dict in enumerate(cv_data_loader):
@@ -160,17 +222,21 @@ class Executor:
160
  total_num_utts += num_utts
161
 
162
  if self.gan is True:
163
- batch_dict['turn'] = 'generator'
164
  info_dict = batch_forward(model, batch_dict, None, info_dict)
165
 
166
- for k, v in info_dict['loss_dict'].items():
167
  if k not in total_loss_dict:
168
  total_loss_dict[k] = []
169
  total_loss_dict[k].append(v.item() * num_utts)
170
  log_per_step(None, info_dict)
171
  for k, v in total_loss_dict.items():
172
  total_loss_dict[k] = sum(v) / total_num_utts
173
- info_dict['loss_dict'] = total_loss_dict
174
  log_per_save(writer, info_dict)
175
- model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
 
 
 
 
176
  save_model(model, model_name, info_dict)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
 
16
  import os
17
+ from contextlib import nullcontext
18
 
19
  import torch
20
  import torch.distributed as dist
21
+ from cosyvoice.utils.train_utils import (batch_backward, batch_forward,
22
+ cosyvoice_join, log_per_save,
23
+ log_per_step, save_model,
24
+ update_parameter_and_lr)
25
 
26
+ from loguru import logger
27
 
28
 
29
  class Executor:
30
+ """Executor for training and cross validation"""
31
+ def __init__(
32
+ self,
33
+ gan: bool = False,
34
+ ref_model: torch.nn.Module = None,
35
+ dpo_loss: torch.nn.Module = None,
36
+ ):
37
  self.gan = gan
38
  self.ref_model = ref_model
39
  self.dpo_loss = dpo_loss
40
  self.step = 0
41
  self.epoch = 0
42
+ self.rank = int(os.environ.get("RANK", 0))
43
+ self.device = torch.device(f"cuda:{self.rank}")
44
+
45
+ def train_one_epoc(
46
+ self,
47
+ model,
48
+ optimizer,
49
+ scheduler,
50
+ train_data_loader,
51
+ cv_data_loader,
52
+ writer,
53
+ info_dict,
54
+ scaler,
55
+ group_join,
56
+ ):
57
+ """Train one epoch"""
58
+
59
+ lr = optimizer.param_groups[0]["lr"]
60
+ logger.info(
61
+ f"Epoch {self.epoch} TRAIN info lr {lr} rank {self.rank}"
62
+ )
63
+ logger.info(
64
+ f"using accumulate grad, new batch size is {info_dict['accum_grad']} times larger than before"
65
+ )
66
+
67
  model.train()
68
  if self.ref_model is not None:
69
  self.ref_model.eval()
70
+ model_context = (
71
+ model.join if info_dict["train_engine"] == "torch_ddp" else nullcontext
72
+ )
73
  with model_context():
74
  for batch_idx, batch_dict in enumerate(train_data_loader):
75
  info_dict["tag"] = "TRAIN"
 
79
  if cosyvoice_join(group_join, info_dict):
80
  break
81
 
82
+
83
+ if (
84
+ info_dict["train_engine"] == "torch_ddp"
85
+ and (batch_idx + 1) % info_dict["accum_grad"] != 0
86
+ ):
87
  context = model.no_sync
88
+
 
89
  else:
90
  context = nullcontext
91
 
92
  with context():
93
+ info_dict = batch_forward(
94
+ model,
95
+ batch_dict,
96
+ scaler,
97
+ info_dict,
98
+ ref_model=self.ref_model,
99
+ dpo_loss=self.dpo_loss,
100
+ )
101
  info_dict = batch_backward(model, scaler, info_dict)
102
 
103
+ info_dict = update_parameter_and_lr(
104
+ model, optimizer, scheduler, scaler, info_dict
105
+ )
106
  log_per_step(writer, info_dict)
107
  # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
108
+ if (
109
+ info_dict["save_per_step"] > 0
110
+ and (self.step + 1) % info_dict["save_per_step"] == 0
111
+ and (batch_idx + 1) % info_dict["accum_grad"] == 0
112
+ ):
113
  dist.barrier()
114
+ self.cv(
115
+ model, cv_data_loader, writer, info_dict, on_batch_end=False
116
+ )
117
  model.train()
118
  if (batch_idx + 1) % info_dict["accum_grad"] == 0:
119
  self.step += 1
120
  dist.barrier()
121
  self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
122
 
123
+ def train_one_epoc_gan(
124
+ self,
125
+ model,
126
+ optimizer,
127
+ scheduler,
128
+ optimizer_d,
129
+ scheduler_d,
130
+ train_data_loader,
131
+ cv_data_loader,
132
+ writer,
133
+ info_dict,
134
+ scaler,
135
+ group_join,
136
+ ):
137
+ """Train one epoch"""
138
+
139
+ lr = optimizer.param_groups[0]["lr"]
140
+ logger.info(
141
+ f"Epoch {self.epoch} TRAIN info lr {lr} rank {self.rank}"
142
+ )
143
+ logger.info(
144
+ f"using accumulate grad, new batch size is {info_dict['accum_grad']} times larger than before"
145
+ )
146
  # A context manager to be used in conjunction with an instance of
147
  # torch.nn.parallel.DistributedDataParallel to be able to train
148
  # with uneven inputs across participating processes.
149
  model.train()
150
+ model_context = (
151
+ model.join if info_dict["train_engine"] == "torch_ddp" else nullcontext
152
+ )
153
  with model_context():
154
  for batch_idx, batch_dict in enumerate(train_data_loader):
155
  info_dict["tag"] = "TRAIN"
 
162
  # Disable gradient synchronizations across DDP processes.
163
  # Within this context, gradients will be accumulated on module
164
  # variables, which will later be synchronized.
165
+ if (
166
+ info_dict["train_engine"] == "torch_ddp"
167
+ and (batch_idx + 1) % info_dict["accum_grad"] != 0
168
+ ):
169
  context = model.no_sync
170
  # Used for single gpu training and DDP gradient synchronization
171
  # processes.
 
173
  context = nullcontext
174
 
175
  with context():
176
+ batch_dict["turn"] = "discriminator"
177
  info_dict = batch_forward(model, batch_dict, scaler, info_dict)
178
  info_dict = batch_backward(model, scaler, info_dict)
179
+ info_dict = update_parameter_and_lr(
180
+ model, optimizer_d, scheduler_d, scaler, info_dict
181
+ )
182
  optimizer.zero_grad()
183
  log_per_step(writer, info_dict)
184
  with context():
185
+ batch_dict["turn"] = "generator"
186
  info_dict = batch_forward(model, batch_dict, scaler, info_dict)
187
  info_dict = batch_backward(model, scaler, info_dict)
188
+ info_dict = update_parameter_and_lr(
189
+ model, optimizer, scheduler, scaler, info_dict
190
+ )
191
  optimizer_d.zero_grad()
192
  log_per_step(writer, info_dict)
193
  # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
194
+ if (
195
+ info_dict["save_per_step"] > 0
196
+ and (self.step + 1) % info_dict["save_per_step"] == 0
197
+ and (batch_idx + 1) % info_dict["accum_grad"] == 0
198
+ ):
199
  dist.barrier()
200
+ self.cv(
201
+ model, cv_data_loader, writer, info_dict, on_batch_end=False
202
+ )
203
  model.train()
204
  if (batch_idx + 1) % info_dict["accum_grad"] == 0:
205
  self.step += 1
206
  dist.barrier()
207
+ # self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
208
 
209
  @torch.inference_mode()
210
  def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
211
+ """Cross validation on"""
212
+ logger.info(f"Epoch {self.epoch} Step {self.step + 1} on_batch_end {on_batch_end} CV rank {self.rank}")
 
213
  model.eval()
214
  total_num_utts, total_loss_dict = 0, {} # avoid division by 0
215
  for batch_idx, batch_dict in enumerate(cv_data_loader):
 
222
  total_num_utts += num_utts
223
 
224
  if self.gan is True:
225
+ batch_dict["turn"] = "generator"
226
  info_dict = batch_forward(model, batch_dict, None, info_dict)
227
 
228
+ for k, v in info_dict["loss_dict"].items():
229
  if k not in total_loss_dict:
230
  total_loss_dict[k] = []
231
  total_loss_dict[k].append(v.item() * num_utts)
232
  log_per_step(None, info_dict)
233
  for k, v in total_loss_dict.items():
234
  total_loss_dict[k] = sum(v) / total_num_utts
235
+ info_dict["loss_dict"] = total_loss_dict
236
  log_per_save(writer, info_dict)
237
+ model_name = (
238
+ "epoch_{}_whole".format(self.epoch)
239
+ if on_batch_end
240
+ else "epoch_{}_step_{}".format(self.epoch, self.step + 1)
241
+ )
242
  save_model(model, model_name, info_dict)
speech/cosyvoice/utils/train_utils.py CHANGED
@@ -29,7 +29,7 @@ import torch.distributed as dist
29
  from torch.utils.tensorboard import SummaryWriter
30
  from torch.utils.data import DataLoader
31
  from torch.nn.utils import clip_grad_norm_
32
-
33
  from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
34
 
35
  from cosyvoice.dataset.dataset import Dataset
@@ -40,8 +40,7 @@ def init_distributed(args):
40
  world_size = int(os.environ.get('WORLD_SIZE', 1))
41
  local_rank = int(os.environ.get('LOCAL_RANK', 0))
42
  rank = int(os.environ.get('RANK', 0))
43
- logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
44
- ', rank {}, world_size {}'.format(rank, world_size))
45
  if args.train_engine == 'torch_ddp':
46
  torch.cuda.set_device(local_rank)
47
  dist.init_process_group(args.dist_backend)
@@ -70,6 +69,7 @@ def init_dataset_and_dataloader(args, configs, gan, dpo):
70
 
71
 
72
  def check_modify_and_save_config(args, configs):
 
73
  if args.train_engine == "torch_ddp":
74
  configs['train_conf']["dtype"] = 'fp32'
75
  else:
@@ -92,6 +92,7 @@ def check_modify_and_save_config(args, configs):
92
 
93
 
94
  def wrap_cuda_model(args, model):
 
95
  local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
96
  world_size = int(os.environ.get('WORLD_SIZE', 1))
97
  if args.train_engine == "torch_ddp": # native pytorch ddp
@@ -109,6 +110,7 @@ def wrap_cuda_model(args, model):
109
 
110
 
111
  def init_optimizer_and_scheduler(args, configs, model, gan):
 
112
  if gan is False:
113
  if configs['train_conf']['optim'] == 'adam':
114
  optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
@@ -185,6 +187,7 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
185
 
186
 
187
  def init_summarywriter(args):
 
188
  writer = None
189
  if int(os.environ.get('RANK', 0)) == 0:
190
  os.makedirs(args.model_dir, exist_ok=True)
@@ -215,6 +218,7 @@ def save_model(model, model_name, info_dict):
215
 
216
 
217
  def cosyvoice_join(group_join, info_dict):
 
218
  world_size = int(os.environ.get('WORLD_SIZE', 1))
219
  local_rank = int(os.environ.get('LOCAL_RANK', 0))
220
  rank = int(os.environ.get('RANK', 0))
@@ -236,6 +240,7 @@ def cosyvoice_join(group_join, info_dict):
236
 
237
 
238
  def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
 
239
  device = int(os.environ.get('LOCAL_RANK', 0))
240
 
241
  dtype = info_dict["dtype"]
@@ -276,7 +281,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
276
 
277
  def batch_backward(model, scaler, info_dict):
278
  if info_dict["train_engine"] == "deepspeed":
279
- scaled_loss = model.backward(info_dict['loss_dict']['loss'])
280
  else:
281
  scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
282
  if scaler is not None:
@@ -356,9 +361,8 @@ def log_per_save(writer, info_dict):
356
  loss_dict = info_dict["loss_dict"]
357
  lr = info_dict['lr']
358
  rank = int(os.environ.get('RANK', 0))
359
- logging.info(
360
- 'Epoch {} Step {} CV info lr {} {} rank {}'.format(
361
- epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
362
 
363
  if writer is not None:
364
  for k in ['epoch', 'lr']:
 
29
  from torch.utils.tensorboard import SummaryWriter
30
  from torch.utils.data import DataLoader
31
  from torch.nn.utils import clip_grad_norm_
32
+ from loguru import logger
33
  from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
34
 
35
  from cosyvoice.dataset.dataset import Dataset
 
40
  world_size = int(os.environ.get('WORLD_SIZE', 1))
41
  local_rank = int(os.environ.get('LOCAL_RANK', 0))
42
  rank = int(os.environ.get('RANK', 0))
43
+ logger.info(f'training on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}')
 
44
  if args.train_engine == 'torch_ddp':
45
  torch.cuda.set_device(local_rank)
46
  dist.init_process_group(args.dist_backend)
 
69
 
70
 
71
  def check_modify_and_save_config(args, configs):
72
+ """Check and modify config"""
73
  if args.train_engine == "torch_ddp":
74
  configs['train_conf']["dtype"] = 'fp32'
75
  else:
 
92
 
93
 
94
  def wrap_cuda_model(args, model):
95
+ """Wrap model to cuda"""
96
  local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
97
  world_size = int(os.environ.get('WORLD_SIZE', 1))
98
  if args.train_engine == "torch_ddp": # native pytorch ddp
 
110
 
111
 
112
  def init_optimizer_and_scheduler(args, configs, model, gan):
113
+ """Init optimizer and scheduler"""
114
  if gan is False:
115
  if configs['train_conf']['optim'] == 'adam':
116
  optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
 
187
 
188
 
189
  def init_summarywriter(args):
190
+
191
  writer = None
192
  if int(os.environ.get('RANK', 0)) == 0:
193
  os.makedirs(args.model_dir, exist_ok=True)
 
218
 
219
 
220
  def cosyvoice_join(group_join, info_dict):
221
+ """Join all ranks"""
222
  world_size = int(os.environ.get('WORLD_SIZE', 1))
223
  local_rank = int(os.environ.get('LOCAL_RANK', 0))
224
  rank = int(os.environ.get('RANK', 0))
 
240
 
241
 
242
  def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
243
+ """ Forward batch and compute loss"""
244
  device = int(os.environ.get('LOCAL_RANK', 0))
245
 
246
  dtype = info_dict["dtype"]
 
281
 
282
  def batch_backward(model, scaler, info_dict):
283
  if info_dict["train_engine"] == "deepspeed":
284
+ scaled_loss = model.backward(info_dict['loss_dict']['loss'])
285
  else:
286
  scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
287
  if scaler is not None:
 
361
  loss_dict = info_dict["loss_dict"]
362
  lr = info_dict['lr']
363
  rank = int(os.environ.get('RANK', 0))
364
+ logger.info(
365
+ f'Epoch {epoch} Step {step + 1} CV info lr {lr} {rank} {''.join([f"{k} {v}" for k, v in loss_dict.items()])}')
 
366
 
367
  if writer is not None:
368
  for k in ['epoch', 'lr']:
speech/train.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+ from loguru import logger
26
+
27
+ from hyperpyyaml import load_hyperpyyaml
28
+
29
+ from torch.distributed.elastic.multiprocessing.errors import record
30
+
31
+ from cosyvoice.utils.losses import DPOLoss
32
+ from cosyvoice.utils.executor import Executor
33
+ from cosyvoice.utils.train_utils import (
34
+ init_distributed,
35
+ init_dataset_and_dataloader,
36
+ init_optimizer_and_scheduler,
37
+ init_summarywriter, save_model,
38
+ check_modify_and_save_config)
39
+
40
+
41
+ def get_args():
42
+ parser = argparse.ArgumentParser(description='training your network')
43
+ parser.add_argument('--train_engine',
44
+ default='torch_ddp',
45
+ choices=['torch_ddp', 'deepspeed'],
46
+ help='Engine for paralleled training')
47
+ parser.add_argument('--model', required=True, help='model which will be trained')
48
+ parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
49
+ parser.add_argument('--config', required=True, help='config file')
50
+ parser.add_argument('--train_data', required=True, help='train data file')
51
+ parser.add_argument('--cv_data', required=True, help='cv data file')
52
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
53
+ parser.add_argument('--checkpoint', help='checkpoint model')
54
+ parser.add_argument('--model_dir', required=True, help='save model dir')
55
+ parser.add_argument('--tensorboard_dir',
56
+ default='tensorboard',
57
+ help='tensorboard log dir')
58
+ parser.add_argument('--ddp.dist_backend',
59
+ dest='dist_backend',
60
+ default='nccl',
61
+ choices=['nccl', 'gloo'],
62
+ help='distributed backend')
63
+ parser.add_argument('--num_workers',
64
+ default=0,
65
+ type=int,
66
+ help='num of subprocess workers for reading')
67
+ parser.add_argument('--prefetch',
68
+ default=100,
69
+ type=int,
70
+ help='prefetch number')
71
+ parser.add_argument('--pin_memory',
72
+ action='store_true',
73
+ default=False,
74
+ help='Use pinned memory buffers used for reading')
75
+ parser.add_argument('--use_amp',
76
+ action='store_true',
77
+ default=False,
78
+ help='Use automatic mixed precision training')
79
+ parser.add_argument('--dpo',
80
+ action='store_true',
81
+ default=False,
82
+ help='Use Direct Preference Optimization')
83
+ parser.add_argument('--deepspeed.save_states',
84
+ dest='save_states',
85
+ default='model_only',
86
+ choices=['model_only', 'model+optimizer'],
87
+ help='save model/optimizer states')
88
+ parser.add_argument('--timeout',
89
+ default=60,
90
+ type=int,
91
+ help='timeout (in seconds) of cosyvoice_join.')
92
+ parser = deepspeed.add_config_arguments(parser)
93
+ args = parser.parse_args()
94
+ return args
95
+
96
+
97
+ @record
98
+ def main():
99
+ args = get_args()
100
+ logging.basicConfig(level=logging.DEBUG,
101
+ format='%(asctime)s %(levelname)s %(message)s')
102
+ # gan train has some special initialization logic
103
+ gan = True if args.model == 'hifigan' else False
104
+
105
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
106
+ if gan is True:
107
+ override_dict.pop('hift')
108
+ try:
109
+ with open(args.config, 'r') as f:
110
+ configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
111
+ except Exception:
112
+ with open(args.config, 'r') as f:
113
+ configs = load_hyperpyyaml(f, overrides=override_dict)
114
+ if gan is True:
115
+ configs['train_conf'] = configs['train_conf_gan']
116
+ configs['train_conf'].update(vars(args))
117
+
118
+ # Init env for ddp
119
+ init_distributed(args)
120
+
121
+ # Get dataset & dataloader
122
+ train_dataset, _, train_data_loader, cv_data_loader = \
123
+ init_dataset_and_dataloader(args, configs, gan, args.dpo)
124
+
125
+ # Do some sanity checks and save config to arsg.model_dir
126
+ configs = check_modify_and_save_config(args, configs)
127
+
128
+ # Tensorboard summary
129
+ writer = init_summarywriter(args)
130
+
131
+ # load checkpoint
132
+ if args.dpo is True:
133
+ configs[args.model].forward = configs[args.model].forward_dpo
134
+
135
+ model = configs[args.model]
136
+ start_step, start_epoch = 0, -1
137
+ if args.checkpoint is not None:
138
+ if os.path.exists(args.checkpoint):
139
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
140
+ model.load_state_dict(state_dict, strict=False)
141
+ if 'step' in state_dict:
142
+ start_step = state_dict['step']
143
+ if 'epoch' in state_dict:
144
+ start_epoch = state_dict['epoch']
145
+ else:
146
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
147
+
148
+ # Dispatch model from cpu to gpu
149
+ model = model.cuda()
150
+ model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
151
+
152
+
153
+ # Get optimizer & scheduler
154
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
155
+ scheduler.set_step(start_step)
156
+ if scheduler_d is not None:
157
+ scheduler_d.set_step(start_step)
158
+
159
+ # Save init checkpoints
160
+ info_dict = deepcopy(configs['train_conf'])
161
+ info_dict['step'] = start_step
162
+ info_dict['epoch'] = start_epoch
163
+ save_model(model, 'init', info_dict)
164
+
165
+ # DPO related
166
+ if args.dpo is True:
167
+ ref_model = deepcopy(configs[args.model])
168
+ state_dict = torch.load(args.ref_model, map_location='cpu')
169
+ ref_model.load_state_dict(state_dict, strict=False)
170
+ dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
171
+ ref_model = ref_model.cuda()
172
+ ref_model = torch.nn.parallel.DistributedDataParallel(ref_model, find_unused_parameters=True)
173
+ else:
174
+ ref_model, dpo_loss = None, None
175
+
176
+ # Get executor
177
+ executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
178
+ executor.step = start_step
179
+
180
+ # Init scaler, used for pytorch amp mixed precision training
181
+ scaler = torch.amp.GradScaler() if args.use_amp else None
182
+ logger.info(f'start step {start_step} start epoch {start_epoch}')
183
+
184
+ # Start training loop
185
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
186
+ executor.epoch = epoch
187
+ train_dataset.set_epoch(epoch)
188
+ dist.barrier()
189
+ group_join = dist.new_group(backend="nccl", timeout=datetime.timedelta(seconds=args.timeout))
190
+ if gan is True:
191
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
192
+ writer, info_dict, scaler, group_join)
193
+ else:
194
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
195
+ dist.destroy_process_group(group_join)
196
+
197
+
198
+ if __name__ == '__main__':
199
+ main()