File size: 1,702 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import unittest
import torch
from gat.runtime.meter import AverageMeter, calc_cls_accuracy
class TestAverageMeter(unittest.TestCase):
def test_initialization(self):
avg_meter = AverageMeter()
self.assertEqual(avg_meter.val, 0)
self.assertEqual(avg_meter.avg, 0)
self.assertEqual(avg_meter.sum, 0)
self.assertEqual(avg_meter.count, 0)
def test_update(self):
avg_meter = AverageMeter()
avg_meter.update(10)
self.assertEqual(avg_meter.val, 10)
self.assertEqual(avg_meter.sum, 10)
self.assertEqual(avg_meter.count, 1)
self.assertEqual(avg_meter.avg, 10)
avg_meter.update(20, n=2)
self.assertEqual(avg_meter.val, 20)
self.assertEqual(avg_meter.sum, 50)
self.assertEqual(avg_meter.count, 3)
self.assertAlmostEqual(avg_meter.avg, 16.67, delta=0.01)
def test_reset(self):
avg_meter = AverageMeter()
avg_meter.update(10)
avg_meter.reset()
self.assertEqual(avg_meter.val, 0)
self.assertEqual(avg_meter.avg, 0)
self.assertEqual(avg_meter.sum, 0)
self.assertEqual(avg_meter.count, 0)
class TestCalcClsAccuracy(unittest.TestCase):
def test_calc_cls_accuracy(self):
output = torch.tensor([[0.1, 0.7, 0.2], [0.8, 0.1, 0.1]])
target = torch.tensor([2, 0])
acc1 = calc_cls_accuracy(output, target, topk=(1, ))
self.assertAlmostEqual(acc1[0].item(), 50.0)
acc2 = calc_cls_accuracy(output, target, topk=(2, ))
self.assertAlmostEqual(acc2[0].item(), 100.0)
if __name__ == '__main__':
unittest.main()
|