b2u commited on
Commit
b537bff
·
1 Parent(s): 1589415

fix default env variables numbers

Browse files
Files changed (1) hide show
  1. model.py +5 -5
model.py CHANGED
@@ -35,11 +35,11 @@ class BertClassifier(LabelStudioMLBase):
35
  def __init__(self, project_id=None, label_config=None, **kwargs):
36
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
37
 
38
- # Load training configuration from environment variables
39
- self.learning_rate = float(os.getenv('LEARNING_RATE'))
40
- self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS'))
41
- self.weight_decay = float(os.getenv('WEIGHT_DECAY'))
42
- self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES'))
43
 
44
  logger.info("=== Training Configuration ===")
45
  logger.info(f"✓ Learning rate: {self.learning_rate}")
 
35
  def __init__(self, project_id=None, label_config=None, **kwargs):
36
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
37
 
38
+ # Load training configuration from environment variables with defaults
39
+ self.learning_rate = float(os.getenv('LEARNING_RATE', '2e-5'))
40
+ self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS', '20'))
41
+ self.weight_decay = float(os.getenv('WEIGHT_DECAY', '0.01'))
42
+ self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES', '1'))
43
 
44
  logger.info("=== Training Configuration ===")
45
  logger.info(f"✓ Learning rate: {self.learning_rate}")