Update geneformer/mtl/train.py
#582
by
MajorasMeow
- opened
- geneformer/mtl/train.py +19 -7
geneformer/mtl/train.py
CHANGED
|
@@ -475,14 +475,26 @@ def objective(
|
|
| 475 |
param_name, param_config["low"], param_config["high"]
|
| 476 |
)
|
| 477 |
|
| 478 |
-
# Set appropriate max layers to freeze based on pretrained model
|
| 479 |
if "max_layers_to_freeze" in trial_config:
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
"
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
trial_config["run_name"] = f"trial_{trial.number}"
|
| 488 |
|
|
|
|
| 475 |
param_name, param_config["low"], param_config["high"]
|
| 476 |
)
|
| 477 |
|
| 478 |
+
# Set appropriate max layers to freeze based on pretrained model or user-specific range
|
| 479 |
if "max_layers_to_freeze" in trial_config:
|
| 480 |
+
if trial_config["max_layers_to_freeze"] is None:
|
| 481 |
+
# infer range from pretrained model
|
| 482 |
+
freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
|
| 483 |
+
trial_config["max_layers_to_freeze"] = int(
|
| 484 |
+
trial.suggest_int(
|
| 485 |
+
"max_layers_to_freeze",
|
| 486 |
+
freeze_range["min"],
|
| 487 |
+
freeze_range["max"],
|
| 488 |
+
)
|
| 489 |
+
)
|
| 490 |
+
else:
|
| 491 |
+
# user-specified range
|
| 492 |
+
min_freeze = trial_config["max_layers_to_freeze"]["min"]
|
| 493 |
+
max_freeze = trial_config["max_layers_to_freeze"]["max"]
|
| 494 |
+
|
| 495 |
+
trial_config["max_layers_to_freeze"] = int(
|
| 496 |
+
trial.suggest_int("max_layers_to_freeze", min_freeze, max_freeze)
|
| 497 |
+
)
|
| 498 |
|
| 499 |
trial_config["run_name"] = f"trial_{trial.number}"
|
| 500 |
|