import unittest import torch from gat.runtime.meter import AverageMeter, calc_cls_accuracy class TestAverageMeter(unittest.TestCase): def test_initialization(self): avg_meter = AverageMeter() self.assertEqual(avg_meter.val, 0) self.assertEqual(avg_meter.avg, 0) self.assertEqual(avg_meter.sum, 0) self.assertEqual(avg_meter.count, 0) def test_update(self): avg_meter = AverageMeter() avg_meter.update(10) self.assertEqual(avg_meter.val, 10) self.assertEqual(avg_meter.sum, 10) self.assertEqual(avg_meter.count, 1) self.assertEqual(avg_meter.avg, 10) avg_meter.update(20, n=2) self.assertEqual(avg_meter.val, 20) self.assertEqual(avg_meter.sum, 50) self.assertEqual(avg_meter.count, 3) self.assertAlmostEqual(avg_meter.avg, 16.67, delta=0.01) def test_reset(self): avg_meter = AverageMeter() avg_meter.update(10) avg_meter.reset() self.assertEqual(avg_meter.val, 0) self.assertEqual(avg_meter.avg, 0) self.assertEqual(avg_meter.sum, 0) self.assertEqual(avg_meter.count, 0) class TestCalcClsAccuracy(unittest.TestCase): def test_calc_cls_accuracy(self): output = torch.tensor([[0.1, 0.7, 0.2], [0.8, 0.1, 0.1]]) target = torch.tensor([2, 0]) acc1 = calc_cls_accuracy(output, target, topk=(1, )) self.assertAlmostEqual(acc1[0].item(), 50.0) acc2 = calc_cls_accuracy(output, target, topk=(2, )) self.assertAlmostEqual(acc2[0].item(), 100.0) if __name__ == '__main__': unittest.main()