Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Dict | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from src.constants import TARGET_NAMES | |
| from src.features import FingerprintFeaturizer | |
| from src.lightgbm_trainer import train_stage_one_models | |
| from src.preprocess import load_tox21_dataset | |
| from src.seed import set_seed | |
| from src.stage_two import train_stage_two_models | |
| def _default_checkpoint_dir(config: Dict) -> Path: | |
| checkpoint_cfg = config.get("output", {}) | |
| checkpoint_dir = checkpoint_cfg.get("checkpoint_dir", "./checkpoints") | |
| path = Path(checkpoint_dir) | |
| path.mkdir(parents=True, exist_ok=True) | |
| return path | |
| def train(config: Dict): | |
| load_dotenv() | |
| set_seed(config.get("seed", 42)) | |
| token = os.getenv("TOKEN") | |
| dataset_cfg = config.get("dataset", {}) | |
| dataset_name = dataset_cfg.get("name", "ml-jku/tox21") | |
| splits = load_tox21_dataset(token, dataset_name) | |
| if "train" not in splits or "validation" not in splits: | |
| raise ValueError("Dataset must provide 'train' and 'validation' splits.") | |
| featurizer = FingerprintFeaturizer(config.get("features", {})) | |
| train_df, train_features = featurizer.featurize_dataframe(splits["train"], "train") | |
| val_df, val_features = featurizer.featurize_dataframe(splits["validation"], "validation") | |
| checkpoint_dir = _default_checkpoint_dir(config) | |
| cache_dir = checkpoint_dir / "cache" | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| print("==== Stage 1: Training baseline LightGBM models ====") | |
| stage1_artifacts = train_stage_one_models( | |
| train_features, | |
| val_features, | |
| train_df, | |
| val_df, | |
| config, | |
| checkpoint_dir, | |
| target_names=TARGET_NAMES, | |
| ) | |
| stage1_train_full = stage1_artifacts["train_full"] | |
| stage1_val_full = stage1_artifacts["val_full"] | |
| np.savez( | |
| cache_dir / "stage1_train_predictions.npz", | |
| full=stage1_train_full, | |
| target_names=np.array(TARGET_NAMES, dtype=object), | |
| ) | |
| if stage1_val_full is not None: | |
| np.savez( | |
| cache_dir / "stage1_validation_predictions.npz", | |
| full=stage1_val_full, | |
| target_names=np.array(TARGET_NAMES, dtype=object), | |
| ) | |
| stage2_metrics = None | |
| multitask_cfg = config.get("multitask", {"enabled": False}) | |
| if multitask_cfg.get("enabled", False): | |
| print("==== Stage 2: Training multitask-augmented LightGBM models ====") | |
| stage2_artifacts = train_stage_two_models( | |
| train_features, | |
| val_features, | |
| train_df, | |
| val_df, | |
| config, | |
| checkpoint_dir, | |
| stage1_train_full, | |
| stage1_val_full, | |
| target_names=TARGET_NAMES, | |
| ) | |
| stage2_metrics = stage2_artifacts["metrics"] | |
| stage2_entry = { | |
| "enabled": bool(multitask_cfg.get("enabled", False)), | |
| "model_dir": str(checkpoint_dir / "stage2") if stage2_metrics is not None else None, | |
| "metrics": str(checkpoint_dir / "metrics_stage2.json") if stage2_metrics is not None else None, | |
| } | |
| manifest = { | |
| "feature_config": config.get("features", {}), | |
| "target_names": TARGET_NAMES, | |
| "dataset": dataset_cfg, | |
| "stage1": { | |
| "model_dir": str(checkpoint_dir / "stage1"), | |
| "metrics": str((checkpoint_dir / "metrics_stage1.json")), | |
| }, | |
| "stage2": stage2_entry, | |
| "multitask": multitask_cfg, | |
| "seed": config.get("seed", 42), | |
| } | |
| manifest_path = checkpoint_dir / "training_manifest.json" | |
| with manifest_path.open("w", encoding="utf-8") as f: | |
| json.dump(manifest, f, indent=2) | |
| print("Training complete.") | |
| if __name__ == "__main__": | |
| with open("./config/config.json", "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| train(config) | |