import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm_x = torch.mean(x * x, dim=-1, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) return self.scale * x_normed class MLP(nn.Module): def __init__(self, n_embd: int): super().__init__() hidden_dim = 4 * n_embd n_hidden = int(2 * hidden_dim / 3) self.c_fc1 = nn.Linear(n_embd, n_hidden, bias=False) self.c_fc2 = nn.Linear(n_embd, n_hidden, bias=False) self.c_proj = nn.Linear(n_hidden, n_embd, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.c_proj(F.silu(self.c_fc1(x)) * self.c_fc2(x)) class SelfAttention(nn.Module): def __init__(self, n_embd: int, n_head: int, max_seq_len: int, rope_base: int = 10000): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.n_embd = n_embd self.head_dim = n_embd // n_head self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False) self.c_proj = nn.Linear(n_embd, n_embd, bias=False) rope_cache = self._build_rope_cache(max_seq_len, self.head_dim, rope_base) self.register_buffer("rope_cache", rope_cache, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, D = x.shape q, k, v = self.c_attn(x).split(self.n_embd, dim=2) q = q.view(B, T, self.n_head, self.head_dim) k = k.view(B, T, self.n_head, self.head_dim) v = v.view(B, T, self.n_head, self.head_dim) q = self._apply_rope(q, self.rope_cache) k = self._apply_rope(k, self.rope_cache) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) y = y.transpose(1, 2).contiguous().view(B, T, D) return self.c_proj(y) @staticmethod def _build_rope_cache(seq_len: int, head_dim: int, base: int = 10000) -> torch.Tensor: theta = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) seq_idx = torch.arange(seq_len, dtype=torch.float32) idx_theta = torch.outer(seq_idx, theta) cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) return cache @staticmethod def _apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: T = x.size(1) rope_cache = rope_cache[:T] x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) rope_cache = rope_cache.view(1, T, 1, x_shaped.size(3), 2) x_out = torch.stack( [ x_shaped[..., 0] * rope_cache[..., 0] - x_shaped[..., 1] * rope_cache[..., 1], x_shaped[..., 1] * rope_cache[..., 0] + x_shaped[..., 0] * rope_cache[..., 1], ], -1, ) return x_out.flatten(3).type_as(x) class RoFormerBlock(nn.Module): def __init__(self, n_embd: int, n_head: int, max_seq_len: int, rope_base: int = 10000): super().__init__() self.att_norm = RMSNorm(n_embd) self.att = SelfAttention(n_embd, n_head, max_seq_len, rope_base) self.ffn_norm = RMSNorm(n_embd) self.mlp = MLP(n_embd) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.att(self.att_norm(x)) x = x + self.mlp(self.ffn_norm(x)) return x if __name__ == '__main__': n_head = 8 n_embd = 512 max_seq_len = 256 model = RoFormerBlock( n_embd=n_embd, n_head=n_head, max_seq_len=max_seq_len, rope_base=10000 ) batch_size = 4 seq_len = 100 device = 'cuda' if torch.cuda.is_available() else 'cpu' input_tensor = torch.randn(batch_size, seq_len, n_embd).to(device) model.to(device) model.eval() with torch.no_grad(): output = model(input_tensor) print(f"Input shape: {input_tensor.shape}") print(f"Output shape: {output.shape}") assert output.shape == (batch_size, seq_len, n_embd)