| import unittest | |
| import torch | |
| from gat.runtime.meter import AverageMeter, calc_cls_accuracy | |
| class TestAverageMeter(unittest.TestCase): | |
| def test_initialization(self): | |
| avg_meter = AverageMeter() | |
| self.assertEqual(avg_meter.val, 0) | |
| self.assertEqual(avg_meter.avg, 0) | |
| self.assertEqual(avg_meter.sum, 0) | |
| self.assertEqual(avg_meter.count, 0) | |
| def test_update(self): | |
| avg_meter = AverageMeter() | |
| avg_meter.update(10) | |
| self.assertEqual(avg_meter.val, 10) | |
| self.assertEqual(avg_meter.sum, 10) | |
| self.assertEqual(avg_meter.count, 1) | |
| self.assertEqual(avg_meter.avg, 10) | |
| avg_meter.update(20, n=2) | |
| self.assertEqual(avg_meter.val, 20) | |
| self.assertEqual(avg_meter.sum, 50) | |
| self.assertEqual(avg_meter.count, 3) | |
| self.assertAlmostEqual(avg_meter.avg, 16.67, delta=0.01) | |
| def test_reset(self): | |
| avg_meter = AverageMeter() | |
| avg_meter.update(10) | |
| avg_meter.reset() | |
| self.assertEqual(avg_meter.val, 0) | |
| self.assertEqual(avg_meter.avg, 0) | |
| self.assertEqual(avg_meter.sum, 0) | |
| self.assertEqual(avg_meter.count, 0) | |
| class TestCalcClsAccuracy(unittest.TestCase): | |
| def test_calc_cls_accuracy(self): | |
| output = torch.tensor([[0.1, 0.7, 0.2], [0.8, 0.1, 0.1]]) | |
| target = torch.tensor([2, 0]) | |
| acc1 = calc_cls_accuracy(output, target, topk=(1, )) | |
| self.assertAlmostEqual(acc1[0].item(), 50.0) | |
| acc2 = calc_cls_accuracy(output, target, topk=(2, )) | |
| self.assertAlmostEqual(acc2[0].item(), 100.0) | |
| if __name__ == '__main__': | |
| unittest.main() | |