|
|
import time |
|
|
import warnings |
|
|
from importlib.util import find_spec |
|
|
from pathlib import Path |
|
|
from typing import Any, Callable, Dict, List |
|
|
|
|
|
import hydra |
|
|
from omegaconf import DictConfig |
|
|
from pytorch_lightning import Callback |
|
|
from pytorch_lightning.loggers import Logger |
|
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
|
|
from src.utils import pylogger, rich_utils |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'close_loggers', 'extras', 'get_metric_value', 'instantiate_callbacks', |
|
|
'instantiate_loggers', 'log_hyperparameters', 'save_file', 'task_wrapper'] |
|
|
|
|
|
|
|
|
log = pylogger.get_pylogger(__name__) |
|
|
|
|
|
|
|
|
def task_wrapper(task_func: Callable) -> Callable: |
|
|
"""Optional decorator that wraps the task function in extra utilities. |
|
|
|
|
|
Makes multirun more resistant to failure. |
|
|
|
|
|
Utilities: |
|
|
- Calling the `utils.extras()` before the task is started |
|
|
- Calling the `utils.close_loggers()` after the task is finished |
|
|
- Logging the exception if occurs |
|
|
- Logging the task total execution time |
|
|
- Logging the output dir |
|
|
""" |
|
|
|
|
|
def wrap(cfg: DictConfig): |
|
|
|
|
|
|
|
|
extras(cfg) |
|
|
|
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
metric_dict, object_dict = task_func(cfg=cfg) |
|
|
except Exception as ex: |
|
|
log.exception("") |
|
|
raise ex |
|
|
finally: |
|
|
path = Path(cfg.paths.output_dir, "exec_time.log") |
|
|
content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" |
|
|
save_file(path, content) |
|
|
close_loggers() |
|
|
|
|
|
log.info(f"Output dir: {cfg.paths.output_dir}") |
|
|
|
|
|
return metric_dict, object_dict |
|
|
|
|
|
return wrap |
|
|
|
|
|
|
|
|
def extras(cfg: DictConfig) -> None: |
|
|
"""Applies optional utilities before the task is started. |
|
|
|
|
|
Utilities: |
|
|
- Ignoring python warnings |
|
|
- Setting tags from command line |
|
|
- Rich config printing |
|
|
""" |
|
|
|
|
|
|
|
|
if not cfg.get("extras"): |
|
|
log.warning("Extras config not found! <cfg.extras=null>") |
|
|
return |
|
|
|
|
|
|
|
|
if cfg.extras.get("ignore_warnings"): |
|
|
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
if cfg.extras.get("enforce_tags"): |
|
|
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") |
|
|
rich_utils.enforce_tags(cfg, save_to_file=True) |
|
|
|
|
|
|
|
|
if cfg.extras.get("print_config"): |
|
|
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") |
|
|
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) |
|
|
|
|
|
|
|
|
@rank_zero_only |
|
|
def save_file(path: str, content: str) -> None: |
|
|
"""Save file in rank zero mode (only on one process in multi-GPU setup).""" |
|
|
with open(path, "w+") as file: |
|
|
file.write(content) |
|
|
|
|
|
|
|
|
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: |
|
|
"""Instantiates callbacks from config.""" |
|
|
callbacks: List[Callback] = [] |
|
|
|
|
|
if not callbacks_cfg: |
|
|
log.warning("Callbacks config is empty.") |
|
|
return callbacks |
|
|
|
|
|
if not isinstance(callbacks_cfg, DictConfig): |
|
|
raise TypeError("Callbacks config must be a DictConfig!") |
|
|
|
|
|
for _, cb_conf in callbacks_cfg.items(): |
|
|
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: |
|
|
log.info(f"Instantiating callback <{cb_conf._target_}>") |
|
|
callbacks.append(hydra.utils.instantiate(cb_conf)) |
|
|
|
|
|
return callbacks |
|
|
|
|
|
|
|
|
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: |
|
|
"""Instantiates loggers from config.""" |
|
|
logger: List[Logger] = [] |
|
|
|
|
|
if not logger_cfg: |
|
|
log.warning("Logger config is empty.") |
|
|
return logger |
|
|
|
|
|
if not isinstance(logger_cfg, DictConfig): |
|
|
raise TypeError("Logger config must be a DictConfig!") |
|
|
|
|
|
for _, lg_conf in logger_cfg.items(): |
|
|
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: |
|
|
log.info(f"Instantiating logger <{lg_conf._target_}>") |
|
|
logger.append(hydra.utils.instantiate(lg_conf)) |
|
|
|
|
|
return logger |
|
|
|
|
|
|
|
|
@rank_zero_only |
|
|
def log_hyperparameters(object_dict: dict) -> None: |
|
|
"""Controls which config parts are saved by lightning loggers. |
|
|
|
|
|
Additionally saves: |
|
|
- Number of model parameters |
|
|
""" |
|
|
|
|
|
hparams = {} |
|
|
|
|
|
cfg = object_dict["cfg"] |
|
|
model = object_dict["model"] |
|
|
trainer = object_dict["trainer"] |
|
|
|
|
|
if not trainer.logger: |
|
|
log.warning("Logger not found! Skipping hyperparameter logging...") |
|
|
return |
|
|
|
|
|
hparams["model"] = cfg["model"] |
|
|
|
|
|
|
|
|
hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) |
|
|
hparams["model/params/trainable"] = sum( |
|
|
p.numel() for p in model.parameters() if p.requires_grad |
|
|
) |
|
|
hparams["model/params/non_trainable"] = sum( |
|
|
p.numel() for p in model.parameters() if not p.requires_grad |
|
|
) |
|
|
|
|
|
hparams["datamodule"] = cfg["datamodule"] |
|
|
hparams["trainer"] = cfg["trainer"] |
|
|
|
|
|
hparams["callbacks"] = cfg.get("callbacks") |
|
|
hparams["extras"] = cfg.get("extras") |
|
|
|
|
|
hparams["task_name"] = cfg.get("task_name") |
|
|
hparams["tags"] = cfg.get("tags") |
|
|
hparams["ckpt_path"] = cfg.get("ckpt_path") |
|
|
hparams["seed"] = cfg.get("seed") |
|
|
|
|
|
|
|
|
trainer.logger.log_hyperparams(hparams) |
|
|
|
|
|
|
|
|
def get_metric_value(metric_dict: dict, metric_name: str) -> float: |
|
|
"""Safely retrieves value of the metric logged in LightningModule.""" |
|
|
|
|
|
if not metric_name: |
|
|
log.info("Metric name is None! Skipping metric value retrieval...") |
|
|
return |
|
|
|
|
|
if metric_name not in metric_dict: |
|
|
raise Exception( |
|
|
f"Metric value not found! <metric_name={metric_name}>\n" |
|
|
"Make sure metric name logged in LightningModule is correct!\n" |
|
|
"Make sure `optimized_metric` name in `hparams_search` config is correct!" |
|
|
) |
|
|
|
|
|
metric_value = metric_dict[metric_name].item() |
|
|
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") |
|
|
|
|
|
return metric_value |
|
|
|
|
|
|
|
|
def close_loggers() -> None: |
|
|
"""Makes sure all loggers closed properly (prevents logging failure during multirun).""" |
|
|
|
|
|
log.info("Closing loggers...") |
|
|
|
|
|
if find_spec("wandb"): |
|
|
import wandb |
|
|
|
|
|
if wandb.run: |
|
|
log.info("Closing wandb!") |
|
|
wandb.finish() |
|
|
|