kfoughali commited on
Commit
d64b5de
Β·
verified Β·
1 Parent(s): 83f379f

Update core/trainer.py

Browse files
Files changed (1) hide show
  1. core/trainer.py +27 -17
core/trainer.py CHANGED
@@ -20,8 +20,9 @@ class GraphMambaTrainer:
20
  # Conservative learning rate
21
  self.lr = config['training']['learning_rate']
22
  self.epochs = config['training']['epochs']
23
- self.patience = config['training'].get('patience', 10)
24
  self.min_lr = config['training'].get('min_lr', 1e-6)
 
25
 
26
  # Heavily regularized optimizer
27
  self.optimizer = optim.AdamW(
@@ -33,14 +34,16 @@ class GraphMambaTrainer:
33
  )
34
 
35
  # Proper loss function with label smoothing
36
- self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
 
 
37
 
38
- # Conservative scheduler - FIXED: removed verbose parameter
39
  self.scheduler = ReduceLROnPlateau(
40
  self.optimizer,
41
  mode='max',
42
  factor=0.5,
43
- patience=5,
44
  min_lr=self.min_lr
45
  )
46
 
@@ -55,10 +58,10 @@ class GraphMambaTrainer:
55
 
56
  # Track overfitting
57
  self.best_gap = float('inf')
58
- self.overfitting_threshold = 0.3
59
 
60
  def train_node_classification(self, data, verbose=True):
61
- """Anti-overfitting training"""
62
 
63
  if verbose:
64
  total_params = sum(p.numel() for p in self.model.parameters())
@@ -72,8 +75,9 @@ class GraphMambaTrainer:
72
  print(f"βš™οΈ Parameters: {total_params:,}")
73
  print(f"πŸ“š Training samples: {train_samples}")
74
  print(f"⚠️ Params per sample: {params_per_sample:.1f}")
 
75
 
76
- if params_per_sample > 1000:
77
  print(f"🚨 WARNING: High params per sample ratio - overfitting risk!")
78
 
79
  # Initialize classifier
@@ -114,16 +118,16 @@ class GraphMambaTrainer:
114
  else:
115
  self.patience_counter += 1
116
 
117
- # Overfitting detection
118
  if acc_gap > self.overfitting_threshold:
119
  if verbose:
120
  print(f"🚨 OVERFITTING detected: {acc_gap:.3f} gap")
121
  print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
122
 
123
  # Progress logging
124
- if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
125
  elapsed = time.time() - start_time
126
- gap_indicator = "🚨" if acc_gap > 0.2 else "⚠️" if acc_gap > 0.1 else "βœ…"
127
 
