File size: 4,542 Bytes
2fc11ed 74a6d3c 2fc11ed 41c262e 2fc11ed 41c262e 2fc11ed 74a6d3c 2fc11ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
#@title Architecture implementation
# TODO: comment and rename variables / clean code
# https://arxiv.org/abs/2410.01201v1
import torch
import torch.nn as nn
import torch.nn.functional as F
# appendix B
# https://github.com/glassroom/heinsen_sequence
def heinsen_associative_scan_log(log_coeffs, log_values):
a_star = log_coeffs.cumsum(dim = 1)
log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim = 1)
log_h = a_star + log_h0_plus_b_star
return log_h.exp()
# appendix B.3
def g(x): return torch.where(x >= 0, x + 0.5, x.sigmoid())
def log_g(x): return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))
# log-space version of minGRU - B.3.1
# they enforce the hidden states to be positive
class minGRU(nn.Module):
def __init__(self, d_model, d_inner):
super().__init__()
self.d_model = d_model
self.d_inner = d_inner
self.hidden_proj = nn.Linear(d_model, d_inner, bias=False)
self.gate_proj = nn.Linear(d_model, d_inner, bias=False)
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def step(self, x, h_prev=None):
hidden = self.hidden_proj(x)
gate = self.gate_proj(x)
h_prev = h_prev.detach() if h_prev is not None else None
hidden = g(hidden)
gate = gate.sigmoid()
out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate)
h_next = out[:, -1:]
out = self.out_proj(out)
return out, h_next
def forward(self, x, h_prev=None):
seq_len = x.shape[1]
hidden = self.hidden_proj(x)
gate = self.gate_proj(x)
h_prev = h_prev.detach() if h_prev is not None else None
log_coeffs = -F.softplus(gate)
log_z = -F.softplus(-gate)
log_tilde_h = log_g(hidden)
log_values = log_z + log_tilde_h
if h_prev is not None:
log_values = torch.cat((h_prev.log(), log_values), dim=1)
log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0))
out = heinsen_associative_scan_log(log_coeffs, log_values)
out = out[:, -seq_len:]
h_next = out[:, -1:]
out = self.out_proj(out)
return out, h_next
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class minGRULM(nn.Module):
def __init__(self, vocab_size, d_model, d_inner, n_layers):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.d_inner = d_inner
self.n_layers = n_layers
self.embed = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleList([
RMSNorm(d_model),
minGRU(d_model, d_inner)
]))
self.norm_f = RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias = False)
# One single step of minGRU, forwarding one token and outputting one token
def step(self, x, h_states=None):
x = self.embed(x)
h_next = []
h_states = iter(h_states if h_states is not None else [])
for norm, mingru in self.layers:
h_prev = next(h_states, None)
residual = x
x = norm(x)
x, h_t = mingru.step(x, h_prev)
x = x + residual
h_next.append(h_t)
x = self.norm_f(x)
logits = self.lm_head(x)
return logits, h_next
def forward(self, x, h_states=None):
x, labels = x[:, :-1], x[:, 1:]
x = self.embed(x)
h_next = []
h_states = iter(h_states if h_states is not None else [])
for norm, mingru in self.layers:
h_prev = next(h_states, None)
residual = x
x = norm(x)
x, h_t = mingru.forward(x, h_prev)
x = x + residual
h_next.append(h_t)
x = self.norm_f(x)
logits = self.lm_head(x)
loss = F.cross_entropy(logits.transpose(1, 2), labels)
return logits, h_next, loss |