import torch import torch.nn as nn import torch.nn.functional as F import math from huggingface_hub import PyTorchModelHubMixin class RoPEPositionalEncoding(nn.Module): def __init__(self, dim, max_len=2048): super().__init__() self.dim = dim inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._cached_cos = None self._cached_sin = None self._cached_len = 0 def _compute_cache(self, seq_len, device): if seq_len > self._cached_len or ( self._cached_cos is not None and self._cached_cos.device != device ): t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) inv_freq = self.inv_freq.to(device) freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self._cached_cos = emb.cos() self._cached_sin = emb.sin() self._cached_len = seq_len return ( self._cached_cos[:seq_len].to(device), self._cached_sin[:seq_len].to(device), ) def rotate_half(self, x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rope(self, q, k, seq_len): cos, sin = self._compute_cache(seq_len, q.device) cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q = (q * cos) + (self.rotate_half(q) * sin) k = (k * cos) + (self.rotate_half(k) * sin) return q, k class BitLinearFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias=None): scale = 127.0 / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) x_quant = (input * scale).round().clamp(-128, 127) / scale w_scale = weight.abs().mean().clamp(min=1e-5) w_quant = (weight / w_scale).round().clamp(-1, 1) * w_scale ctx.save_for_backward(input, weight) ctx.w_quant = w_quant return F.linear(x_quant, w_quant, bias) @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors w_quant = ctx.w_quant grad_input = grad_output.matmul(w_quant) grad_output_flat = grad_output.view(-1, grad_output.shape[-1]) input_flat = input.view(-1, input.shape[-1]) grad_weight = grad_output_flat.t().mm(input_flat) grad_bias = None if ctx.needs_input_grad[2]: grad_bias = grad_output_flat.sum(0) return grad_input, grad_weight, grad_bias class RigorousBitLinear(nn.Module): def __init__(self, in_features, out_features, bias=False): super().__init__() self.weight = nn.Parameter(torch.randn(out_features, in_features)) self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None def forward(self, x): return BitLinearFunction.apply(x, self.weight, self.bias) class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return normed * self.weight class ImprovedBitAttention(nn.Module): def __init__(self, dim, heads=8, dropout=0.1, max_len=2048): super().__init__() self.heads = heads self.head_dim = dim // heads self.scale = self.head_dim ** -0.5 self.q_proj = RigorousBitLinear(dim, dim) self.k_proj = RigorousBitLinear(dim, dim) self.v_proj = RigorousBitLinear(dim, dim) self.out_proj = RigorousBitLinear(dim, dim) self.rope = RoPEPositionalEncoding(self.head_dim, max_len) self.dropout = nn.Dropout(dropout) def forward(self, x): B, L, D = x.shape q = self.q_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) q, k = self.rope.apply_rope(q, k, L) attn = (q @ k.transpose(-2, -1)) * self.scale mask = torch.tril(torch.ones(L, L, device=x.device, dtype=torch.bool)) attn = attn.masked_fill(~mask, float("-inf")) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) out = (attn @ v).transpose(1, 2).contiguous().view(B, L, D) return self.out_proj(out) class SwiGLUMLP(nn.Module): def __init__(self, dim, expansion=2.67, dropout=0.1): super().__init__() hidden = int(dim * expansion) # IMPORTANT: keep original names self.gate_proj = RigorousBitLinear(dim, hidden) self.up_proj = RigorousBitLinear(dim, hidden) self.down_proj = RigorousBitLinear(hidden, dim) self.dropout = nn.Dropout(dropout) def forward(self, x): gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) return self.down_proj(self.dropout(gate * up)) class ImprovedBitBlock(nn.Module): def __init__(self, dim, heads=8, dropout=0.1, max_len=2048): super().__init__() self.norm1 = RMSNorm(dim) self.attn = ImprovedBitAttention(dim, heads, dropout, max_len) self.norm2 = RMSNorm(dim) self.mlp = SwiGLUMLP(dim, dropout=dropout) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x class ImprovedBitNet(nn.Module, PyTorchModelHubMixin): def __init__( self, vocab_size: int = 30522, dim: int = 768, depth: int = 12, heads: int = 12, max_len: int = 512, dropout: float = 0.05, ): super().__init__() self.vocab_size = vocab_size self.dim = dim self.depth = depth # Token embedding self.token_emb = nn.Embedding(vocab_size, dim) # Transformer blocks self.blocks = nn.ModuleList( [ ImprovedBitBlock( dim=dim, heads=heads, dropout=dropout, max_len=max_len, ) for _ in range(depth) ] ) # Final normalization + LM head self.norm = RMSNorm(dim) self.head = nn.Linear(dim, vocab_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.token_emb(x) for block in self.blocks: x = block(x) x = self.norm(x) logits = self.head(x) return logits