from __future__ import annotations import json import os from pathlib import Path from typing import Dict import numpy as np from dotenv import load_dotenv from src.constants import TARGET_NAMES from src.features import FingerprintFeaturizer from src.lightgbm_trainer import train_stage_one_models from src.preprocess import load_tox21_dataset from src.seed import set_seed from src.stage_two import train_stage_two_models def _default_checkpoint_dir(config: Dict) -> Path: checkpoint_cfg = config.get("output", {}) checkpoint_dir = checkpoint_cfg.get("checkpoint_dir", "./checkpoints") path = Path(checkpoint_dir) path.mkdir(parents=True, exist_ok=True) return path def train(config: Dict): load_dotenv() set_seed(config.get("seed", 42)) token = os.getenv("TOKEN") dataset_cfg = config.get("dataset", {}) dataset_name = dataset_cfg.get("name", "ml-jku/tox21") splits = load_tox21_dataset(token, dataset_name) if "train" not in splits or "validation" not in splits: raise ValueError("Dataset must provide 'train' and 'validation' splits.") featurizer = FingerprintFeaturizer(config.get("features", {})) train_df, train_features = featurizer.featurize_dataframe(splits["train"], "train") val_df, val_features = featurizer.featurize_dataframe(splits["validation"], "validation") checkpoint_dir = _default_checkpoint_dir(config) cache_dir = checkpoint_dir / "cache" cache_dir.mkdir(parents=True, exist_ok=True) print("==== Stage 1: Training baseline LightGBM models ====") stage1_artifacts = train_stage_one_models( train_features, val_features, train_df, val_df, config, checkpoint_dir, target_names=TARGET_NAMES, ) stage1_train_full = stage1_artifacts["train_full"] stage1_val_full = stage1_artifacts["val_full"] np.savez( cache_dir / "stage1_train_predictions.npz", full=stage1_train_full, target_names=np.array(TARGET_NAMES, dtype=object), ) if stage1_val_full is not None: np.savez( cache_dir / "stage1_validation_predictions.npz", full=stage1_val_full, target_names=np.array(TARGET_NAMES, dtype=object), ) stage2_metrics = None multitask_cfg = config.get("multitask", {"enabled": False}) if multitask_cfg.get("enabled", False): print("==== Stage 2: Training multitask-augmented LightGBM models ====") stage2_artifacts = train_stage_two_models( train_features, val_features, train_df, val_df, config, checkpoint_dir, stage1_train_full, stage1_val_full, target_names=TARGET_NAMES, ) stage2_metrics = stage2_artifacts["metrics"] stage2_entry = { "enabled": bool(multitask_cfg.get("enabled", False)), "model_dir": str(checkpoint_dir / "stage2") if stage2_metrics is not None else None, "metrics": str(checkpoint_dir / "metrics_stage2.json") if stage2_metrics is not None else None, } manifest = { "feature_config": config.get("features", {}), "target_names": TARGET_NAMES, "dataset": dataset_cfg, "stage1": { "model_dir": str(checkpoint_dir / "stage1"), "metrics": str((checkpoint_dir / "metrics_stage1.json")), }, "stage2": stage2_entry, "multitask": multitask_cfg, "seed": config.get("seed", 42), } manifest_path = checkpoint_dir / "training_manifest.json" with manifest_path.open("w", encoding="utf-8") as f: json.dump(manifest, f, indent=2) print("Training complete.") if __name__ == "__main__": with open("./config/config.json", "r", encoding="utf-8") as f: config = json.load(f) train(config)