| import logging |
| from pytorch_lightning.callbacks import Callback |
| import torch |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class FixNANinGrad(Callback): |
| def __init__(self, monitor): |
| super().__init__() |
| self.monitor = monitor |
| self.continuous_nan_batchs = 0 |
|
|
| def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None: |
| has_nan = [] |
| is_inf = [] |
| for name, param in pl_module.named_parameters(): |
| if param.grad is not None: |
| if torch.isnan(param.grad).any(): |
| has_nan.append(name) |
| if torch.isinf(param.grad).any(): |
| is_inf.append(name) |
| torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) |
| if len(has_nan) > 0: |
| print(f"Found NaN in {has_nan}") |
| if len(is_inf) > 0: |
| print(f"Found Inf in {is_inf}") |
|
|
| def on_train_batch_end( |
| self, |
| trainer, |
| pl_module, |
| outputs, |
| batch, |
| batch_idx, |
| ) -> None: |
| logs = trainer.callback_metrics |
| i = 0 |
| found_metric = False |
| while i < len(self.monitor) and not found_metric: |
| if self.monitor[i] in logs.keys(): |
| current = logs[self.monitor[i]].squeeze() |
| found_metric = True |
| else: |
| i += 1 |
| if not found_metric: |
| raise ValueError("Asked metric not in logs") |
|
|
| if not torch.isfinite(current): |
| self.continuous_nan_batchs += 1 |
| if self.continuous_nan_batchs >= 5: |
| trainer.should_stop = True |
| log.info("Training interrupted because of NaN in {self.monitor}") |
| else: |
| self.continuous_nan_batchs = 0 |
|
|