| | import os |
| | from typing import List, Optional |
| |
|
| | import hydra |
| | from omegaconf import DictConfig |
| | from pytorch_lightning import ( |
| | Callback, |
| | LightningDataModule, |
| | LightningModule, |
| | Trainer, |
| | seed_everything, |
| | ) |
| | from pytorch_lightning.loggers import LightningLoggerBase |
| |
|
| | from src import utils |
| |
|
| | log = utils.get_logger(__name__) |
| |
|
| |
|
| | def train(config: DictConfig) -> Optional[float]: |
| | """Contains the training pipeline. |
| | Can additionally evaluate model on a testset, using best weights achieved during training. |
| | |
| | Args: |
| | config (DictConfig): Configuration composed by Hydra. |
| | |
| | Returns: |
| | Optional[float]: Metric score for hyperparameter optimization. |
| | """ |
| |
|
| | |
| | if config.get("seed"): |
| | seed_everything(config.seed, workers=True) |
| |
|
| | |
| | ckpt_path = config.trainer.get("resume_from_checkpoint") |
| | if ckpt_path and not os.path.isabs(ckpt_path): |
| | config.trainer.resume_from_checkpoint = os.path.join( |
| | hydra.utils.get_original_cwd(), ckpt_path |
| | ) |
| |
|
| | |
| | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") |
| | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) |
| |
|
| | |
| | log.info(f"Instantiating model <{config.model._target_}>") |
| | model: LightningModule = hydra.utils.instantiate(config.model) |
| |
|
| | |
| | callbacks: List[Callback] = [] |
| | if "callbacks" in config: |
| | for _, cb_conf in config.callbacks.items(): |
| | if "_target_" in cb_conf: |
| | log.info(f"Instantiating callback <{cb_conf._target_}>") |
| | callbacks.append(hydra.utils.instantiate(cb_conf)) |
| |
|
| | |
| | logger: List[LightningLoggerBase] = [] |
| | if "logger" in config: |
| | for _, lg_conf in config.logger.items(): |
| | if "_target_" in lg_conf: |
| | log.info(f"Instantiating logger <{lg_conf._target_}>") |
| | logger.append(hydra.utils.instantiate(lg_conf)) |
| |
|
| | |
| | log.info(f"Instantiating trainer <{config.trainer._target_}>") |
| | trainer: Trainer = hydra.utils.instantiate( |
| | config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" |
| | ) |
| |
|
| | |
| | log.info("Logging hyperparameters!") |
| | utils.log_hyperparameters( |
| | config=config, |
| | model=model, |
| | datamodule=datamodule, |
| | trainer=trainer, |
| | callbacks=callbacks, |
| | logger=logger, |
| | ) |
| |
|
| | |
| | if config.get("train"): |
| | log.info("Starting training!") |
| | trainer.fit(model=model, datamodule=datamodule) |
| |
|
| | |
| | optimized_metric = config.get("optimized_metric") |
| | if optimized_metric and optimized_metric not in trainer.callback_metrics: |
| | raise Exception( |
| | "Metric for hyperparameter optimization not found! " |
| | "Make sure the `optimized_metric` in `hparams_search` config is correct!" |
| | ) |
| | score = trainer.callback_metrics.get(optimized_metric) |
| |
|
| | |
| | if config.get("test"): |
| | ckpt_path = "best" |
| | if not config.get("train") or config.trainer.get("fast_dev_run"): |
| | ckpt_path = None |
| | log.info("Starting testing!") |
| | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) |
| |
|
| | |
| | log.info("Finalizing!") |
| | utils.finish( |
| | config=config, |
| | model=model, |
| | datamodule=datamodule, |
| | trainer=trainer, |
| | callbacks=callbacks, |
| | logger=logger, |
| | ) |
| |
|
| | |
| | if not config.trainer.get("fast_dev_run") and config.trainer.get("train"): |
| | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") |
| |
|
| | |
| | return score |
| |
|