Spaces:
Sleeping
Sleeping
| from typing import Any | |
| from pytorch_lightning import Callback, Trainer, LightningModule | |
| from pytorch_lightning.utilities import rank_zero_only | |
| from pytorch_lightning.utilities.parsing import AttributeDict | |
| class ParamsLog(Callback): | |
| """Log the number of parameters of the model | |
| """ | |
| def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True, | |
| non_trainable_params_log: bool = True): | |
| super().__init__() | |
| self._log_stats = AttributeDict( | |
| { | |
| 'total_params_log': total_params_log, | |
| 'trainable_params_log': trainable_params_log, | |
| 'non_trainable_params_log': non_trainable_params_log, | |
| } | |
| ) | |
| def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: | |
| logs = {} | |
| if self._log_stats.total_params_log: | |
| logs["model/params_total"] = sum(p.numel() for p in pl_module.parameters()) | |
| if self._log_stats.trainable_params_log: | |
| logs["model/params_trainable"] = sum(p.numel() for p in pl_module.parameters() | |
| if p.requires_grad) | |
| if self._log_stats.non_trainable_params_log: | |
| logs["model/params_not_trainable"] = sum(p.numel() for p in pl_module.parameters() | |
| if not p.requires_grad) | |
| if trainer.logger is not None: | |
| trainer.logger.log_hyperparams(logs) | |