kfoughali commited on
Commit
bc0f47a
·
verified ·
1 Parent(s): 72acb14

Update core/trainer.py

Browse files
Files changed (1) hide show
  1. core/trainer.py +7 -7
core/trainer.py CHANGED
@@ -17,12 +17,12 @@ class GraphMambaTrainer:
17
  self.config = config
18
  self.device = device
19
 
20
- # Conservative learning rate
21
  self.lr = config['training']['learning_rate']
22
  self.epochs = config['training']['epochs']
23
- self.patience = config['training'].get('patience', 8)
24
  self.min_lr = config['training'].get('min_lr', 1e-6)
25
- self.max_gap = config['training'].get('max_gap', 0.25) # New gap threshold
26
 
27
  # Heavily regularized optimizer
28
  self.optimizer = optim.AdamW(
@@ -35,15 +35,15 @@ class GraphMambaTrainer:
35
 
36
  # Proper loss function with label smoothing
37
  self.criterion = nn.CrossEntropyLoss(
38
- label_smoothing=config['training'].get('label_smoothing', 0.15)
39
  )
40
 
41
- # Conservative scheduler
42
  self.scheduler = ReduceLROnPlateau(
43
  self.optimizer,
44
  mode='max',
45
  factor=0.5,
46
- patience=4, # Reduced from 5
47
  min_lr=self.min_lr
48
  )
49
 
@@ -58,7 +58,7 @@ class GraphMambaTrainer:
58
 
59
  # Track overfitting
60
  self.best_gap = float('inf')
61
- self.overfitting_threshold = 0.2 # Reduced from 0.3
62
 
63
  def train_node_classification(self, data, verbose=True):
64
  """Anti-overfitting training with gap monitoring"""
 
17
  self.config = config
18
  self.device = device
19
 
20
+ # Balanced learning rate
21
  self.lr = config['training']['learning_rate']
22
  self.epochs = config['training']['epochs']
23
+ self.patience = config['training'].get('patience', 12)
24
  self.min_lr = config['training'].get('min_lr', 1e-6)
25
+ self.max_gap = config['training'].get('max_gap', 0.35) # More lenient threshold
26
 
27
  # Heavily regularized optimizer
28
  self.optimizer = optim.AdamW(
 
35
 
36
  # Proper loss function with label smoothing
37
  self.criterion = nn.CrossEntropyLoss(
38
+ label_smoothing=config['training'].get('label_smoothing', 0.1)
39
  )
40
 
41
+ # Balanced scheduler
42
  self.scheduler = ReduceLROnPlateau(
43
  self.optimizer,
44
  mode='max',
45
  factor=0.5,
46
+ patience=5,
47
  min_lr=self.min_lr
48
  )
49
 
 
58
 
59
  # Track overfitting
60
  self.best_gap = float('inf')
61
+ self.overfitting_threshold = 0.25 # Balanced threshold
62
 
63
  def train_node_classification(self, data, verbose=True):
64
  """Anti-overfitting training with gap monitoring"""