| import pyrootutils |
|
|
| root = str(pyrootutils.setup_root( |
| search_from=__file__, |
| indicator=[".git", "README.md"], |
| pythonpath=True, |
| dotenv=True)) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| import pandas as pd |
|
|
| from typing import List, Optional, Tuple |
|
|
| import hydra |
| import torch |
| import torch_geometric |
| import pytorch_lightning as pl |
| from omegaconf import OmegaConf, DictConfig |
| from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer |
| from pytorch_lightning.loggers import Logger |
|
|
| from src import utils |
|
|
| |
| |
| |
| if not OmegaConf.has_resolver('eval'): |
| OmegaConf.register_new_resolver('eval', eval) |
|
|
| log = utils.get_pylogger(__name__) |
|
|
|
|
| @utils.task_wrapper |
| def train(cfg: DictConfig) -> Tuple[dict, dict]: |
| """Trains the model. Can additionally evaluate on a testset, using best weights obtained during |
| training. |
| |
| This method is wrapped in optional @task_wrapper decorator which applies extra utilities |
| before and after the call. |
| |
| Args: |
| cfg (DictConfig): Configuration composed by Hydra. |
| |
| Returns: |
| Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. |
| """ |
|
|
| |
| if cfg.get("seed"): |
| pl.seed_everything(cfg.seed, workers=True) |
|
|
| log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") |
| datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) |
|
|
| log.info(f"Instantiating model <{cfg.model._target_}>") |
| model: LightningModule = hydra.utils.instantiate(cfg.model) |
|
|
| log.info("Instantiating callbacks...") |
| callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) |
|
|
| log.info("Instantiating loggers...") |
| logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) |
|
|
| log.info(f"Instantiating trainer <{cfg.trainer._target_}>") |
| trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) |
| if float('.'.join(torch.__version__.split('.')[:2])) >= 2.0: |
| torch.set_float32_matmul_precision(cfg.float32_matmul_precision) |
|
|
| object_dict = { |
| "cfg": cfg, |
| "datamodule": datamodule, |
| "model": model, |
| "callbacks": callbacks, |
| "logger": logger, |
| "trainer": trainer, |
| } |
|
|
| if logger: |
| log.info("Logging hyperparameters!") |
| utils.log_hyperparameters(object_dict) |
|
|
| if cfg.get("compile"): |
| log.info("Compiling model!") |
| model = torch.compile(model, dynamic=True) |
|
|
| if cfg.get("train"): |
| log.info("Starting training!") |
| trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) |
|
|
| train_metrics = trainer.callback_metrics |
|
|
| if cfg.get("test"): |
| log.info("Starting testing!") |
| ckpt_path = trainer.checkpoint_callback.best_model_path |
| if ckpt_path == "": |
| log.warning("Best ckpt not found! Using current weights for testing...") |
| ckpt_path = None |
| trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) |
| log.info(f"Best ckpt path: {ckpt_path}") |
|
|
| test_metrics = trainer.callback_metrics |
|
|
| |
| metric_dict = {**train_metrics, **test_metrics} |
|
|
| return metric_dict, object_dict |
|
|
|
|
| @hydra.main(version_base="1.2", config_path=root + "/configs", config_name="train.yaml") |
| def main(cfg: DictConfig) -> Optional[float]: |
|
|
| |
| metric_dict, _ = train(cfg) |
|
|
| |
| metric_value = utils.get_metric_value( |
| metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") |
| ) |
|
|
| |
| return metric_value |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|