mmcarpi's picture
Upload custom model with source code and tokenizer
7e9dc48 verified
from typing import Optional
import torch
import torch.nn as nn
class CastedLinear(nn.Linear):
def forward(self, x: torch.FloatTensor):
if self.weight.device.type == "meta":
return nn.functional.linear(x, self.weight)
return nn.functional.linear(x, self.weight.type_as(x))
class FeedForward(nn.Module):
def __init__(
self,
embedding_dim: int,
hidden_dim: int,
device: torch.device,
dtype: torch.dtype | None = None,
):
factory_kwargs = dict(device=device, dtype=dtype)
super().__init__()
self.fc1 = CastedLinear(embedding_dim, hidden_dim, bias=False, **factory_kwargs)
self.fc2 = CastedLinear(embedding_dim, hidden_dim, bias=False, **factory_kwargs)
self.fc3 = CastedLinear(hidden_dim, embedding_dim, bias=False, **factory_kwargs)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x_fc1 = self.fc1(x)
x_fc2 = self.fc2(x)
x = nn.functional.silu(x_fc1) * x_fc2
x = self.fc3(x)
return x
class MoEFeedForward(nn.Module):
def __init__(
self,
embedding_dim: int,
hidden_dim: int,
num_experts_per_token: int,
num_experts: int,
device: torch.device,
dtype: torch.dtype | None = None,
):
assert num_experts > 0, "num_experts should be greater than zero"
assert num_experts >= num_experts_per_token > 0, (
"num_experts_per_token should be greater than zero and less than or equal to num_experts"
)
super().__init__()
self.num_experts_per_token = num_experts_per_token
self.num_experts = num_experts
meta_device = torch.device("meta")
self.gate = CastedLinear(
embedding_dim, num_experts, bias=False, device=device, dtype=dtype
)
self.ff = nn.ModuleList(
[
FeedForward(
embedding_dim,
hidden_dim,
device=meta_device,
dtype=dtype,
)
for _ in range(num_experts)
]
)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
scores = self.gate(x)
topk_scores, topk_indices = torch.topk(
scores, self.num_experts_per_token, dim=-1
)
topk_probs = torch.softmax(topk_scores, dim=-1)
expert_outputs = []
for i in range(self.num_experts):
out = self.ff[i](x)
expert_outputs.append(out.unsqueeze(-2))
expert_outputs = torch.cat(expert_outputs, dim=-2)
gating_probs = torch.zeros_like(scores)
for i in range(self.num_experts_per_token):
indices = topk_indices[..., i : i + 1]
prob = topk_probs[..., i : i + 1]
gating_probs.scatter_(dim=-1, index=indices, src=prob)
gating_probs = gating_probs.unsqueeze(-1)
y = (gating_probs * expert_outputs).sum(dim=-2)
return y
class RMSNorm(nn.Module):
def __init__(
self,
embedding_dim: int,
eps: float = 1e-6,
bias: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
factory_kwargs = dict(device=device, dtype=dtype)
super().__init__()
self.embedding_dim = embedding_dim
self.eps = eps
self.bias = bias
self.scale = nn.Parameter(torch.ones(embedding_dim, **factory_kwargs))
self.shift = (
nn.Parameter(torch.zeros(embedding_dim, **factory_kwargs)) if bias else None
)
self.dtype = dtype
def extra_repr(self):
s = "embedding_dim=%r, eps=%r, bias=%r" % (
self.embedding_dim,
self.eps,
self.bias,
)
return s
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
input_dtype = x.dtype
variance = x.to(self.dtype).pow(2).mean(dim=-1, keepdim=True)
norm_x = x * torch.rsqrt(variance + self.eps)
norm_x = norm_x * self.scale
if self.shift is not None:
norm_x = norm_x + self.shift
return norm_x.to(input_dtype)
def compute_rope_params(
head_dim: int,
theta_base: int = 10_000,
context_length: int = 4096,
dtype: Optional[torch.dtype] = torch.float32,
device: Optional[torch.device] = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
assert head_dim % 2 == 0, "Embedding dim (head_dim) must be even"
inv_freq = 1.0 / (
theta_base
** (
torch.arange(0, head_dim, 2, dtype=dtype, device=device)[
: head_dim // 2
].float()
/ head_dim
)
)
positions = torch.arange(context_length, dtype=dtype, device=device)
angles = positions[:, None] * inv_freq[None, :]
angles = torch.cat([angles, angles], dim=1)
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
def apply_rope(
x: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
offset: int = 0,
) -> torch.FloatTensor:
assert x.dim() == 4, "expected tensor of dimension 3 (B, NH, S, H)"
_, _, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "head_dim must be even"
x1 = x[..., : head_dim // 2]
x2 = x[..., : head_dim // 2 :]
cos = cos[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0)
sin = sin[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0)
rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)
x_rotated = x_rotated.type_as(x)
return x_rotated