| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| from mmengine.hooks import Hook | |
| from mmengine.runner import Runner | |
| from mmdet.registry import HOOKS | |
| class CheckInvalidLossHook(Hook): | |
| """Check invalid loss hook. | |
| This hook will regularly check whether the loss is valid | |
| during training. | |
| Args: | |
| interval (int): Checking interval (every k iterations). | |
| Default: 50. | |
| """ | |
| def __init__(self, interval: int = 50) -> None: | |
| self.interval = interval | |
| def after_train_iter(self, | |
| runner: Runner, | |
| batch_idx: int, | |
| data_batch: Optional[dict] = None, | |
| outputs: Optional[dict] = None) -> None: | |
| """Regularly check whether the loss is valid every n iterations. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the training process. | |
| batch_idx (int): The index of the current batch in the train loop. | |
| data_batch (dict, Optional): Data from dataloader. | |
| Defaults to None. | |
| outputs (dict, Optional): Outputs from model. Defaults to None. | |
| """ | |
| if self.every_n_train_iters(runner, self.interval): | |
| assert torch.isfinite(outputs['loss']), \ | |
| runner.logger.info('loss become infinite or NaN!') | |