| import torch |
| import lightning.pytorch as pl |
| from lightning.pytorch.utilities import grad_norm |
| from torch.optim import Optimizer |
|
|
| class GradientMonitor(pl.Callback): |
| """Logs the gradient norm""" |
|
|
| def __init__(self, norm_type: int = 2): |
| norm_type = float(norm_type) |
| if norm_type <= 0: |
| raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") |
| self.norm_type = norm_type |
|
|
| def on_before_optimizer_step( |
| self, trainer: "pl.Trainer", |
| pl_module: "pl.LightningModule", |
| optimizer: Optimizer |
| ) -> None: |
| norms = grad_norm(pl_module, norm_type=self.norm_type) |
| max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max() |
| pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]}) |