Ttius's picture
Upload 192 files
998bb30 verified
import unittest
import torch
from gat.runtime.meter import AverageMeter, calc_cls_accuracy
class TestAverageMeter(unittest.TestCase):
def test_initialization(self):
avg_meter = AverageMeter()
self.assertEqual(avg_meter.val, 0)
self.assertEqual(avg_meter.avg, 0)
self.assertEqual(avg_meter.sum, 0)
self.assertEqual(avg_meter.count, 0)
def test_update(self):
avg_meter = AverageMeter()
avg_meter.update(10)
self.assertEqual(avg_meter.val, 10)
self.assertEqual(avg_meter.sum, 10)
self.assertEqual(avg_meter.count, 1)
self.assertEqual(avg_meter.avg, 10)
avg_meter.update(20, n=2)
self.assertEqual(avg_meter.val, 20)
self.assertEqual(avg_meter.sum, 50)
self.assertEqual(avg_meter.count, 3)
self.assertAlmostEqual(avg_meter.avg, 16.67, delta=0.01)
def test_reset(self):
avg_meter = AverageMeter()
avg_meter.update(10)
avg_meter.reset()
self.assertEqual(avg_meter.val, 0)
self.assertEqual(avg_meter.avg, 0)
self.assertEqual(avg_meter.sum, 0)
self.assertEqual(avg_meter.count, 0)
class TestCalcClsAccuracy(unittest.TestCase):
def test_calc_cls_accuracy(self):
output = torch.tensor([[0.1, 0.7, 0.2], [0.8, 0.1, 0.1]])
target = torch.tensor([2, 0])
acc1 = calc_cls_accuracy(output, target, topk=(1, ))
self.assertAlmostEqual(acc1[0].item(), 50.0)
acc2 = calc_cls_accuracy(output, target, topk=(2, ))
self.assertAlmostEqual(acc2[0].item(), 100.0)
if __name__ == '__main__':
unittest.main()