import torch import torch.nn as nn import torch.nn.functional as F def rotate_every_two(x: torch.Tensor) -> torch.Tensor: x1 = x[..., ::2] #select all tensors from even indices along the last dim x2 = x[..., 1::2] #select all odd indices tensor along last dim x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2) #stack those bitches return x_rot def get_model_device(model): return next(iter(model.parameters())).device class RetNet(nn.Module): decay: torch.Tensor angle: torch.Tensor def __init__(self, hidden_size, num_heads=8): super().__init__() self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads self.scaling = self.head_size**-0.5 self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.g_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.out_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.norm = nn.RMSNorm(self.head_size, eps=1e-6, elementwise_affine=False) self.register_buffer("decay", torch.empty(num_heads)) self.register_buffer("angle", torch.empty(self.head_size)) def forward( self, x: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor, torch.Tensor] ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: batch_size, hidden_size = x.shape assert hidden_size == self.hidden_size seq_offsets, scales, recurrent_state = state assert seq_offsets.shape == (batch_size,) assert scales.shape == (batch_size, self.num_heads) assert recurrent_state.shape == ( batch_size, self.num_heads, self.head_size, self.head_size, ) q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) g = self.g_proj(x) k = k * self.scaling q_heads = q.view(batch_size, self.num_heads, self.head_size) k_heads = k.view(batch_size, self.num_heads, self.head_size) v_heads = v.view(batch_size, self.num_heads, self.head_size) # Rope sin = torch.sin(seq_offsets[:, None, None] * self.angle[None, None, :]) cos = torch.cos(seq_offsets[:, None, None] * self.angle[None, None, :]) q_rope = q_heads * cos + rotate_every_two(q_heads) * sin k_rope = k_heads * cos + rotate_every_two(k_heads) * sin # State update kv_outer_prod = k_rope.unsqueeze(-1) * v_heads.unsqueeze(-2) new_recurrent_state = ( recurrent_state * self.decay[None, :, None, None] + kv_outer_prod ) # State scaling new_scales = scales * self.decay + 1.0 scale_factor = (1.0 / new_scales.sqrt())[:, :, None, None] scaled_state = new_recurrent_state * scale_factor # Out out = torch.matmul(q_rope.unsqueeze(2), scaled_state).squeeze(2) out = self.norm(out).reshape(batch_size, self.hidden_size) out = F.silu(g) * out out = self.out_proj(out) return out, (seq_offsets + 1, new_scales, new_recurrent_state) def init_state( self, batch_size: int, device: torch.device | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if device is None: device = get_model_device(self) return ( torch.zeros(batch_size, dtype=torch.int32, device=device), torch.zeros(batch_size, self.num_heads, device=device), torch.zeros( batch_size, self.num_heads, self.head_size, self.head_size, device=device, ), )