mateo496 commited on
Commit
83d5454
·
1 Parent(s): 20ecf01

Added Transformer training

Browse files
Files changed (3) hide show
  1. .gitignore +3 -2
  2. src/app/server.py +1 -1
  3. src/models/traintransformer.py +234 -0
.gitignore CHANGED
@@ -2,8 +2,7 @@ data/
2
  *.npy
3
  *.wav
4
 
5
- models/checkpoints
6
- models/saved
7
  *.pt
8
  *.pth
9
 
@@ -28,6 +27,8 @@ wheels/
28
  .installed.cfg
29
  *.egg
30
 
 
 
31
  .dockerignore
32
  Dockerfile
33
  requirements.txt
 
2
  *.npy
3
  *.wav
4
 
5
+ models/
 
6
  *.pt
7
  *.pth
8
 
 
27
  .installed.cfg
28
  *.egg
29
 
30
+ test/
31
+
32
  .dockerignore
33
  Dockerfile
34
  requirements.txt
src/app/server.py CHANGED
@@ -22,7 +22,7 @@ app = FastAPI(
22
 
23
  model = None
24
  device = None
25
- model_path="models/saved/model30012026_74valacc.pt"
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  model = load_model(model_path, device)
28
 
 
22
 
23
  model = None
24
  device = None
25
+ model_path="models/cnn/saved/final_model.pt"
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  model = load_model(model_path, device)
28
 
src/models/traintransformer.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.optim import Adam, AdamW
5
+ from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
6
+ import os
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import time
10
+
11
+
12
+ class TransformerTrainer:
13
+ def __init__(
14
+ self,
15
+ model,
16
+ train_loader,
17
+ val_loader,
18
+ num_epochs=50,
19
+ learning_rate=1e-4,
20
+ weight_decay=1e-4,
21
+ warmup_epochs=5,
22
+ checkpoint_dir="models/transformer/checkpoints",
23
+ device="cuda"
24
+ ):
25
+ self.model = model.to(device)
26
+ self.train_loader = train_loader
27
+ self.val_loader = val_loader
28
+ self.num_epochs = num_epochs
29
+ self.device = device
30
+ self.checkpoint_dir = checkpoint_dir
31
+
32
+ os.makedirs(checkpoint_dir, exist_ok=True)
33
+
34
+ self.optimizer = AdamW(
35
+ model.parameters(),
36
+ lr=learning_rate,
37
+ weight_decay=weight_decay,
38
+ betas=(0.9, 0.999)
39
+ )
40
+
41
+ self.scheduler = CosineAnnealingLR(
42
+ self.optimizer,
43
+ T_max = num_epochs - warmup_epochs,
44
+ eta_min=1e-6
45
+ )
46
+
47
+ self.warmup_epochs = warmup_epochs
48
+ self.base_lr = learning_rate
49
+
50
+ self.criterion = nn.CrossEntropyLoss()
51
+
52
+ self.train_loss = []
53
+ self.val_loss = []
54
+ self.train_acc = []
55
+ self.val_acc = []
56
+ self.best_val_acc = 0
57
+ self.best_epoch = 0
58
+
59
+ def warmup_lr(self, epoch):
60
+ if epoch < self.warmup_epochs:
61
+ lr = self.base_lr * (epoch + 1) / self.warmup_epochs
62
+ for param_group in self.optimizer.param_groups:
63
+ param_group['lr'] = lr
64
+
65
+ def train_epoch(self, epoch):
66
+ self.model.train()
67
+
68
+ total_loss = 0
69
+ correct = 0
70
+ total = 0
71
+
72
+ for batch_idx, (data, target) in enumerate(tqdm(self.train_loader, des=f"Epoch {epoch}/{self.num_epochs}")):
73
+ data, target = data.to(self.device), target.to(self.device)
74
+
75
+ self.optimizer.zero_grad()
76
+ output = self.model(data)
77
+ loss = self.criterion(output, target)
78
+
79
+ loss.backward()
80
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
81
+ self.optimizer.step()
82
+
83
+ pred = output.argmax(dim=1)
84
+ correct += pred.eq(target).sum().item()
85
+
86
+ total_loss += loss.item()
87
+ total += target.size()
88
+
89
+ pbar.set_postfix({
90
+ 'loss': total_loss / (batch_idx + 1),
91
+ 'acc': 100. * correct / total,
92
+ 'lr': self.optimizer.param_groups[0]['lr']
93
+ })
94
+
95
+ avg_loss = total_loss / len(self.train_loader)
96
+ avg_acc = 100. * correct / total
97
+
98
+ return avg_loss, avg_acc
99
+
100
+ def validate(self):
101
+ self.model.eval()
102
+
103
+ total_loss = 0
104
+ correct = 0
105
+ total = 0
106
+
107
+ with torch.no_grad():
108
+ for data, target in tqdm(self.val_loader, desc='Validation'):
109
+ data, target = data.to(self.device), target.to(self.device)
110
+
111
+ output = self.model(data)
112
+ loss = self.criterion(output, target)
113
+
114
+ total_loss += loss.item()
115
+ pred = output.argmax(dim=1)
116
+ correct += pred.eq(target).sum().item()
117
+ total += target.size(0)
118
+
119
+ avg_loss = total_loss / len(self.val_loader)
120
+ avg_acc = 100. * correct / total
121
+
122
+ return avg_loss, avg_acc
123
+
124
+ def save_checkpoint(self, epoch, val_acc, is_best=False):
125
+ """Save model checkpoint."""
126
+ checkpoint = {
127
+ 'epoch': epoch,
128
+ 'model_state_dict': self.model.state_dict(),
129
+ 'optimizer_state_dict': self.optimizer.state_dict(),
130
+ 'scheduler_state_dict': self.scheduler.state_dict(),
131
+ 'val_acc': val_acc,
132
+ 'train_losses': self.train_losses,
133
+ 'val_losses': self.val_losses,
134
+ 'train_accs': self.train_accs,
135
+ 'val_accs': self.val_accs,
136
+ }
137
+
138
+ # Save latest checkpoint
139
+ path = os.path.join(self.checkpoint_dir, 'checkpoint_latest.pth')
140
+ torch.save(checkpoint, path)
141
+
142
+ # Save best checkpoint
143
+ if is_best:
144
+ path = os.path.join(self.checkpoint_dir, 'checkpoint_best.pth')
145
+ torch.save(checkpoint, path)
146
+ print(f'✓ Saved best model with val_acc: {val_acc:.2f}%')
147
+
148
+ def load_checkpoint(self, checkpoint_path):
149
+ """Load model checkpoint."""
150
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
151
+
152
+ self.model.load_state_dict(checkpoint['model_state_dict'])
153
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
154
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
155
+
156
+ self.train_losses = checkpoint['train_losses']
157
+ self.val_losses = checkpoint['val_losses']
158
+ self.train_accs = checkpoint['train_accs']
159
+ self.val_accs = checkpoint['val_accs']
160
+
161
+ print(f'✓ Loaded checkpoint from epoch {checkpoint["epoch"]}')
162
+ return checkpoint['epoch']
163
+
164
+ def train(self, resume_from=None):
165
+ """
166
+ Main training loop.
167
+
168
+ Args:
169
+ resume_from: Path to checkpoint to resume from
170
+
171
+ Returns:
172
+ Best validation accuracy
173
+ """
174
+ start_epoch = 1
175
+
176
+ if resume_from:
177
+ start_epoch = self.load_checkpoint(resume_from) + 1
178
+
179
+ print(f'\nStarting training for {self.num_epochs} epochs')
180
+ print(f'Device: {self.device}')
181
+ print(f'Training samples: {len(self.train_loader.dataset)}')
182
+ print(f'Validation samples: {len(self.val_loader.dataset)}')
183
+ print('-' * 60)
184
+
185
+ start_time = time.time()
186
+
187
+ for epoch in range(start_epoch, self.num_epochs + 1):
188
+ # Warmup learning rate
189
+ if epoch <= self.warmup_epochs:
190
+ self._warmup_lr(epoch - 1)
191
+
192
+ # Train
193
+ train_loss, train_acc = self.train_epoch(epoch)
194
+
195
+ # Validate
196
+ val_loss, val_acc = self.validate()
197
+
198
+ # Update scheduler (after warmup)
199
+ if epoch > self.warmup_epochs:
200
+ self.scheduler.step()
201
+
202
+ # Track metrics
203
+ self.train_losses.append(train_loss)
204
+ self.val_losses.append(val_loss)
205
+ self.train_accs.append(train_acc)
206
+ self.val_accs.append(val_acc)
207
+
208
+ # Print epoch summary
209
+ print(f'\nEpoch {epoch}/{self.num_epochs}:')
210
+ print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
211
+ print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
212
+ print(f' LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
213
+
214
+ # Save checkpoint
215
+ is_best = val_acc > self.best_val_acc
216
+ if is_best:
217
+ self.best_val_acc = val_acc
218
+ self.best_epoch = epoch
219
+
220
+ self.save_checkpoint(epoch, val_acc, is_best)
221
+
222
+ # Early stopping check (optional)
223
+ if epoch - self.best_epoch > 30:
224
+ print(f'\nEarly stopping: no improvement for 30 epochs')
225
+ break
226
+
227
+ elapsed_time = time.time() - start_time
228
+ print(f'\n{"="*60}')
229
+ print(f'Training completed in {elapsed_time/3600:.2f} hours')
230
+ print(f'Best validation accuracy: {self.best_val_acc:.2f}% at epoch {self.best_epoch}')
231
+ print(f'{"="*60}')
232
+
233
+ return self.best_val_acc
234
+