Update geneformer/mtl/train.py

#582
Files changed (1) hide show
  1. geneformer/mtl/train.py +19 -7
geneformer/mtl/train.py CHANGED
@@ -475,14 +475,26 @@ def objective(
475
  param_name, param_config["low"], param_config["high"]
476
  )
477
 
478
- # Set appropriate max layers to freeze based on pretrained model
479
  if "max_layers_to_freeze" in trial_config:
480
- freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
481
- trial_config["max_layers_to_freeze"] = int(trial.suggest_int(
482
- "max_layers_to_freeze",
483
- freeze_range["min"],
484
- freeze_range["max"]
485
- ))
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  trial_config["run_name"] = f"trial_{trial.number}"
488
 
 
475
  param_name, param_config["low"], param_config["high"]
476
  )
477
 
478
+ # Set appropriate max layers to freeze based on pretrained model or user-specific range
479
  if "max_layers_to_freeze" in trial_config:
480
+ if trial_config["max_layers_to_freeze"] is None:
481
+ # infer range from pretrained model
482
+ freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
483
+ trial_config["max_layers_to_freeze"] = int(
484
+ trial.suggest_int(
485
+ "max_layers_to_freeze",
486
+ freeze_range["min"],
487
+ freeze_range["max"],
488
+ )
489
+ )
490
+ else:
491
+ # user-specified range
492
+ min_freeze = trial_config["max_layers_to_freeze"]["min"]
493
+ max_freeze = trial_config["max_layers_to_freeze"]["max"]
494
+
495
+ trial_config["max_layers_to_freeze"] = int(
496
+ trial.suggest_int("max_layers_to_freeze", min_freeze, max_freeze)
497
+ )
498
 
499
  trial_config["run_name"] = f"trial_{trial.number}"
500