Spaces:
Sleeping
Sleeping
| import logging | |
| import warnings | |
| from typing import List, Sequence | |
| import pytorch_lightning as pl | |
| import rich.syntax | |
| import rich.tree | |
| from omegaconf import DictConfig, OmegaConf | |
| from pytorch_lightning.utilities import rank_zero_only | |
| # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging | |
| class LoggingContext: | |
| def __init__(self, logger, level=None, handler=None, close=True): | |
| self.logger = logger | |
| self.level = level | |
| self.handler = handler | |
| self.close = close | |
| def __enter__(self): | |
| if self.level is not None: | |
| self.old_level = self.logger.level | |
| self.logger.setLevel(self.level) | |
| if self.handler: | |
| self.logger.addHandler(self.handler) | |
| def __exit__(self, et, ev, tb): | |
| if self.level is not None: | |
| self.logger.setLevel(self.old_level) | |
| if self.handler: | |
| self.logger.removeHandler(self.handler) | |
| if self.handler and self.close: | |
| self.handler.close() | |
| # implicit return of None => don't swallow exceptions | |
| def get_logger(name=__name__) -> logging.Logger: | |
| """Initializes multi-GPU-friendly python logger.""" | |
| logger = logging.getLogger(name) | |
| # this ensures all logging levels get marked with the rank zero decorator | |
| # otherwise logs would get multiplied for each GPU process in multi-GPU setup | |
| for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): | |
| setattr(logger, level, rank_zero_only(getattr(logger, level))) | |
| return logger | |
| def extras(config: DictConfig) -> None: | |
| """A couple of optional utilities, controlled by main config file: | |
| - disabling warnings | |
| - forcing debug friendly configuration | |
| - verifying experiment name is set when running in experiment mode | |
| Modifies DictConfig in place. | |
| Args: | |
| config (DictConfig): Configuration composed by Hydra. | |
| """ | |
| log = get_logger(__name__) | |
| # disable python warnings if <config.ignore_warnings=True> | |
| if config.get("ignore_warnings"): | |
| log.info("Disabling python warnings! <config.ignore_warnings=True>") | |
| warnings.filterwarnings("ignore") | |
| # verify experiment name is set when running in experiment mode | |
| if config.get("experiment_mode") and not config.get("name"): | |
| log.info( | |
| "Running in experiment mode without the experiment name specified! " | |
| "Use `python run.py mode=exp name=experiment_name`" | |
| ) | |
| log.info("Exiting...") | |
| exit() | |
| # force debugger friendly configuration if <config.trainer.fast_dev_run=True> | |
| # debuggers don't like GPUs and multiprocessing | |
| if config.trainer.get("fast_dev_run"): | |
| log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>") | |
| if config.trainer.get("gpus"): | |
| config.trainer.gpus = 0 | |
| if config.datamodule.get("pin_memory"): | |
| config.datamodule.pin_memory = False | |
| if config.datamodule.get("num_workers"): | |
| config.datamodule.num_workers = 0 | |
| def print_config( | |
| config: DictConfig, | |
| fields: Sequence[str] = ( | |
| "trainer", | |
| "model", | |
| "datamodule", | |
| "train", | |
| "eval", | |
| "callbacks", | |
| "logger", | |
| "seed", | |
| "name", | |
| ), | |
| resolve: bool = True, | |
| ) -> None: | |
| """Prints content of DictConfig using Rich library and its tree structure. | |
| Args: | |
| config (DictConfig): Configuration composed by Hydra. | |
| fields (Sequence[str], optional): Determines which main fields from config will | |
| be printed and in what order. | |
| resolve (bool, optional): Whether to resolve reference fields of DictConfig. | |
| """ | |
| style = "dim" | |
| tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) | |
| for field in fields: | |
| branch = tree.add(field, style=style, guide_style=style) | |
| config_section = config.get(field) | |
| branch_content = str(config_section) | |
| if isinstance(config_section, DictConfig): | |
| branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) | |
| branch.add(rich.syntax.Syntax(branch_content, "yaml")) | |
| rich.print(tree) | |
| with open("config_tree.txt", "w") as fp: | |
| rich.print(tree, file=fp) | |
| def finish( | |
| config: DictConfig, | |
| model: pl.LightningModule, | |
| datamodule: pl.LightningDataModule, | |
| trainer: pl.Trainer, | |
| callbacks: List[pl.Callback], | |
| logger: List[pl.loggers.LightningLoggerBase], | |
| ) -> None: | |
| """Makes sure everything closed properly.""" | |
| # without this sweeps with wandb logger might crash! | |
| for lg in logger: | |
| if isinstance(lg, pl.loggers.wandb.WandbLogger): | |
| import wandb | |
| wandb.finish() | |