|
|
import pyrootutils |
|
|
|
|
|
root = str(pyrootutils.setup_root( |
|
|
search_from=__file__, |
|
|
indicator=[".git", "README.md"], |
|
|
pythonpath=True, |
|
|
dotenv=True)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
from typing import List, Tuple |
|
|
|
|
|
import hydra |
|
|
import torch |
|
|
import torch_geometric |
|
|
from omegaconf import OmegaConf, DictConfig |
|
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
|
|
from pytorch_lightning.loggers import Logger |
|
|
|
|
|
from src import utils |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not OmegaConf.has_resolver('eval'): |
|
|
OmegaConf.register_new_resolver('eval', eval) |
|
|
|
|
|
log = utils.get_pylogger(__name__) |
|
|
|
|
|
|
|
|
@utils.task_wrapper |
|
|
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: |
|
|
"""Evaluates given checkpoint on a datamodule testset. |
|
|
|
|
|
This method is wrapped in optional @task_wrapper decorator which applies extra utilities |
|
|
before and after the call. |
|
|
|
|
|
Args: |
|
|
cfg (DictConfig): Configuration composed by Hydra. |
|
|
|
|
|
Returns: |
|
|
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. |
|
|
""" |
|
|
|
|
|
assert cfg.ckpt_path |
|
|
|
|
|
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") |
|
|
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) |
|
|
|
|
|
log.info(f"Instantiating model <{cfg.model._target_}>") |
|
|
model: LightningModule = hydra.utils.instantiate(cfg.model) |
|
|
|
|
|
log.info("Instantiating loggers...") |
|
|
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) |
|
|
|
|
|
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") |
|
|
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) |
|
|
if float('.'.join(torch.__version__.split('.')[:2])) >= 2.0: |
|
|
torch.set_float32_matmul_precision(cfg.float32_matmul_precision) |
|
|
|
|
|
|
|
|
object_dict = { |
|
|
"cfg": cfg, |
|
|
"datamodule": datamodule, |
|
|
"model": model, |
|
|
"logger": logger, |
|
|
"trainer": trainer, |
|
|
} |
|
|
|
|
|
if logger: |
|
|
log.info("Logging hyperparameters!") |
|
|
utils.log_hyperparameters(object_dict) |
|
|
|
|
|
if cfg.get("compile"): |
|
|
log.info("Compiling model!") |
|
|
model = torch.compile(model, dynamic=True) |
|
|
|
|
|
log.info("Starting testing!") |
|
|
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric_dict = trainer.callback_metrics |
|
|
|
|
|
return metric_dict, object_dict |
|
|
|
|
|
|
|
|
@hydra.main(version_base="1.2", config_path=root + "/configs", config_name="eval.yaml") |
|
|
def main(cfg: DictConfig) -> None: |
|
|
evaluate(cfg) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|