| | import logging |
| | import warnings |
| | from typing import List, Sequence |
| |
|
| | import pytorch_lightning as pl |
| | import rich.syntax |
| | import rich.tree |
| | from omegaconf import DictConfig, OmegaConf |
| | from pytorch_lightning.utilities import rank_zero_only |
| |
|
| |
|
| | def get_logger(name=__name__) -> logging.Logger: |
| | """Initializes multi-GPU-friendly python command line logger.""" |
| |
|
| | logger = logging.getLogger(name) |
| |
|
| | |
| | |
| | for level in ( |
| | "debug", |
| | "info", |
| | "warning", |
| | "error", |
| | "exception", |
| | "fatal", |
| | "critical", |
| | ): |
| | setattr(logger, level, rank_zero_only(getattr(logger, level))) |
| |
|
| | return logger |
| |
|
| |
|
| | log = get_logger(__name__) |
| |
|
| |
|
| | def extras(config: DictConfig) -> None: |
| | """Applies optional utilities, controlled by config flags. |
| | |
| | Utilities: |
| | - Ignoring python warnings |
| | - Rich config printing |
| | """ |
| |
|
| | |
| | if config.get("ignore_warnings"): |
| | log.info("Disabling python warnings! <config.ignore_warnings=True>") |
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| | if config.get("print_config"): |
| | log.info("Printing config tree with Rich! <config.print_config=True>") |
| | print_config(config, resolve=True) |
| |
|
| |
|
| | @rank_zero_only |
| | def print_config( |
| | config: DictConfig, |
| | print_order: Sequence[str] = ( |
| | "datamodule", |
| | "model", |
| | "callbacks", |
| | "logger", |
| | "trainer", |
| | ), |
| | resolve: bool = True, |
| | ) -> None: |
| | """Prints content of DictConfig using Rich library and its tree structure. |
| | |
| | Args: |
| | config (DictConfig): Configuration composed by Hydra. |
| | print_order (Sequence[str], optional): Determines in what order config components are printed. |
| | resolve (bool, optional): Whether to resolve reference fields of DictConfig. |
| | """ |
| |
|
| | style = "dim" |
| | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
| |
|
| | quee = [] |
| |
|
| | for field in print_order: |
| | quee.append(field) if field in config else log.info(f"Field '{field}' not found in config") |
| |
|
| | for field in config: |
| | if field not in quee: |
| | quee.append(field) |
| |
|
| | for field in quee: |
| | branch = tree.add(field, style=style, guide_style=style) |
| |
|
| | config_group = config[field] |
| | if isinstance(config_group, DictConfig): |
| | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) |
| | else: |
| | branch_content = str(config_group) |
| |
|
| | branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
| |
|
| | rich.print(tree) |
| |
|
| | with open("config_tree.log", "w") as file: |
| | rich.print(tree, file=file) |
| |
|
| |
|
| | @rank_zero_only |
| | def log_hyperparameters( |
| | config: DictConfig, |
| | model: pl.LightningModule, |
| | datamodule: pl.LightningDataModule, |
| | trainer: pl.Trainer, |
| | callbacks: List[pl.Callback], |
| | logger: List[pl.loggers.LightningLoggerBase], |
| | ) -> None: |
| | """Controls which config parts are saved by Lightning loggers. |
| | |
| | Additionaly saves: |
| | - number of model parameters |
| | """ |
| |
|
| | hparams = {} |
| |
|
| | |
| | hparams["trainer"] = config["trainer"] |
| | hparams["model"] = config["model"] |
| | hparams["datamodule"] = config["datamodule"] |
| |
|
| | if "seed" in config: |
| | hparams["seed"] = config["seed"] |
| | if "callbacks" in config: |
| | hparams["callbacks"] = config["callbacks"] |
| |
|
| | |
| | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) |
| | hparams["model/params/trainable"] = sum( |
| | p.numel() for p in model.parameters() if p.requires_grad |
| | ) |
| | hparams["model/params/non_trainable"] = sum( |
| | p.numel() for p in model.parameters() if not p.requires_grad |
| | ) |
| |
|
| | |
| | trainer.logger.log_hyperparams(hparams) |
| |
|
| |
|
| | def finish( |
| | config: DictConfig, |
| | model: pl.LightningModule, |
| | datamodule: pl.LightningDataModule, |
| | trainer: pl.Trainer, |
| | callbacks: List[pl.Callback], |
| | logger: List[pl.loggers.LightningLoggerBase], |
| | ) -> None: |
| | """Makes sure everything closed properly.""" |
| |
|
| | |
| | for lg in logger: |
| | if isinstance(lg, pl.loggers.wandb.WandbLogger): |
| | import wandb |
| |
|
| | wandb.finish() |
| |
|