| |
| from unittest import TestCase |
| from unittest.mock import Mock |
|
|
| import torch |
|
|
| from mmdet.engine.hooks import CheckInvalidLossHook |
|
|
|
|
| class TestCheckInvalidLossHook(TestCase): |
|
|
| def test_after_train_iter(self): |
| n = 50 |
| hook = CheckInvalidLossHook(n) |
| runner = Mock() |
| runner.logger = Mock() |
| runner.logger.info = Mock() |
|
|
| |
| runner.iter = 10 |
| outputs = dict(loss=torch.LongTensor([2])) |
| hook.after_train_iter(runner, 10, outputs=outputs) |
| outputs = dict(loss=torch.tensor(float('nan'))) |
| hook.after_train_iter(runner, 10, outputs=outputs) |
| outputs = dict(loss=torch.tensor(float('inf'))) |
| hook.after_train_iter(runner, 10, outputs=outputs) |
|
|
| |
| runner.iter = n - 1 |
| outputs = dict(loss=torch.LongTensor([2])) |
| hook.after_train_iter(runner, n - 1, outputs=outputs) |
| outputs = dict(loss=torch.tensor(float('nan'))) |
| with self.assertRaises(AssertionError): |
| hook.after_train_iter(runner, n - 1, outputs=outputs) |
| outputs = dict(loss=torch.tensor(float('inf'))) |
| with self.assertRaises(AssertionError): |
| hook.after_train_iter(runner, n - 1, outputs=outputs) |
|
|