kfoughali commited on
Commit
cf47595
Β·
verified Β·
1 Parent(s): a7a0326

Update core/trainer.py

Browse files
Files changed (1) hide show
  1. core/trainer.py +80 -109
core/trainer.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
4
- from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
5
  import numpy as np
6
  import time
7
  import logging
@@ -10,33 +10,40 @@ from utils.metrics import GraphMetrics
10
  logger = logging.getLogger(__name__)
11
 
12
  class GraphMambaTrainer:
13
- """Enhanced trainer with optimized learning rates and schedules"""
14
 
15
  def __init__(self, model, config, device):
16
  self.model = model
17
  self.config = config
18
  self.device = device
19
 
20
- # Fixed learning rate (much lower)
21
- self.lr = 0.001 # Changed from 0.01
22
  self.epochs = config['training']['epochs']
23
- self.patience = config['training'].get('patience', 15)
24
  self.min_lr = config['training'].get('min_lr', 1e-6)
25
 
26
- # Enhanced optimizer
27
  self.optimizer = optim.AdamW(
28
  model.parameters(),
29
  lr=self.lr,
30
- weight_decay=config['training']['weight_decay'],
31
  betas=(0.9, 0.999),
32
  eps=1e-8
33
  )
34
 
35
- # Proper loss function
36
- self.criterion = nn.CrossEntropyLoss()
37
 
38
- # Learning rate scheduler (will be set in training)
39
- self.scheduler = None
 
 
 
 
 
 
 
40
 
41
  # Training state
42
  self.best_val_acc = 0.0
@@ -46,36 +53,34 @@ class GraphMambaTrainer:
46
  'train_loss': [], 'train_acc': [],
47
  'val_loss': [], 'val_acc': [], 'lr': []
48
  }
49
-
50
- def _setup_scheduler(self, total_steps):
51
- """Setup learning rate scheduler"""
52
- self.scheduler = OneCycleLR(
53
- self.optimizer,
54
- max_lr=self.lr,
55
- total_steps=total_steps,
56
- pct_start=0.1, # 10% warmup
57
- anneal_strategy='cos',
58
- div_factor=10.0, # Start LR = max_lr/10
59
- final_div_factor=100.0 # End LR = max_lr/100
60
- )
61
 
62
  def train_node_classification(self, data, verbose=True):
63
- """Enhanced training with proper LR scheduling"""
64
 
65
  if verbose:
 
 
 
 
66
  print(f"πŸ‹οΈ Training GraphMamba for {self.epochs} epochs")
67
  print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
68
  print(f"🎯 Classes: {len(torch.unique(data.y))}")
69
  print(f"πŸ’Ύ Device: {self.device}")
