File size: 3,019 Bytes
29658b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import unittest

import torch

from specforge.core.loss import LogSoftmaxLoss, _compute_loss

from .utils import norm_tensor


class TestLogSoftmaxLoss(unittest.TestCase):

    TTT_LENGTH = 7

    def _test_loss_and_gradient_calculation(self, B, T, V):
        if not torch.cuda.is_available():
            device = "cpu"
        else:
            device = "cuda"

        logits = norm_tensor((B, T, V), device, torch.float32)
        logits2 = logits.clone().detach().requires_grad_(True)
        target = norm_tensor((B, T, V), device, torch.float32)
        position_mask = torch.randint(0, 2, (B, T, 1), dtype=torch.bool, device=device)

        output1 = LogSoftmaxLoss.apply(logits, target, position_mask)
        output2 = _compute_loss(logits2, target, position_mask)
        torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4)

        output1.backward()
        output2.backward()
        torch.testing.assert_close(logits.grad, logits2.grad, rtol=1e-4, atol=1e-4)

    def test_loss(self):
        B = [1, 2, 4]
        T = [1024, 2048, 4096, 6000]
        V = [4096, 8192, 10000]
        for b in B:
            for t in T:
                for v in V:
                    self._test_loss_and_gradient_calculation(b, t, v)

    def test_ttt_loss_accumulation(self):
        if not torch.cuda.is_available():
            device = "cpu"
        else:
            device = "cuda"

        B, T, V = 1, 1024, 3200
        plosses = []
        plosses_compare = []
        logits_list = [
            norm_tensor((B, T, V), device, torch.float32)
            for _ in range(self.TTT_LENGTH)
        ]
        logits_list_copy = [
            logits.clone().detach().requires_grad_(True) for logits in logits_list
        ]
        for i in range(self.TTT_LENGTH):
            logits = logits_list[i]
            logits2 = logits_list_copy[i]
            target = norm_tensor((B, T, V), device, torch.float32)
            position_mask = torch.randint(
                0, 2, (B, T, 1), dtype=torch.bool, device=device
            )

            output1 = LogSoftmaxLoss.apply(logits, target, position_mask)
            output2 = _compute_loss(logits2, target, position_mask)
            torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4)
            plosses.append(output1)
            plosses_compare.append(output2)

        ploss_weight = [0.8**i for i in range(len(plosses))]
        ploss = (
            sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])
            / self.TTT_LENGTH
        )
        ploss_compare = (
            sum([ploss_weight[i] * plosses_compare[i] for i in range(len(plosses))])
            / self.TTT_LENGTH
        )
        torch.testing.assert_close(ploss, ploss_compare, rtol=1e-4, atol=1e-4)
        ploss.backward()
        ploss_compare.backward()
        for i in range(self.TTT_LENGTH):
            torch.testing.assert_close(
                logits_list[i].grad, logits_list_copy[i].grad, rtol=1e-4, atol=1e-4
            )