Respair commited on
Commit
85e890c
·
verified ·
1 Parent(s): 545988c

Create ddp_train.py

Browse files
Files changed (1) hide show
  1. ddp_train.py +238 -0
ddp_train.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import time
9
+ from tqdm import tqdm
10
+ import os.path as osp
11
+ import re
12
+ import sys
13
+ import yaml
14
+ import shutil
15
+ from utils import *
16
+ from optimizers import build_optimizer
17
+ from model import *
18
+ from meldataset import build_dataloader
19
+ from utils import *
20
+ from torch.utils.tensorboard import SummaryWriter
21
+ import click
22
+
23
+ from accelerate import Accelerator
24
+ from accelerate.utils import LoggerType
25
+ from accelerate import DistributedDataParallelKwargs
26
+
27
+ import logging
28
+ from logging import StreamHandler
29
+ logger = logging.getLogger(__name__)
30
+ logger.setLevel(logging.DEBUG)
31
+ handler = StreamHandler()
32
+ handler.setLevel(logging.DEBUG)
33
+ logger.addHandler(handler)
34
+
35
+
36
+ import logging
37
+ from accelerate.logging import get_logger
38
+ logger = get_logger(__name__, log_level="DEBUG")
39
+
40
+ # torch.autograd.detect_anomaly(True)
41
+ torch.backends.cudnn.benchmark = True
42
+
43
+ @click.command()
44
+ @click.option('-p', '--config_path', default='./Configs/config.yml', type=str)
45
+ def main(config_path):
46
+
47
+ config = yaml.safe_load(open(config_path))
48
+ log_dir = config['log_dir']
49
+ if not osp.exists(log_dir): os.mkdir(log_dir)
50
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
51
+
52
+ writer = SummaryWriter(log_dir + "/tensorboard")
53
+
54
+ ddp_kwargs = DistributedDataParallelKwargs()
55
+ accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs], mixed_precision='bf16')
56
+ if accelerator.is_main_process:
57
+ writer = SummaryWriter(log_dir + "/tensorboard")
58
+
59
+
60
+ # write logs
61
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
62
+ file_handler.setLevel(logging.DEBUG)
63
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
64
+ logger.logger.addHandler(file_handler)
65
+
66
+ epoch = config.get('epoch', 100)
67
+ save_iter = 1
68
+ batch_size = config.get('batch_size', 4)
69
+ log_interval = 1
70
+ device = accelerator.device
71
+ train_path = config.get('train_data', None)
72
+ val_path = config.get('val_data', None)
73
+ epochs = config.get('epochs', 1000)
74
+
75
+ train_list, val_list = get_data_path_list(train_path, val_path)
76
+ train_list = val_list
77
+ train_dataloader = build_dataloader(train_list,
78
+ batch_size=batch_size,
79
+ num_workers=8,
80
+ dataset_config=config.get('dataset_params', {}),
81
+ device=device)
82
+
83
+ val_dataloader = build_dataloader(val_list,
84
+ batch_size=batch_size,
85
+ validation=True,
86
+ num_workers=2,
87
+ device=device,
88
+ dataset_config=config.get('dataset_params', {}))
89
+
90
+
91
+
92
+ aligner = AlignerModel()
93
+ forward_sum_loss = ForwardSumLoss()
94
+ best_val_loss = float('inf')
95
+
96
+
97
+ scheduler_params = {
98
+ "max_lr": float(config['optimizer_params'].get('lr', 5e-4)),
99
+ "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
100
+ "epochs": epochs,
101
+ "steps_per_epoch": len(train_dataloader),
102
+ }
103
+
104
+
105
+ optimizer, scheduler = build_optimizer(
106
+ {"params": aligner.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params})
107
+
108
+
109
+ aligner, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
110
+ aligner, optimizer, train_dataloader, val_dataloader, scheduler
111
+ )
112
+
113
+ with accelerator.main_process_first():
114
+ if config.get('pretrained_model', '') != '':
115
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
116
+ load_only_params=config.get('load_only_params', True))
117
+ else:
118
+ start_epoch = 0
119
+ iters = 0
120
+
121
+
122
+ # Training loop
123
+ for epoch in range(1, epochs + 1):
124
+ aligner.train()
125
+ train_losses = []
126
+ train_fwd_losses = []
127
+ start_time = time.time()
128
+
129
+
130
+ # Training phase
131
+ pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]")
132
+ for i, batch in enumerate(pbar):
133
+ batch = [b.to(device) for b in batch]
134
+
135
+ text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch
136
+
137
+ # Forward pass
138
+ attn_soft, attn_logprob = aligner(spec=mel_input,
139
+ spec_len=mel_input_length,
140
+ text=text_input,
141
+ text_len=text_input_length,
142
+ attn_prior=attn_prior)
143
+
144
+ # Calculate loss
145
+ loss = forward_sum_loss(attn_logprob=attn_logprob,
146
+ in_lens=text_input_length,
147
+ out_lens=mel_input_length)
148
+
149
+ # Backward pass and optimization
150
+ optimizer.zero_grad()
151
+ accelerator.backward(loss)
152
+
153
+ # Optional gradient clipping
154
+ grad_norm = accelerator.clip_grad_norm_(aligner.parameters(), 5.0)
155
+
156
+ optimizer.step()
157
+ iters = iters + 1
158
+
159
+ if scheduler is not None:
160
+ scheduler.step()
161
+
162
+
163
+ if (i+1)%log_interval == 0 and accelerator.is_main_process:
164
+ log_print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Forward Sum Loss: %.5f'
165
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, loss), logger)
166
+
167
+ writer.add_scalar('train/Forward Sum Loss', loss, iters)
168
+ # writer.add_scalar('train/d_loss', d_loss, iters)
169
+
170
+ train_losses.append(loss.item())
171
+ train_fwd_losses.append(loss.item())
172
+
173
+ running_loss = 0
174
+
175
+ accelerator.print('Time elasped:', time.time()-start_time)
176
+
177
+ # Calculate average training loss for this epoch
178
+ avg_train_loss = sum(train_losses) / len(train_losses)
179
+
180
+ # Validation phase
181
+ aligner.eval()
182
+ val_losses = []
183
+
184
+ with torch.no_grad():
185
+ for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"):
186
+ batch = [b.to(device) for b in batch]
187
+
188
+ text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch
189
+
190
+ # Forward pass
191
+ attn_soft, attn_logprob = aligner(spec=mel_input,
192
+ spec_len=mel_input_length,
193
+ text=text_input,
194
+ text_len=text_input_length,
195
+ attn_prior=attn_prior)
196
+
197
+ # Calculate loss
198
+ val_loss = forward_sum_loss(attn_logprob=attn_logprob,
199
+ in_lens=text_input_length,
200
+ out_lens=mel_input_length)
201
+
202
+ val_losses.append(val_loss.item())
203
+
204
+ # Calculate average validation loss
205
+ avg_val_loss = sum(val_losses) / len(val_losses)
206
+
207
+ # Log to TensorBoard
208
+ writer.add_scalar('epoch/train_loss', avg_train_loss, epoch)
209
+ writer.add_scalar('epoch/val_loss', avg_val_loss, epoch)
210
+
211
+ # Save checkpoint every N epochs
212
+
213
+ if (i+1)%save_iter == 0 and accelerator.is_main_process:
214
+
215
+ print(f'Saving on step {epoch*len(train_dataloader)+i}...')
216
+ state = {
217
+ 'net': {key: aligner[key].state_dict() for key in model},
218
+ 'optimizer': optimizer.state_dict(),
219
+ 'iters': iters,
220
+ 'epoch': epoch,
221
+ }
222
+ save_path = os.path.join(log_dir, 'checkpoints', f'TextAligner_checkpoint_epoch_{epoch}.pt')
223
+ torch.save(state, save_path)
224
+ # Print summary for this epoch
225
+ epoch_time = time.time() - start_time
226
+ accelerator.print(f"Epoch {epoch}/{epochs} completed in {epoch_time:.2f}s | "
227
+ f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
228
+
229
+ # # Plot and save attention matrices for visualization
230
+ # if epoch % config.get('plot_every', 10) == 0:
231
+ # plot_attention_matrices(aligner, val_dataloader, device,
232
+ # os.path.join(log_dir, 'attention_plots', f'epoch_{epoch}'),
233
+ # num_samples=4)
234
+
235
+ writer.close()
236
+
237
+ if __name__=="__main__":
238
+ main()