Update core/trainer.py
Browse files- core/trainer.py +7 -7
core/trainer.py
CHANGED
|
@@ -17,12 +17,12 @@ class GraphMambaTrainer:
|
|
| 17 |
self.config = config
|
| 18 |
self.device = device
|
| 19 |
|
| 20 |
-
#
|
| 21 |
self.lr = config['training']['learning_rate']
|
| 22 |
self.epochs = config['training']['epochs']
|
| 23 |
-
self.patience = config['training'].get('patience',
|
| 24 |
self.min_lr = config['training'].get('min_lr', 1e-6)
|
| 25 |
-
self.max_gap = config['training'].get('max_gap', 0.
|
| 26 |
|
| 27 |
# Heavily regularized optimizer
|
| 28 |
self.optimizer = optim.AdamW(
|
|
@@ -35,15 +35,15 @@ class GraphMambaTrainer:
|
|
| 35 |
|
| 36 |
# Proper loss function with label smoothing
|
| 37 |
self.criterion = nn.CrossEntropyLoss(
|
| 38 |
-
label_smoothing=config['training'].get('label_smoothing', 0.
|
| 39 |
)
|
| 40 |
|
| 41 |
-
#
|
| 42 |
self.scheduler = ReduceLROnPlateau(
|
| 43 |
self.optimizer,
|
| 44 |
mode='max',
|
| 45 |
factor=0.5,
|
| 46 |
-
patience=
|
| 47 |
min_lr=self.min_lr
|
| 48 |
)
|
| 49 |
|
|
@@ -58,7 +58,7 @@ class GraphMambaTrainer:
|
|
| 58 |
|
| 59 |
# Track overfitting
|
| 60 |
self.best_gap = float('inf')
|
| 61 |
-
self.overfitting_threshold = 0.
|
| 62 |
|
| 63 |
def train_node_classification(self, data, verbose=True):
|
| 64 |
"""Anti-overfitting training with gap monitoring"""
|
|
|
|
| 17 |
self.config = config
|
| 18 |
self.device = device
|
| 19 |
|
| 20 |
+
# Balanced learning rate
|
| 21 |
self.lr = config['training']['learning_rate']
|
| 22 |
self.epochs = config['training']['epochs']
|
| 23 |
+
self.patience = config['training'].get('patience', 12)
|
| 24 |
self.min_lr = config['training'].get('min_lr', 1e-6)
|
| 25 |
+
self.max_gap = config['training'].get('max_gap', 0.35) # More lenient threshold
|
| 26 |
|
| 27 |
# Heavily regularized optimizer
|
| 28 |
self.optimizer = optim.AdamW(
|
|
|
|
| 35 |
|
| 36 |
# Proper loss function with label smoothing
|
| 37 |
self.criterion = nn.CrossEntropyLoss(
|
| 38 |
+
label_smoothing=config['training'].get('label_smoothing', 0.1)
|
| 39 |
)
|
| 40 |
|
| 41 |
+
# Balanced scheduler
|
| 42 |
self.scheduler = ReduceLROnPlateau(
|
| 43 |
self.optimizer,
|
| 44 |
mode='max',
|
| 45 |
factor=0.5,
|
| 46 |
+
patience=5,
|
| 47 |
min_lr=self.min_lr
|
| 48 |
)
|
| 49 |
|
|
|
|
| 58 |
|
| 59 |
# Track overfitting
|
| 60 |
self.best_gap = float('inf')
|
| 61 |
+
self.overfitting_threshold = 0.25 # Balanced threshold
|
| 62 |
|
| 63 |
def train_node_classification(self, data, verbose=True):
|
| 64 |
"""Anti-overfitting training with gap monitoring"""
|