| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| class CastedLinear(nn.Linear): | |
| def forward(self, x: torch.FloatTensor): | |
| if self.weight.device.type == "meta": | |
| return nn.functional.linear(x, self.weight) | |
| return nn.functional.linear(x, self.weight.type_as(x)) | |
| class FeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| hidden_dim: int, | |
| device: torch.device, | |
| dtype: torch.dtype | None = None, | |
| ): | |
| factory_kwargs = dict(device=device, dtype=dtype) | |
| super().__init__() | |
| self.fc1 = CastedLinear(embedding_dim, hidden_dim, bias=False, **factory_kwargs) | |
| self.fc2 = CastedLinear(embedding_dim, hidden_dim, bias=False, **factory_kwargs) | |
| self.fc3 = CastedLinear(hidden_dim, embedding_dim, bias=False, **factory_kwargs) | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| x_fc1 = self.fc1(x) | |
| x_fc2 = self.fc2(x) | |
| x = nn.functional.silu(x_fc1) * x_fc2 | |
| x = self.fc3(x) | |
| return x | |
| class MoEFeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| hidden_dim: int, | |
| num_experts_per_token: int, | |
| num_experts: int, | |
| device: torch.device, | |
| dtype: torch.dtype | None = None, | |
| ): | |
| assert num_experts > 0, "num_experts should be greater than zero" | |
| assert num_experts >= num_experts_per_token > 0, ( | |
| "num_experts_per_token should be greater than zero and less than or equal to num_experts" | |
| ) | |
| super().__init__() | |
| self.num_experts_per_token = num_experts_per_token | |
| self.num_experts = num_experts | |
| meta_device = torch.device("meta") | |
| self.gate = CastedLinear( | |
| embedding_dim, num_experts, bias=False, device=device, dtype=dtype | |
| ) | |
| self.ff = nn.ModuleList( | |
| [ | |
| FeedForward( | |
| embedding_dim, | |
| hidden_dim, | |
| device=meta_device, | |
| dtype=dtype, | |
| ) | |
| for _ in range(num_experts) | |
| ] | |
| ) | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| scores = self.gate(x) | |
| topk_scores, topk_indices = torch.topk( | |
| scores, self.num_experts_per_token, dim=-1 | |
| ) | |
| topk_probs = torch.softmax(topk_scores, dim=-1) | |
| expert_outputs = [] | |
| for i in range(self.num_experts): | |
| out = self.ff[i](x) | |
| expert_outputs.append(out.unsqueeze(-2)) | |
| expert_outputs = torch.cat(expert_outputs, dim=-2) | |
| gating_probs = torch.zeros_like(scores) | |
| for i in range(self.num_experts_per_token): | |
| indices = topk_indices[..., i : i + 1] | |
| prob = topk_probs[..., i : i + 1] | |
| gating_probs.scatter_(dim=-1, index=indices, src=prob) | |
| gating_probs = gating_probs.unsqueeze(-1) | |
| y = (gating_probs * expert_outputs).sum(dim=-2) | |
| return y | |
| class RMSNorm(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| eps: float = 1e-6, | |
| bias: bool = False, | |
| device: torch.device | None = None, | |
| dtype: torch.dtype | None = None, | |
| ): | |
| factory_kwargs = dict(device=device, dtype=dtype) | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.eps = eps | |
| self.bias = bias | |
| self.scale = nn.Parameter(torch.ones(embedding_dim, **factory_kwargs)) | |
| self.shift = ( | |
| nn.Parameter(torch.zeros(embedding_dim, **factory_kwargs)) if bias else None | |
| ) | |
| self.dtype = dtype | |
| def extra_repr(self): | |
| s = "embedding_dim=%r, eps=%r, bias=%r" % ( | |
| self.embedding_dim, | |
| self.eps, | |
| self.bias, | |
| ) | |
| return s | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| input_dtype = x.dtype | |
| variance = x.to(self.dtype).pow(2).mean(dim=-1, keepdim=True) | |
| norm_x = x * torch.rsqrt(variance + self.eps) | |
| norm_x = norm_x * self.scale | |
| if self.shift is not None: | |
| norm_x = norm_x + self.shift | |
| return norm_x.to(input_dtype) | |
| def compute_rope_params( | |
| head_dim: int, | |
| theta_base: int = 10_000, | |
| context_length: int = 4096, | |
| dtype: Optional[torch.dtype] = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor]: | |
| assert head_dim % 2 == 0, "Embedding dim (head_dim) must be even" | |
| inv_freq = 1.0 / ( | |
| theta_base | |
| ** ( | |
| torch.arange(0, head_dim, 2, dtype=dtype, device=device)[ | |
| : head_dim // 2 | |
| ].float() | |
| / head_dim | |
| ) | |
| ) | |
| positions = torch.arange(context_length, dtype=dtype, device=device) | |
| angles = positions[:, None] * inv_freq[None, :] | |
| angles = torch.cat([angles, angles], dim=1) | |
| cos = torch.cos(angles) | |
| sin = torch.sin(angles) | |
| return cos, sin | |
| def apply_rope( | |
| x: torch.FloatTensor, | |
| cos: torch.FloatTensor, | |
| sin: torch.FloatTensor, | |
| offset: int = 0, | |
| ) -> torch.FloatTensor: | |
| assert x.dim() == 4, "expected tensor of dimension 3 (B, NH, S, H)" | |
| _, _, seq_len, head_dim = x.shape | |
| assert head_dim % 2 == 0, "head_dim must be even" | |
| x1 = x[..., : head_dim // 2] | |
| x2 = x[..., : head_dim // 2 :] | |
| cos = cos[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0) | |
| sin = sin[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0) | |
| rotated = torch.cat((-x2, x1), dim=-1) | |
| x_rotated = (x * cos) + (rotated * sin) | |
| x_rotated = x_rotated.type_as(x) | |
| return x_rotated | |