File size: 4,542 Bytes
2fc11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74a6d3c
 
 
2fc11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41c262e
2fc11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41c262e
2fc11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74a6d3c
 
 
 
 
 
2fc11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#@title Architecture implementation
# TODO: comment and rename variables / clean code


# https://arxiv.org/abs/2410.01201v1

import torch
import torch.nn as nn
import torch.nn.functional as F


# appendix B
# https://github.com/glassroom/heinsen_sequence

def heinsen_associative_scan_log(log_coeffs, log_values):
    a_star = log_coeffs.cumsum(dim = 1)
    log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim = 1)
    log_h = a_star + log_h0_plus_b_star
    return log_h.exp()

# appendix B.3

def g(x):     return torch.where(x >= 0, x + 0.5, x.sigmoid())
def log_g(x): return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))

# log-space version of minGRU - B.3.1
# they enforce the hidden states to be positive

class minGRU(nn.Module):
    def __init__(self, d_model, d_inner):
        super().__init__()

        self.d_model = d_model
        self.d_inner = d_inner

        self.hidden_proj = nn.Linear(d_model, d_inner, bias=False)
        self.gate_proj   = nn.Linear(d_model, d_inner, bias=False)
        self.out_proj    = nn.Linear(d_inner, d_model, bias=False)


    def step(self, x, h_prev=None):
        hidden = self.hidden_proj(x)
        gate   = self.gate_proj(x)

        h_prev = h_prev.detach() if h_prev is not None else None

        hidden = g(hidden)
        gate   = gate.sigmoid()
        out    = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate)

        h_next = out[:, -1:]
        out    = self.out_proj(out)

        return out, h_next


    def forward(self, x, h_prev=None):
        seq_len = x.shape[1]
        hidden  = self.hidden_proj(x)
        gate    = self.gate_proj(x)

        h_prev = h_prev.detach() if h_prev is not None else None

        log_coeffs  = -F.softplus(gate)
        log_z       = -F.softplus(-gate)
        log_tilde_h = log_g(hidden)
        log_values  = log_z + log_tilde_h

        if h_prev is not None:
            log_values = torch.cat((h_prev.log(), log_values), dim=1)
            log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0))

        out = heinsen_associative_scan_log(log_coeffs, log_values)
        out = out[:, -seq_len:]
        
        h_next = out[:, -1:]
        out    = self.out_proj(out)

        return out, h_next






class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight





class minGRULM(nn.Module):
    def __init__(self, vocab_size, d_model, d_inner, n_layers):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model    = d_model
        self.d_inner    = d_inner
        self.n_layers   = n_layers
        
        self.embed = nn.Embedding(vocab_size, d_model)

        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            self.layers.append(nn.ModuleList([
                RMSNorm(d_model),
                minGRU(d_model, d_inner)
            ]))

        self.norm_f = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias = False)



    # One single step of minGRU, forwarding one token and outputting one token
    def step(self, x, h_states=None):
        x        = self.embed(x)

        h_next   = []
        h_states  = iter(h_states if h_states is not None else [])

        for norm, mingru in self.layers:
            h_prev   = next(h_states, None)
            residual = x

            x        = norm(x)
            x, h_t   = mingru.step(x, h_prev)
            x        = x + residual

            h_next.append(h_t)

        x      = self.norm_f(x)
        logits = self.lm_head(x)

        return logits, h_next



    def forward(self, x, h_states=None):
        x, labels = x[:, :-1], x[:, 1:]
        x         = self.embed(x)
    
        h_next   = []
        h_states  = iter(h_states if h_states is not None else [])

        for norm, mingru in self.layers:
            h_prev   = next(h_states, None)
            residual = x

            x        = norm(x)
            x, h_t   = mingru.forward(x, h_prev)
            x        = x + residual

            h_next.append(h_t)

        x      = self.norm_f(x)
        logits = self.lm_head(x)
        loss   = F.cross_entropy(logits.transpose(1, 2), labels)

        return logits, h_next, loss