from typing import Optional import torch import torch.nn as nn class CastedLinear(nn.Linear): def forward(self, x: torch.FloatTensor): if self.weight.device.type == "meta": return nn.functional.linear(x, self.weight) return nn.functional.linear(x, self.weight.type_as(x)) class FeedForward(nn.Module): def __init__( self, embedding_dim: int, hidden_dim: int, device: torch.device, dtype: torch.dtype | None = None, ): factory_kwargs = dict(device=device, dtype=dtype) super().__init__() self.fc1 = CastedLinear(embedding_dim, hidden_dim, bias=False, **factory_kwargs) self.fc2 = CastedLinear(embedding_dim, hidden_dim, bias=False, **factory_kwargs) self.fc3 = CastedLinear(hidden_dim, embedding_dim, bias=False, **factory_kwargs) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x_fc1 = self.fc1(x) x_fc2 = self.fc2(x) x = nn.functional.silu(x_fc1) * x_fc2 x = self.fc3(x) return x class MoEFeedForward(nn.Module): def __init__( self, embedding_dim: int, hidden_dim: int, num_experts_per_token: int, num_experts: int, device: torch.device, dtype: torch.dtype | None = None, ): assert num_experts > 0, "num_experts should be greater than zero" assert num_experts >= num_experts_per_token > 0, ( "num_experts_per_token should be greater than zero and less than or equal to num_experts" ) super().__init__() self.num_experts_per_token = num_experts_per_token self.num_experts = num_experts meta_device = torch.device("meta") self.gate = CastedLinear( embedding_dim, num_experts, bias=False, device=device, dtype=dtype ) self.ff = nn.ModuleList( [ FeedForward( embedding_dim, hidden_dim, device=meta_device, dtype=dtype, ) for _ in range(num_experts) ] ) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: scores = self.gate(x) topk_scores, topk_indices = torch.topk( scores, self.num_experts_per_token, dim=-1 ) topk_probs = torch.softmax(topk_scores, dim=-1) expert_outputs = [] for i in range(self.num_experts): out = self.ff[i](x) expert_outputs.append(out.unsqueeze(-2)) expert_outputs = torch.cat(expert_outputs, dim=-2) gating_probs = torch.zeros_like(scores) for i in range(self.num_experts_per_token): indices = topk_indices[..., i : i + 1] prob = topk_probs[..., i : i + 1] gating_probs.scatter_(dim=-1, index=indices, src=prob) gating_probs = gating_probs.unsqueeze(-1) y = (gating_probs * expert_outputs).sum(dim=-2) return y class RMSNorm(nn.Module): def __init__( self, embedding_dim: int, eps: float = 1e-6, bias: bool = False, device: torch.device | None = None, dtype: torch.dtype | None = None, ): factory_kwargs = dict(device=device, dtype=dtype) super().__init__() self.embedding_dim = embedding_dim self.eps = eps self.bias = bias self.scale = nn.Parameter(torch.ones(embedding_dim, **factory_kwargs)) self.shift = ( nn.Parameter(torch.zeros(embedding_dim, **factory_kwargs)) if bias else None ) self.dtype = dtype def extra_repr(self): s = f"embedding_dim=%r, eps=%r, bias=%r" % (self.embedding_dim, self.eps, self.bias) return s def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: input_dtype = x.dtype variance = x.to(self.dtype).pow(2).mean(dim=-1, keepdim=True) norm_x = x * torch.rsqrt(variance + self.eps) norm_x = norm_x * self.scale if self.shift is not None: norm_x = norm_x + self.shift return norm_x.to(input_dtype) def compute_rope_params( head_dim: int, theta_base: int = 10_000, context_length: int = 4096, dtype: Optional[torch.dtype] = torch.float32, device: Optional[torch.device] = None, ) -> tuple[torch.FloatTensor, torch.FloatTensor]: assert head_dim % 2 == 0, "Embedding dim (head_dim) must be even" inv_freq = 1.0 / ( theta_base ** ( torch.arange(0, head_dim, 2, dtype=dtype, device=device)[ : head_dim // 2 ].float() / head_dim ) ) positions = torch.arange(context_length, dtype=dtype, device=device) angles = positions[:, None] * inv_freq[None, :] angles = torch.cat([angles, angles], dim=1) cos = torch.cos(angles) sin = torch.sin(angles) return cos, sin def apply_rope( x: torch.FloatTensor, cos: torch.FloatTensor, sin: torch.FloatTensor, offset: int = 0, ) -> torch.FloatTensor: assert x.dim() == 4, "expected tensor of dimension 3 (B, NH, S, H)" _, _, seq_len, head_dim = x.shape assert head_dim % 2 == 0, "head_dim must be even" x1 = x[..., : head_dim // 2] x2 = x[..., : head_dim // 2 :] cos = cos[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0) sin = sin[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0) rotated = torch.cat((-x2, x1), dim=-1) x_rotated = (x * cos) + (rotated * sin) x_rotated = x_rotated.type_as(x) return x_rotated