File size: 1,702 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import unittest

import torch

from gat.runtime.meter import AverageMeter, calc_cls_accuracy


class TestAverageMeter(unittest.TestCase):

    def test_initialization(self):
        avg_meter = AverageMeter()
        self.assertEqual(avg_meter.val, 0)
        self.assertEqual(avg_meter.avg, 0)
        self.assertEqual(avg_meter.sum, 0)
        self.assertEqual(avg_meter.count, 0)

    def test_update(self):
        avg_meter = AverageMeter()
        avg_meter.update(10)
        self.assertEqual(avg_meter.val, 10)
        self.assertEqual(avg_meter.sum, 10)
        self.assertEqual(avg_meter.count, 1)
        self.assertEqual(avg_meter.avg, 10)

        avg_meter.update(20, n=2)
        self.assertEqual(avg_meter.val, 20)
        self.assertEqual(avg_meter.sum, 50)
        self.assertEqual(avg_meter.count, 3)
        self.assertAlmostEqual(avg_meter.avg, 16.67, delta=0.01)

    def test_reset(self):
        avg_meter = AverageMeter()
        avg_meter.update(10)
        avg_meter.reset()
        self.assertEqual(avg_meter.val, 0)
        self.assertEqual(avg_meter.avg, 0)
        self.assertEqual(avg_meter.sum, 0)
        self.assertEqual(avg_meter.count, 0)


class TestCalcClsAccuracy(unittest.TestCase):

    def test_calc_cls_accuracy(self):
        output = torch.tensor([[0.1, 0.7, 0.2], [0.8, 0.1, 0.1]])
        target = torch.tensor([2, 0])

        acc1 = calc_cls_accuracy(output, target, topk=(1, ))
        self.assertAlmostEqual(acc1[0].item(), 50.0)

        acc2 = calc_cls_accuracy(output, target, topk=(2, ))
        self.assertAlmostEqual(acc2[0].item(), 100.0)


if __name__ == '__main__':
    unittest.main()