Respair commited on
Commit
ea101a8
·
verified ·
1 Parent(s): fe76d2f

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +177 -0
train.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import time
9
+ from tqdm import tqdm
10
+ from optimizers import build_optimizer
11
+
12
+ def train_aligner(config, accelerator, train_dataloader, val_dataloader, device, log_dir, epochs=100):
13
+ # Create model
14
+ aligner = AlignerModel().to(device)
15
+
16
+ # Define loss function
17
+ forward_sum_loss = ForwardSumLoss()
18
+
19
+ # Setup optimizer
20
+ scheduler_params = {
21
+ "max_lr": float(config['optimizer_params'].get('lr', 5e-4)),
22
+ "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
23
+ "epochs": epochs,
24
+ "steps_per_epoch": len(train_dataloader),
25
+ }
26
+
27
+ optimizer, scheduler = build_optimizer(
28
+ {"params": aligner.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params})
29
+
30
+ # Setup TensorBoard writer
31
+ writer = SummaryWriter(log_dir=log_dir)
32
+
33
+ # Create directories for model checkpoints
34
+ os.makedirs(os.path.join(log_dir, 'checkpoints'), exist_ok=True)
35
+
36
+ # Track best validation loss
37
+ best_val_loss = float('inf')
38
+
39
+ # Loss weights
40
+ fwd_sum_loss_weight = config.get('fwd_sum_loss_weight', 1.0)
41
+
42
+ # Training loop
43
+ for epoch in range(1, epochs + 1):
44
+ aligner.train()
45
+ train_losses = []
46
+ train_fwd_losses = []
47
+ start_time = time.time()
48
+
49
+ # Training phase
50
+ pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]")
51
+ for i, batch in enumerate(pbar):
52
+ batch = [b.to(device) for b in batch]
53
+
54
+ text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch
55
+
56
+ # Forward pass
57
+ attn_soft, attn_logprob = aligner(spec=mel_input,
58
+ spec_len=mel_input_length,
59
+ text=text_input,
60
+ text_len=text_input_length,
61
+ attn_prior=attn_prior)
62
+
63
+ # Calculate loss
64
+ loss = forward_sum_loss(attn_logprob=attn_logprob,
65
+ in_lens=text_input_length,
66
+ out_lens=mel_input_length)
67
+
68
+ # Backward pass and optimization
69
+ optimizer.zero_grad()
70
+ loss.backward()
71
+
72
+ # Optional gradient clipping
73
+ grad_norm = nn.utils.clip_grad_norm_(aligner.parameters(), config.get('grad_clip', 5.0))
74
+
75
+ optimizer.step()
76
+ if scheduler is not None:
77
+ scheduler.step()
78
+
79
+ # Log to TensorBoard
80
+ global_step = (epoch - 1) * len(train_dataloader) + i
81
+ writer.add_scalar('train/total_loss', loss.item(), global_step)
82
+ writer.add_scalar('train/grad_norm', grad_norm, global_step)
83
+
84
+ # Update progress bar
85
+ train_losses.append(loss.item())
86
+ train_fwd_losses.append(loss.item())
87
+
88
+ # Update the progress bar description
89
+ pbar.set_description(f"Epoch {epoch}/{epochs} [Train] Loss: {loss.item():.4f}")
90
+
91
+ # Calculate average training loss for this epoch
92
+ avg_train_loss = sum(train_losses) / len(train_losses)
93
+
94
+ # Validation phase
95
+ aligner.eval()
96
+ val_losses = []
97
+
98
+ with torch.no_grad():
99
+ for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"):
100
+ batch = [b.to(device) for b in batch]
101
+
102
+ text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch
103
+
104
+ # Forward pass
105
+ attn_soft, attn_logprob = aligner(spec=mel_input,
106
+ spec_len=mel_input_length,
107
+ text=text_input,
108
+ text_len=text_input_length,
109
+ attn_prior=attn_prior)
110
+
111
+ # Calculate loss
112
+ val_loss = forward_sum_loss(attn_logprob=attn_logprob,
113
+ in_lens=text_input_length,
114
+ out_lens=mel_input_length)
115
+
116
+ val_losses.append(val_loss.item())
117
+
118
+ # Calculate average validation loss
119
+ avg_val_loss = sum(val_losses) / len(val_losses)
120
+
121
+ # Log to TensorBoard
122
+ writer.add_scalar('epoch/train_loss', avg_train_loss, epoch)
123
+ writer.add_scalar('epoch/val_loss', avg_val_loss, epoch)
124
+
125
+ # Save model if it's the best so far
126
+ if avg_val_loss < best_val_loss:
127
+ best_val_loss = avg_val_loss
128
+ torch.save({
129
+ 'epoch': epoch,
130
+ 'model_state_dict': aligner.state_dict(),
131
+ 'optimizer_state_dict': optimizer.state_dict(),
132
+ 'train_loss': avg_train_loss,
133
+ 'val_loss': avg_val_loss,
134
+ }, os.path.join(log_dir, 'checkpoints', 'best_model.pt'))
135
+
136
+ # Save checkpoint every N epochs
137
+ if epoch % config.get('save_every', 10) == 0:
138
+ torch.save({
139
+ 'epoch': epoch,
140
+ 'model_state_dict': aligner.state_dict(),
141
+ 'optimizer_state_dict': optimizer.state_dict(),
142
+ 'train_loss': avg_train_loss,
143
+ 'val_loss': avg_val_loss,
144
+ }, os.path.join(log_dir, 'checkpoints', f'checkpoint_epoch_{epoch}.pt'))
145
+
146
+ # Print summary for this epoch
147
+ epoch_time = time.time() - start_time
148
+ print(f"Epoch {epoch}/{epochs} completed in {epoch_time:.2f}s | "
149
+ f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
150
+
151
+ # Plot and save attention matrices for visualization
152
+ if epoch % config.get('plot_every', 10) == 0:
153
+ plot_attention_matrices(aligner, val_dataloader, device,
154
+ os.path.join(log_dir, 'attention_plots', f'epoch_{epoch}'),
155
+ num_samples=4)
156
+
157
+ writer.close()
158
+ print(f"Training completed. Best validation loss: {best_val_loss:.4f}")
159
+ return aligner
160
+
161
+ # Main execution
162
+ if __name__ == "__main__":
163
+
164
+ def length_to_mask(lengths):
165
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
166
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
167
+ return mask
168
+
169
+ # Assuming these variables are defined in your main script
170
+ train_aligner(
171
+ config=config,
172
+ train_dataloader=train_dataloader,
173
+ val_dataloader=val_dataloader,
174
+ device=device,
175
+ log_dir=log_dir,
176
+ epochs=epoch
177
+ )