Spaces:
Sleeping
Sleeping
File size: 3,879 Bytes
94b1553 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Dict
import numpy as np
from dotenv import load_dotenv
from src.constants import TARGET_NAMES
from src.features import FingerprintFeaturizer
from src.lightgbm_trainer import train_stage_one_models
from src.preprocess import load_tox21_dataset
from src.seed import set_seed
from src.stage_two import train_stage_two_models
def _default_checkpoint_dir(config: Dict) -> Path:
checkpoint_cfg = config.get("output", {})
checkpoint_dir = checkpoint_cfg.get("checkpoint_dir", "./checkpoints")
path = Path(checkpoint_dir)
path.mkdir(parents=True, exist_ok=True)
return path
def train(config: Dict):
load_dotenv()
set_seed(config.get("seed", 42))
token = os.getenv("TOKEN")
dataset_cfg = config.get("dataset", {})
dataset_name = dataset_cfg.get("name", "ml-jku/tox21")
splits = load_tox21_dataset(token, dataset_name)
if "train" not in splits or "validation" not in splits:
raise ValueError("Dataset must provide 'train' and 'validation' splits.")
featurizer = FingerprintFeaturizer(config.get("features", {}))
train_df, train_features = featurizer.featurize_dataframe(splits["train"], "train")
val_df, val_features = featurizer.featurize_dataframe(splits["validation"], "validation")
checkpoint_dir = _default_checkpoint_dir(config)
cache_dir = checkpoint_dir / "cache"
cache_dir.mkdir(parents=True, exist_ok=True)
print("==== Stage 1: Training baseline LightGBM models ====")
stage1_artifacts = train_stage_one_models(
train_features,
val_features,
train_df,
val_df,
config,
checkpoint_dir,
target_names=TARGET_NAMES,
)
stage1_train_full = stage1_artifacts["train_full"]
stage1_val_full = stage1_artifacts["val_full"]
np.savez(
cache_dir / "stage1_train_predictions.npz",
full=stage1_train_full,
target_names=np.array(TARGET_NAMES, dtype=object),
)
if stage1_val_full is not None:
np.savez(
cache_dir / "stage1_validation_predictions.npz",
full=stage1_val_full,
target_names=np.array(TARGET_NAMES, dtype=object),
)
stage2_metrics = None
multitask_cfg = config.get("multitask", {"enabled": False})
if multitask_cfg.get("enabled", False):
print("==== Stage 2: Training multitask-augmented LightGBM models ====")
stage2_artifacts = train_stage_two_models(
train_features,
val_features,
train_df,
val_df,
config,
checkpoint_dir,
stage1_train_full,
stage1_val_full,
target_names=TARGET_NAMES,
)
stage2_metrics = stage2_artifacts["metrics"]
stage2_entry = {
"enabled": bool(multitask_cfg.get("enabled", False)),
"model_dir": str(checkpoint_dir / "stage2") if stage2_metrics is not None else None,
"metrics": str(checkpoint_dir / "metrics_stage2.json") if stage2_metrics is not None else None,
}
manifest = {
"feature_config": config.get("features", {}),
"target_names": TARGET_NAMES,
"dataset": dataset_cfg,
"stage1": {
"model_dir": str(checkpoint_dir / "stage1"),
"metrics": str((checkpoint_dir / "metrics_stage1.json")),
},
"stage2": stage2_entry,
"multitask": multitask_cfg,
"seed": config.get("seed", 42),
}
manifest_path = checkpoint_dir / "training_manifest.json"
with manifest_path.open("w", encoding="utf-8") as f:
json.dump(manifest, f, indent=2)
print("Training complete.")
if __name__ == "__main__":
with open("./config/config.json", "r", encoding="utf-8") as f:
config = json.load(f)
train(config)
|