| | import os |
| | from typing import List |
| |
|
| | import hydra |
| | from omegaconf import DictConfig |
| | from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything |
| | from pytorch_lightning.loggers import LightningLoggerBase |
| |
|
| | from src import utils |
| |
|
| | log = utils.get_logger(__name__) |
| |
|
| |
|
| | def test(config: DictConfig) -> None: |
| | """Contains minimal example of the testing pipeline. |
| | Evaluates given checkpoint on a testset. |
| | |
| | Args: |
| | config (DictConfig): Configuration composed by Hydra. |
| | |
| | Returns: |
| | None |
| | """ |
| |
|
| | |
| | if config.get("seed"): |
| | seed_everything(config.seed, workers=True) |
| |
|
| | |
| | if not os.path.isabs(config.ckpt_path): |
| | config.ckpt_path = os.path.join(hydra.utils.get_original_cwd(), config.ckpt_path) |
| |
|
| | |
| | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") |
| | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) |
| |
|
| | |
| | log.info(f"Instantiating model <{config.model._target_}>") |
| | model: LightningModule = hydra.utils.instantiate(config.model) |
| |
|
| | |
| | logger: List[LightningLoggerBase] = [] |
| | if "logger" in config: |
| | for _, lg_conf in config.logger.items(): |
| | if "_target_" in lg_conf: |
| | log.info(f"Instantiating logger <{lg_conf._target_}>") |
| | logger.append(hydra.utils.instantiate(lg_conf)) |
| |
|
| | |
| | log.info(f"Instantiating trainer <{config.trainer._target_}>") |
| | trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=logger) |
| |
|
| | |
| | trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path}) |
| |
|
| | log.info("Starting testing!") |
| | trainer.test(model=model, datamodule=datamodule, ckpt_path=config.ckpt_path) |
| |
|