128
  print(f"Epoch {epoch:3d} | "
129
  f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
@@ -131,16 +135,22 @@ class GraphMambaTrainer:
131
  f"Gap: {acc_gap:.3f} {gap_indicator} | "
132
  f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")
133
 
134
- # Early stopping conditions
135
  if self.patience_counter >= self.patience:
136
  if verbose:
137
  print(f"πŸ›‘ Early stopping at epoch {epoch} (patience)")
138
  break
139
 
140
- # Stop if severe overfitting
141
- if acc_gap > 0.5:
 
 
 
 
 
 
142
  if verbose:
143
- print(f"πŸ›‘ Stopping due to severe overfitting (gap: {acc_gap:.3f})")
144
  break
145
 
146
  if verbose:
@@ -170,15 +180,15 @@ class GraphMambaTrainer:
170
  # Compute loss on training nodes only
171
  train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
172
 
173
- # Add L2 regularization manually
174
  l2_reg = 0.0
175
  for param in self.model.parameters():
176
  l2_reg += torch.norm(param, p=2)
177
- train_loss += 1e-5 * l2_reg
178
 
179
  # Backward pass with gradient clipping
180
  train_loss.backward()
181
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
182
  self.optimizer.step()
183
 
184
  # Compute accuracy
 
20
  # Conservative learning rate
21
  self.lr = config['training']['learning_rate']
22
  self.epochs = config['training']['epochs']
23
+ self.patience = config['training'].get('patience', 8)
24
  self.min_lr = config['training'].get('min_lr', 1e-6)
25
+ self.max_gap = config['training'].get('max_gap', 0.25) # New gap threshold
26
 
27
  # Heavily regularized optimizer
28
  self.optimizer = optim.AdamW(
 
34
  )
35
 
36
  # Proper loss function with label smoothing
37
+ self.criterion = nn.CrossEntropyLoss(
38
+ label_smoothing=config['training'].get('label_smoothing', 0.15)
39
+ )
40
 
41
+ # Conservative scheduler
42
  self.scheduler = ReduceLROnPlateau(
43
  self.optimizer,
44
  mode='max',
45
  factor=0.5,
46
+ patience=4, # Reduced from 5
47
  min_lr=self.min_lr
48
  )
49
 
 
58
 
59
  # Track overfitting
60
  self.best_gap = float('inf')
61
+ self.overfitting_threshold = 0.2 # Reduced from 0.3
62
 
63
  def train_node_classification(self, data, verbose=True):
64
+ """Anti-overfitting training with gap monitoring"""
65
 
66
  if verbose:
67
  total_params = sum(p.numel() for p in self.model.parameters())
 
75
  print(f"βš™οΈ Parameters: {total_params:,}")
76
  print(f"πŸ“š Training samples: {train_samples}")
77
  print(f"⚠️ Params per sample: {params_per_sample:.1f}")
78
+ print(f"🚨 Max allowed gap: {self.max_gap:.3f}")
79
 
80
+ if params_per_sample > 500:
81
  print(f"🚨 WARNING: High params per sample ratio - overfitting risk!")
82
 
83
  # Initialize classifier
 
118
  else:
119
  self.patience_counter += 1
120
 
121
+ # Aggressive overfitting detection
122
  if acc_gap > self.overfitting_threshold:
123
  if verbose:
124
  print(f"🚨 OVERFITTING detected: {acc_gap:.3f} gap")
125
  print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
126
 
127
  # Progress logging
128
+ if verbose and (epoch == 0 or (epoch + 1) % 5 == 0 or epoch == self.epochs - 1):
129
  elapsed = time.time() - start_time
130
+ gap_indicator = "🚨" if acc_gap > 0.25 else "⚠️" if acc_gap > 0.15 else "βœ…"
131
 
132
  print(f"Epoch {epoch:3d} | "
133
  f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
 
135
  f"Gap: {acc_gap:.3f} {gap_indicator} | "
136
  f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")
137
 
138
+ # Enhanced early stopping conditions
139
  if self.patience_counter >= self.patience:
140
  if verbose:
141
  print(f"πŸ›‘ Early stopping at epoch {epoch} (patience)")
142
  break
143
 
144
+ # Stop if gap exceeds threshold
145
+ if acc_gap > self.max_gap:
146
+ if verbose:
147
+ print(f"πŸ›‘ Stopping due to overfitting gap: {acc_gap:.3f} > {self.max_gap:.3f}")
148
+ break
149
+
150
+ # Stop if severe overfitting (backup check)
151
+ if acc_gap > 0.6:
152
  if verbose:
153
+ print(f"πŸ›‘ Emergency stop - severe overfitting (gap: {acc_gap:.3f})")
154
  break
155
 
156
  if verbose:
 
180
  # Compute loss on training nodes only
181
  train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
182
 
183
+ # Add stronger L2 regularization
184
  l2_reg = 0.0
185
  for param in self.model.parameters():
186
  l2_reg += torch.norm(param, p=2)
187
+ train_loss += 5e-5 * l2_reg # Increased from 1e-5
188
 
189
  # Backward pass with gradient clipping
190
  train_loss.backward()
191
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5) # Reduced from 1.0
192
  self.optimizer.step()
193
 
194
  # Compute accuracy