70
- print(f"βš™οΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
 
 
 
 
 
71
 
72
  # Initialize classifier
73
  num_classes = len(torch.unique(data.y))
74
  self.model._init_classifier(num_classes, self.device)
75
 
76
- # Setup scheduler
77
- self._setup_scheduler(self.epochs)
78
-
79
  self.model.train()
80
  start_time = time.time()
81
 
@@ -86,6 +91,9 @@ class GraphMambaTrainer:
86
  # Validation step
87
  val_metrics = self._validate_epoch(data, epoch)
88
 
 
 
 
89
  # Update history
90
  self.training_history['train_loss'].append(train_metrics['loss'])
91
  self.training_history['train_acc'].append(train_metrics['acc'])
@@ -93,59 +101,85 @@ class GraphMambaTrainer:
93
  self.training_history['val_acc'].append(val_metrics['acc'])
94
  self.training_history['lr'].append(self.optimizer.param_groups[0]['lr'])
95
 
 
 
 
96
  # Check for improvement
97
  if val_metrics['acc'] > self.best_val_acc:
98
  self.best_val_acc = val_metrics['acc']
99
  self.best_val_loss = val_metrics['loss']
 
100
  self.patience_counter = 0
101
  if verbose:
102
  print(f"πŸŽ‰ New best validation accuracy: {self.best_val_acc:.4f}")
103
  else:
104
  self.patience_counter += 1
105
 
106
- # Progress logging
 
 
 
 
 
 
107
  if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
108
  elapsed = time.time() - start_time
 
 
109
  print(f"Epoch {epoch:3d} | "
110
  f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
111
  f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | "
112
- f"LR: {self.optimizer.param_groups[0]['lr']:.6f} | "
113
- f"Time: {elapsed:.1f}s")
114
 
115
- # Early stopping
116
  if self.patience_counter >= self.patience:
117
  if verbose:
118
- print(f"πŸ›‘ Early stopping at epoch {epoch}")
 
 
 
 
 
 
119
  break
120
-
121
- # Step scheduler
122
- self.scheduler.step()
123
 
124
  if verbose:
125
  total_time = time.time() - start_time
126
  print(f"βœ… Training completed in {total_time:.2f}s")
127
  print(f"πŸ† Best validation accuracy: {self.best_val_acc:.4f}")
 
 
 
 
 
 
 
 
128
 
129
  return self.training_history
130
 
131
  def _train_epoch(self, data, epoch):
132
- """Single training epoch"""
133
  self.model.train()
134
  self.optimizer.zero_grad()
135
 
136
- # Forward pass
137
  h = self.model(data.x, data.edge_index)
138
  logits = self.model.classifier(h)
139
 
140
- # Compute loss on training nodes
141
  train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
142
 
143
- # Backward pass
144
- train_loss.backward()
 
 
 
145
 
146
- # Gradient clipping
 
147
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
148
-
149
  self.optimizer.step()
150
 
151
  # Compute accuracy
@@ -156,7 +190,7 @@ class GraphMambaTrainer:
156
  return {'loss': train_loss.item(), 'acc': train_acc}
157
 
158
  def _validate_epoch(self, data, epoch):
159
- """Single validation epoch"""
160
  self.model.eval()
161
 
162
  with torch.no_grad():
@@ -171,13 +205,12 @@ class GraphMambaTrainer:
171
  return {'loss': val_loss.item(), 'acc': val_acc}
172
 
173
  def test(self, data):
174
- """Comprehensive test evaluation"""
175
  self.model.eval()
176
 
177
  with torch.no_grad():
178
  h = self.model(data.x, data.edge_index)
179
 
180
- # Ensure classifier exists
181
  if self.model.classifier is None:
182
  num_classes = len(torch.unique(data.y))
183
  self.model._init_classifier(num_classes, self.device)
@@ -189,7 +222,6 @@ class GraphMambaTrainer:
189
  test_pred = logits[data.test_mask]
190
  test_target = data.y[data.test_mask]
191
 
192
- # Comprehensive metrics
193
  metrics = {
194
  'test_loss': test_loss.item(),
195
  'test_acc': GraphMetrics.accuracy(test_pred, test_target),
@@ -197,7 +229,6 @@ class GraphMambaTrainer:
197
  'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target),
198
  }
199
 
200
- # Additional metrics
201
  precision, recall = GraphMetrics.precision_recall(test_pred, test_target)
202
  metrics['precision'] = precision
203
  metrics['recall'] = recall
@@ -208,64 +239,4 @@ class GraphMambaTrainer:
208
  """Get node embeddings"""
209
  self.model.eval()
210
  with torch.no_grad():
211
- return self.model(data.x, data.edge_index)
212
-
213
-
214
- class EnhancedGraphMambaTrainer(GraphMambaTrainer):
215
- """Enhanced trainer with additional optimizations"""
216
-
217
- def __init__(self, model, config, device):
218
- super().__init__(model, config, device)
219
-
220
- # Even more conservative learning rate for complex architectures
221
- if hasattr(model, 'multi_scale') or 'Hybrid' in model.__class__.__name__:
222
- self.lr = 0.0005 # Lower for complex models
223
-
224
- self.optimizer = optim.AdamW(
225
- model.parameters(),
226
- lr=self.lr,
227
- weight_decay=config['training']['weight_decay'],
228
- betas=(0.9, 0.99), # More stable
229
- eps=1e-8
230
- )
231
-
232
- def _setup_scheduler(self, total_steps):
233
- """Enhanced scheduler for complex models"""
234
- # Cosine annealing with warm restarts
235
- self.scheduler = CosineAnnealingWarmRestarts(
236
- self.optimizer,
237
- T_0=20, # Restart every 20 epochs
238
- T_mult=2, # Double period after restart
239
- eta_min=self.min_lr
240
- )
241
-
242
- def train_node_classification(self, data, verbose=True):
243
- """Training with enhanced monitoring"""
244
-
245
- if verbose:
246
- model_type = self.model.__class__.__name__
247
- print(f"πŸ‹οΈ Training {model_type} for {self.epochs} epochs")
248
- print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
249
- print(f"🎯 Classes: {len(torch.unique(data.y))}")
250
- print(f"πŸ’Ύ Device: {self.device}")
251
- print(f"βš™οΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
252
- print(f"πŸ“ˆ Learning Rate: {self.lr} (enhanced schedule)")
253
-
254
- # Call parent method with enhancements
255
- history = super().train_node_classification(data, verbose)
256
-
257
- # Additional analysis
258
- if verbose:
259
- final_acc = history['val_acc'][-1] if history['val_acc'] else 0
260
- improvement = final_acc - (history['val_acc'][0] if history['val_acc'] else 0)
261
- print(f"πŸ“Š Final validation accuracy: {final_acc:.4f}")
262
- print(f"πŸ“ˆ Total improvement: {improvement:.4f} ({improvement*100:.1f}%)")
263
-
264
- if final_acc > 0.6:
265
- print("πŸŽ‰ Excellent performance! Model converged well.")
266
- elif final_acc > 0.4:
267
- print("πŸ‘ Good progress! Consider more epochs or tuning.")
268
- else:
269
- print("⚠️ Low accuracy. Check model architecture or data.")
270
-
271
- return history
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
4
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
5
  import numpy as np
6
  import time
7
  import logging
 
10
  logger = logging.getLogger(__name__)
11
 
12
  class GraphMambaTrainer:
13
+ """Anti-overfitting trainer with heavy regularization"""
14
 
15
  def __init__(self, model, config, device):
16
  self.model = model
17
  self.config = config
18
  self.device = device
19
 
20
+ # Conservative learning rate
21
+ self.lr = config['training']['learning_rate'] # Should be 0.0005
22
  self.epochs = config['training']['epochs']
23
+ self.patience = config['training'].get('patience', 10)
24
  self.min_lr = config['training'].get('min_lr', 1e-6)
25
 
26
+ # Heavily regularized optimizer
27
  self.optimizer = optim.AdamW(
28
  model.parameters(),
29
  lr=self.lr,
30
+ weight_decay=config['training']['weight_decay'], # Should be 0.01
31
  betas=(0.9, 0.999),
32
  eps=1e-8
33
  )
34
 
35
+ # Proper loss function with label smoothing
36
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
37
 
38
+ # Conservative scheduler
39
+ self.scheduler = ReduceLROnPlateau(
40
+ self.optimizer,
41
+ mode='max',
42
+ factor=0.5,
43
+ patience=5,
44
+ min_lr=self.min_lr,
45
+ verbose=True
46
+ )
47
 
48
  # Training state
49
  self.best_val_acc = 0.0
 
53
  'train_loss': [], 'train_acc': [],
54
  'val_loss': [], 'val_acc': [], 'lr': []
55
  }
56
+
57
+ # Track overfitting
58
+ self.best_gap = float('inf')
59
+ self.overfitting_threshold = 0.3 # Stop if train-val gap > 30%
 
 
 
 
 
 
 
 
60
 
61
  def train_node_classification(self, data, verbose=True):
62
+ """Anti-overfitting training"""
63
 
64
  if verbose:
65
+ total_params = sum(p.numel() for p in self.model.parameters())
66
+ train_samples = data.train_mask.sum().item()
67
+ params_per_sample = total_params / train_samples
68
+
69
  print(f"πŸ‹οΈ Training GraphMamba for {self.epochs} epochs")
70
  print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
71
  print(f"🎯 Classes: {len(torch.unique(data.y))}")
72
  print(f"πŸ’Ύ Device: {self.device}")
73
+ print(f"βš™οΈ Parameters: {total_params:,}")
74
+ print(f"πŸ“š Training samples: {train_samples}")
75
+ print(f"⚠️ Params per sample: {params_per_sample:.1f}")
76
+
77
+ if params_per_sample > 1000:
78
+ print(f"🚨 WARNING: High params per sample ratio - overfitting risk!")
79
 
80
  # Initialize classifier
81
  num_classes = len(torch.unique(data.y))
82
  self.model._init_classifier(num_classes, self.device)
83
 
 
 
 
84
  self.model.train()
85
  start_time = time.time()
86
 
 
91
  # Validation step
92
  val_metrics = self._validate_epoch(data, epoch)
93
 
94
+ # Calculate overfitting gap
95
+ acc_gap = train_metrics['acc'] - val_metrics['acc']
96
+
97
  # Update history
98
  self.training_history['train_loss'].append(train_metrics['loss'])
99
  self.training_history['train_acc'].append(train_metrics['acc'])
 
101
  self.training_history['val_acc'].append(val_metrics['acc'])
102
  self.training_history['lr'].append(self.optimizer.param_groups[0]['lr'])
103
 
104
+ # Step scheduler
105
+ self.scheduler.step(val_metrics['acc'])
106
+
107
  # Check for improvement
108
  if val_metrics['acc'] > self.best_val_acc:
109
  self.best_val_acc = val_metrics['acc']
110
  self.best_val_loss = val_metrics['loss']
111
+ self.best_gap = acc_gap
112
  self.patience_counter = 0
113
  if verbose:
114
  print(f"πŸŽ‰ New best validation accuracy: {self.best_val_acc:.4f}")
115
  else:
116
  self.patience_counter += 1
117
 
118
+ # Overfitting detection
119
+ if acc_gap > self.overfitting_threshold:
120
+ if verbose:
121
+ print(f"🚨 OVERFITTING detected: {acc_gap:.3f} gap")
122
+ print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
123
+
124
+ # Progress logging with overfitting monitoring
125
  if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
126
  elapsed = time.time() - start_time
127
+ gap_indicator = "🚨" if acc_gap > 0.2 else "⚠️" if acc_gap > 0.1 else "βœ…"
128
+
129
  print(f"Epoch {epoch:3d} | "
130
  f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
131
  f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | "
132
+ f"Gap: {acc_gap:.3f} {gap_indicator} | "
133
+ f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")
134
 
135
+ # Early stopping conditions
136
  if self.patience_counter >= self.patience:
137
  if verbose:
138
+ print(f"πŸ›‘ Early stopping at epoch {epoch} (patience)")
139
+ break
140
+
141
+ # Stop if severe overfitting
142
+ if acc_gap > 0.5:
143
+ if verbose:
144
+ print(f"πŸ›‘ Stopping due to severe overfitting (gap: {acc_gap:.3f})")
145
  break
 
 
 
146
 
147
  if verbose:
148
  total_time = time.time() - start_time
149
  print(f"βœ… Training completed in {total_time:.2f}s")
150
  print(f"πŸ† Best validation accuracy: {self.best_val_acc:.4f}")
151
+ print(f"πŸ“Š Best train-val gap: {self.best_gap:.4f}")
152
+
153
+ if self.best_gap < 0.1:
154
+ print("πŸŽ‰ Excellent generalization!")
155
+ elif self.best_gap < 0.2:
156
+ print("πŸ‘ Good generalization")
157
+ else:
158
+ print("⚠️ Some overfitting detected")
159
 
160
  return self.training_history
161
 
162
  def _train_epoch(self, data, epoch):
163
+ """Single training epoch with regularization"""
164
  self.model.train()
165
  self.optimizer.zero_grad()
166
 
167
+ # Forward pass (with data augmentation)
168
  h = self.model(data.x, data.edge_index)
169
  logits = self.model.classifier(h)
170
 
171
+ # Compute loss on training nodes only
172
  train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
173
 
174
+ # Add L2 regularization manually if needed
175
+ l2_reg = 0.0
176
+ for param in self.model.parameters():
177
+ l2_reg += torch.norm(param, p=2)
178
+ train_loss += 1e-5 * l2_reg # Small additional L2
179
 
180
+ # Backward pass with gradient clipping
181
+ train_loss.backward()
182
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
 
183
  self.optimizer.step()
184
 
185
  # Compute accuracy
 
190
  return {'loss': train_loss.item(), 'acc': train_acc}
191
 
192
  def _validate_epoch(self, data, epoch):
193
+ """Validation without augmentation"""
194
  self.model.eval()
195
 
196
  with torch.no_grad():
 
205
  return {'loss': val_loss.item(), 'acc': val_acc}
206
 
207
  def test(self, data):
208
+ """Test evaluation"""
209
  self.model.eval()
210
 
211
  with torch.no_grad():
212
  h = self.model(data.x, data.edge_index)
213
 
 
214
  if self.model.classifier is None:
215
  num_classes = len(torch.unique(data.y))
216
  self.model._init_classifier(num_classes, self.device)
 
222
  test_pred = logits[data.test_mask]
223
  test_target = data.y[data.test_mask]
224
 
 
225
  metrics = {
226
  'test_loss': test_loss.item(),
227
  'test_acc': GraphMetrics.accuracy(test_pred, test_target),
 
229
  'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target),
230
  }
231
 
 
232
  precision, recall = GraphMetrics.precision_recall(test_pred, test_target)
233
  metrics['precision'] = precision
234
  metrics['recall'] = recall
 
239
  """Get node embeddings"""
240
  self.model.eval()
241
  with torch.no_grad():
242
+ return self.model(data.x, data.edge_index)