Update core/trainer.py
Browse files- core/trainer.py +27 -17
core/trainer.py
CHANGED
|
@@ -20,8 +20,9 @@ class GraphMambaTrainer:
|
|
| 20 |
# Conservative learning rate
|
| 21 |
self.lr = config['training']['learning_rate']
|
| 22 |
self.epochs = config['training']['epochs']
|
| 23 |
-
self.patience = config['training'].get('patience',
|
| 24 |
self.min_lr = config['training'].get('min_lr', 1e-6)
|
|
|
|
| 25 |
|
| 26 |
# Heavily regularized optimizer
|
| 27 |
self.optimizer = optim.AdamW(
|
|
@@ -33,14 +34,16 @@ class GraphMambaTrainer:
|
|
| 33 |
)
|
| 34 |
|
| 35 |
# Proper loss function with label smoothing
|
| 36 |
-
self.criterion = nn.CrossEntropyLoss(
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
# Conservative scheduler
|
| 39 |
self.scheduler = ReduceLROnPlateau(
|
| 40 |
self.optimizer,
|
| 41 |
mode='max',
|
| 42 |
factor=0.5,
|
| 43 |
-
patience=5
|
| 44 |
min_lr=self.min_lr
|
| 45 |
)
|
| 46 |
|
|
@@ -55,10 +58,10 @@ class GraphMambaTrainer:
|
|
| 55 |
|
| 56 |
# Track overfitting
|
| 57 |
self.best_gap = float('inf')
|
| 58 |
-
self.overfitting_threshold = 0.3
|
| 59 |
|
| 60 |
def train_node_classification(self, data, verbose=True):
|
| 61 |
-
"""Anti-overfitting training"""
|
| 62 |
|
| 63 |
if verbose:
|
| 64 |
total_params = sum(p.numel() for p in self.model.parameters())
|
|
@@ -72,8 +75,9 @@ class GraphMambaTrainer:
|
|
| 72 |
print(f"βοΈ Parameters: {total_params:,}")
|
| 73 |
print(f"π Training samples: {train_samples}")
|
| 74 |
print(f"β οΈ Params per sample: {params_per_sample:.1f}")
|
|
|
|
| 75 |
|
| 76 |
-
if params_per_sample >
|
| 77 |
print(f"π¨ WARNING: High params per sample ratio - overfitting risk!")
|
| 78 |
|
| 79 |
# Initialize classifier
|
|
@@ -114,16 +118,16 @@ class GraphMambaTrainer:
|
|
| 114 |
else:
|
| 115 |
self.patience_counter += 1
|
| 116 |
|
| 117 |
-
#
|
| 118 |
if acc_gap > self.overfitting_threshold:
|
| 119 |
if verbose:
|
| 120 |
print(f"π¨ OVERFITTING detected: {acc_gap:.3f} gap")
|
| 121 |
print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
|
| 122 |
|
| 123 |
# Progress logging
|
| 124 |
-
if verbose and (epoch == 0 or (epoch + 1) %
|
| 125 |
elapsed = time.time() - start_time
|
| 126 |
-
gap_indicator = "π¨" if acc_gap > 0.
|
| 127 |
|
| 128 |
print(f"Epoch {epoch:3d} | "
|
| 129 |
f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
|
|
@@ -131,16 +135,22 @@ class GraphMambaTrainer:
|
|
| 131 |
f"Gap: {acc_gap:.3f} {gap_indicator} | "
|
| 132 |
f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")
|
| 133 |
|
| 134 |
-
#
|
| 135 |
if self.patience_counter >= self.patience:
|
| 136 |
if verbose:
|
| 137 |
print(f"π Early stopping at epoch {epoch} (patience)")
|
| 138 |
break
|
| 139 |
|
| 140 |
-
# Stop if
|
| 141 |
-
if acc_gap >
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
if verbose:
|
| 143 |
-
print(f"π
|
| 144 |
break
|
| 145 |
|
| 146 |
if verbose:
|
|
@@ -170,15 +180,15 @@ class GraphMambaTrainer:
|
|
| 170 |
# Compute loss on training nodes only
|
| 171 |
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
|
| 172 |
|
| 173 |
-
# Add L2 regularization
|
| 174 |
l2_reg = 0.0
|
| 175 |
for param in self.model.parameters():
|
| 176 |
l2_reg += torch.norm(param, p=2)
|
| 177 |
-
train_loss +=
|
| 178 |
|
| 179 |
# Backward pass with gradient clipping
|
| 180 |
train_loss.backward()
|
| 181 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0
|
| 182 |
self.optimizer.step()
|
| 183 |
|
| 184 |
# Compute accuracy
|
|
|
|
| 20 |
# Conservative learning rate
|
| 21 |
self.lr = config['training']['learning_rate']
|
| 22 |
self.epochs = config['training']['epochs']
|
| 23 |
+
self.patience = config['training'].get('patience', 8)
|
| 24 |
self.min_lr = config['training'].get('min_lr', 1e-6)
|
| 25 |
+
self.max_gap = config['training'].get('max_gap', 0.25) # New gap threshold
|
| 26 |
|
| 27 |
# Heavily regularized optimizer
|
| 28 |
self.optimizer = optim.AdamW(
|
|
|
|
| 34 |
)
|
| 35 |
|
| 36 |
# Proper loss function with label smoothing
|
| 37 |
+
self.criterion = nn.CrossEntropyLoss(
|
| 38 |
+
label_smoothing=config['training'].get('label_smoothing', 0.15)
|
| 39 |
+
)
|
| 40 |
|
| 41 |
+
# Conservative scheduler
|
| 42 |
self.scheduler = ReduceLROnPlateau(
|
| 43 |
self.optimizer,
|
| 44 |
mode='max',
|
| 45 |
factor=0.5,
|
| 46 |
+
patience=4, # Reduced from 5
|
| 47 |
min_lr=self.min_lr
|
| 48 |
)
|
| 49 |
|
|
|
|
| 58 |
|
| 59 |
# Track overfitting
|
| 60 |
self.best_gap = float('inf')
|
| 61 |
+
self.overfitting_threshold = 0.2 # Reduced from 0.3
|
| 62 |
|
| 63 |
def train_node_classification(self, data, verbose=True):
|
| 64 |
+
"""Anti-overfitting training with gap monitoring"""
|
| 65 |
|
| 66 |
if verbose:
|
| 67 |
total_params = sum(p.numel() for p in self.model.parameters())
|
|
|
|
| 75 |
print(f"βοΈ Parameters: {total_params:,}")
|
| 76 |
print(f"π Training samples: {train_samples}")
|
| 77 |
print(f"β οΈ Params per sample: {params_per_sample:.1f}")
|
| 78 |
+
print(f"π¨ Max allowed gap: {self.max_gap:.3f}")
|
| 79 |
|
| 80 |
+
if params_per_sample > 500:
|
| 81 |
print(f"π¨ WARNING: High params per sample ratio - overfitting risk!")
|
| 82 |
|
| 83 |
# Initialize classifier
|
|
|
|
| 118 |
else:
|
| 119 |
self.patience_counter += 1
|
| 120 |
|
| 121 |
+
# Aggressive overfitting detection
|
| 122 |
if acc_gap > self.overfitting_threshold:
|
| 123 |
if verbose:
|
| 124 |
print(f"π¨ OVERFITTING detected: {acc_gap:.3f} gap")
|
| 125 |
print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
|
| 126 |
|
| 127 |
# Progress logging
|
| 128 |
+
if verbose and (epoch == 0 or (epoch + 1) % 5 == 0 or epoch == self.epochs - 1):
|
| 129 |
elapsed = time.time() - start_time
|
| 130 |
+
gap_indicator = "π¨" if acc_gap > 0.25 else "β οΈ" if acc_gap > 0.15 else "β
"
|
| 131 |
|
| 132 |
print(f"Epoch {epoch:3d} | "
|
| 133 |
f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
|
|
|
|
| 135 |
f"Gap: {acc_gap:.3f} {gap_indicator} | "
|
| 136 |
f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")
|
| 137 |
|
| 138 |
+
# Enhanced early stopping conditions
|
| 139 |
if self.patience_counter >= self.patience:
|
| 140 |
if verbose:
|
| 141 |
print(f"π Early stopping at epoch {epoch} (patience)")
|
| 142 |
break
|
| 143 |
|
| 144 |
+
# Stop if gap exceeds threshold
|
| 145 |
+
if acc_gap > self.max_gap:
|
| 146 |
+
if verbose:
|
| 147 |
+
print(f"π Stopping due to overfitting gap: {acc_gap:.3f} > {self.max_gap:.3f}")
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
# Stop if severe overfitting (backup check)
|
| 151 |
+
if acc_gap > 0.6:
|
| 152 |
if verbose:
|
| 153 |
+
print(f"π Emergency stop - severe overfitting (gap: {acc_gap:.3f})")
|
| 154 |
break
|
| 155 |
|
| 156 |
if verbose:
|
|
|
|
| 180 |
# Compute loss on training nodes only
|
| 181 |
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
|
| 182 |
|
| 183 |
+
# Add stronger L2 regularization
|
| 184 |
l2_reg = 0.0
|
| 185 |
for param in self.model.parameters():
|
| 186 |
l2_reg += torch.norm(param, p=2)
|
| 187 |
+
train_loss += 5e-5 * l2_reg # Increased from 1e-5
|
| 188 |
|
| 189 |
# Backward pass with gradient clipping
|
| 190 |
train_loss.backward()
|
| 191 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5) # Reduced from 1.0
|
| 192 |
self.optimizer.step()
|
| 193 |
|
| 194 |
# Compute accuracy
|