| import logging |
| import warnings |
| from typing import List, Sequence |
|
|
| import pytorch_lightning as pl |
| import rich.syntax |
| import rich.tree |
| from omegaconf import DictConfig, OmegaConf |
| from pytorch_lightning.utilities import rank_zero_only |
|
|
|
|
| |
| class LoggingContext: |
| def __init__(self, logger, level=None, handler=None, close=True): |
| self.logger = logger |
| self.level = level |
| self.handler = handler |
| self.close = close |
|
|
| def __enter__(self): |
| if self.level is not None: |
| self.old_level = self.logger.level |
| self.logger.setLevel(self.level) |
| if self.handler: |
| self.logger.addHandler(self.handler) |
|
|
| def __exit__(self, et, ev, tb): |
| if self.level is not None: |
| self.logger.setLevel(self.old_level) |
| if self.handler: |
| self.logger.removeHandler(self.handler) |
| if self.handler and self.close: |
| self.handler.close() |
| |
|
|
|
|
| def get_logger(name=__name__) -> logging.Logger: |
| """Initializes multi-GPU-friendly python logger.""" |
|
|
| logger = logging.getLogger(name) |
|
|
| |
| |
| for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): |
| setattr(logger, level, rank_zero_only(getattr(logger, level))) |
|
|
| return logger |
|
|
|
|
| def extras(config: DictConfig) -> None: |
| """A couple of optional utilities, controlled by main config file: |
| - disabling warnings |
| - forcing debug friendly configuration |
| - verifying experiment name is set when running in experiment mode |
| Modifies DictConfig in place. |
| Args: |
| config (DictConfig): Configuration composed by Hydra. |
| """ |
|
|
| log = get_logger(__name__) |
|
|
| |
| if config.get("ignore_warnings"): |
| log.info("Disabling python warnings! <config.ignore_warnings=True>") |
| warnings.filterwarnings("ignore") |
|
|
| |
| if config.get("experiment_mode") and not config.get("name"): |
| log.info( |
| "Running in experiment mode without the experiment name specified! " |
| "Use `python run.py mode=exp name=experiment_name`" |
| ) |
| log.info("Exiting...") |
| exit() |
|
|
| |
| |
| if config.trainer.get("fast_dev_run"): |
| log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>") |
| if config.trainer.get("gpus"): |
| config.trainer.gpus = 0 |
| if config.datamodule.get("pin_memory"): |
| config.datamodule.pin_memory = False |
| if config.datamodule.get("num_workers"): |
| config.datamodule.num_workers = 0 |
|
|
|
|
| @rank_zero_only |
| def print_config( |
| config: DictConfig, |
| fields: Sequence[str] = ( |
| "trainer", |
| "model", |
| "datamodule", |
| "train", |
| "eval", |
| "callbacks", |
| "logger", |
| "seed", |
| "name", |
| ), |
| resolve: bool = True, |
| ) -> None: |
| """Prints content of DictConfig using Rich library and its tree structure. |
| Args: |
| config (DictConfig): Configuration composed by Hydra. |
| fields (Sequence[str], optional): Determines which main fields from config will |
| be printed and in what order. |
| resolve (bool, optional): Whether to resolve reference fields of DictConfig. |
| """ |
|
|
| style = "dim" |
| tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
|
|
| for field in fields: |
| branch = tree.add(field, style=style, guide_style=style) |
|
|
| config_section = config.get(field) |
| branch_content = str(config_section) |
| if isinstance(config_section, DictConfig): |
| branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) |
|
|
| branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
|
|
| rich.print(tree) |
|
|
| with open("config_tree.txt", "w") as fp: |
| rich.print(tree, file=fp) |
|
|
|
|
| def finish( |
| config: DictConfig, |
| model: pl.LightningModule, |
| datamodule: pl.LightningDataModule, |
| trainer: pl.Trainer, |
| callbacks: List[pl.Callback], |
| logger: List[pl.loggers.LightningLoggerBase], |
| ) -> None: |
| """Makes sure everything closed properly.""" |
|
|
| |
| for lg in logger: |
| if isinstance(lg, pl.loggers.wandb.WandbLogger): |
| import wandb |
|
|
| wandb.finish() |
|
|