| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
| class RoPEPositionalEncoding(nn.Module): |
| def __init__(self, dim, max_len=2048): |
| super().__init__() |
| self.dim = dim |
|
|
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| self._cached_cos = None |
| self._cached_sin = None |
| self._cached_len = 0 |
|
|
| def _compute_cache(self, seq_len, device): |
| if seq_len > self._cached_len or ( |
| self._cached_cos is not None and self._cached_cos.device != device |
| ): |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| inv_freq = self.inv_freq.to(device) |
| freqs = torch.outer(t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
|
|
| self._cached_cos = emb.cos() |
| self._cached_sin = emb.sin() |
| self._cached_len = seq_len |
|
|
| return ( |
| self._cached_cos[:seq_len].to(device), |
| self._cached_sin[:seq_len].to(device), |
| ) |
|
|
| def rotate_half(self, x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def apply_rope(self, q, k, seq_len): |
| cos, sin = self._compute_cache(seq_len, q.device) |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
| q = (q * cos) + (self.rotate_half(q) * sin) |
| k = (k * cos) + (self.rotate_half(k) * sin) |
|
|
| return q, k |
|
|
|
|
| class BitLinearFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, input, weight, bias=None): |
| scale = 127.0 / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) |
| x_quant = (input * scale).round().clamp(-128, 127) / scale |
|
|
| w_scale = weight.abs().mean().clamp(min=1e-5) |
| w_quant = (weight / w_scale).round().clamp(-1, 1) * w_scale |
|
|
| ctx.save_for_backward(input, weight) |
| ctx.w_quant = w_quant |
|
|
| return F.linear(x_quant, w_quant, bias) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| input, weight = ctx.saved_tensors |
| w_quant = ctx.w_quant |
|
|
| grad_input = grad_output.matmul(w_quant) |
|
|
| grad_output_flat = grad_output.view(-1, grad_output.shape[-1]) |
| input_flat = input.view(-1, input.shape[-1]) |
| grad_weight = grad_output_flat.t().mm(input_flat) |
|
|
| grad_bias = None |
| if ctx.needs_input_grad[2]: |
| grad_bias = grad_output_flat.sum(0) |
|
|
| return grad_input, grad_weight, grad_bias |
|
|
|
|
| class RigorousBitLinear(nn.Module): |
| def __init__(self, in_features, out_features, bias=False): |
| super().__init__() |
| self.weight = nn.Parameter(torch.randn(out_features, in_features)) |
| self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None |
|
|
| def forward(self, x): |
| return BitLinearFunction.apply(x, self.weight, self.bias) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return normed * self.weight |
|
|
|
|
| class ImprovedBitAttention(nn.Module): |
| def __init__(self, dim, heads=8, dropout=0.1, max_len=2048): |
| super().__init__() |
| self.heads = heads |
| self.head_dim = dim // heads |
| self.scale = self.head_dim ** -0.5 |
|
|
| self.q_proj = RigorousBitLinear(dim, dim) |
| self.k_proj = RigorousBitLinear(dim, dim) |
| self.v_proj = RigorousBitLinear(dim, dim) |
| self.out_proj = RigorousBitLinear(dim, dim) |
|
|
| self.rope = RoPEPositionalEncoding(self.head_dim, max_len) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| B, L, D = x.shape |
|
|
| q = self.q_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) |
|
|
| q, k = self.rope.apply_rope(q, k, L) |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| mask = torch.tril(torch.ones(L, L, device=x.device, dtype=torch.bool)) |
| attn = attn.masked_fill(~mask, float("-inf")) |
|
|
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
|
|
| out = (attn @ v).transpose(1, 2).contiguous().view(B, L, D) |
| return self.out_proj(out) |
|
|
|
|
|
|
| class SwiGLUMLP(nn.Module): |
| def __init__(self, dim, expansion=2.67, dropout=0.1): |
| super().__init__() |
| hidden = int(dim * expansion) |
|
|
| |
| self.gate_proj = RigorousBitLinear(dim, hidden) |
| self.up_proj = RigorousBitLinear(dim, hidden) |
| self.down_proj = RigorousBitLinear(hidden, dim) |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| gate = F.silu(self.gate_proj(x)) |
| up = self.up_proj(x) |
| return self.down_proj(self.dropout(gate * up)) |
|
|
|
|
|
|
| class ImprovedBitBlock(nn.Module): |
| def __init__(self, dim, heads=8, dropout=0.1, max_len=2048): |
| super().__init__() |
| self.norm1 = RMSNorm(dim) |
| self.attn = ImprovedBitAttention(dim, heads, dropout, max_len) |
| self.norm2 = RMSNorm(dim) |
| self.mlp = SwiGLUMLP(dim, dropout=dropout) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.norm1(x)) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class ImprovedBitNet(nn.Module, PyTorchModelHubMixin): |
| def __init__( |
| self, |
| vocab_size: int = 30522, |
| dim: int = 768, |
| depth: int = 12, |
| heads: int = 12, |
| max_len: int = 512, |
| dropout: float = 0.05, |
| ): |
| super().__init__() |
|
|
| self.vocab_size = vocab_size |
| self.dim = dim |
| self.depth = depth |
|
|
| |
| self.token_emb = nn.Embedding(vocab_size, dim) |
|
|
| |
| self.blocks = nn.ModuleList( |
| [ |
| ImprovedBitBlock( |
| dim=dim, |
| heads=heads, |
| dropout=dropout, |
| max_len=max_len, |
| ) |
| for _ in range(depth) |
| ] |
| ) |
|
|
| |
| self.norm = RMSNorm(dim) |
| self.head = nn.Linear(dim, vocab_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.token_emb(x) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| x = self.norm(x) |
| logits = self.head(x) |
| return logits |
|
|