Spaces:
Sleeping
Sleeping
| import logging | |
| from lightning.pytorch.utilities import rank_zero_only | |
| from lightning.pytorch.utilities.model_summary import ModelSummary | |
| def get_logger(name=__name__) -> logging.Logger: | |
| """Initializes multi-GPU-friendly python command line logger.""" | |
| logger = logging.getLogger(name) | |
| # this ensures all logging levels get marked with the rank zero decorator | |
| # otherwise logs would get multiplied for each GPU process in multi-GPU setup | |
| logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") | |
| for level in logging_levels: | |
| setattr(logger, level, rank_zero_only(getattr(logger, level))) | |
| return logger | |
| log = get_logger(__name__) | |
| def log_hyperparameters(object_dict: dict) -> None: | |
| """Controls which config parts are saved by lightning loggers. | |
| Additionally, saves: | |
| - Number of model parameters | |
| """ | |
| hparams = {} | |
| cfg = object_dict["cfg"] | |
| model = object_dict["model"] | |
| trainer = object_dict["trainer"] | |
| if not trainer.logger: | |
| log.warning("Logger not found! Skipping hyperparameter logging.") | |
| return | |
| hparams["model"] = cfg["model"] | |
| # save number of model parameters | |
| model_summary = ModelSummary(model) | |
| hparams["model/params/total"] = model_summary.total_parameters | |
| hparams["model/params/trainable"] = model_summary.trainable_parameters | |
| hparams["model/params/non_trainable"] = model_summary.total_parameters - model_summary.trainable_parameters | |
| hparams["data"] = cfg["data"] | |
| hparams["trainer"] = cfg["trainer"] | |
| hparams["callbacks"] = cfg.get("callbacks") | |
| hparams["extras"] = cfg.get("extras") | |
| hparams["job_name"] = cfg.get("job_name") | |
| hparams["tags"] = cfg.get("tags") | |
| hparams["ckpt_path"] = cfg.get("ckpt_path") | |
| hparams["seed"] = cfg.get("seed") | |
| # send hparams to all loggers | |
| for logger in trainer.loggers: | |
| logger.log_hyperparams(hparams) | |