|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
|
|
|
class RoPEPositionalEncoding(nn.Module): |
|
|
def __init__(self, dim, max_len=2048): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
self._cached_cos = None |
|
|
self._cached_sin = None |
|
|
self._cached_len = 0 |
|
|
|
|
|
def _compute_cache(self, seq_len, device): |
|
|
if seq_len > self._cached_len or ( |
|
|
self._cached_cos is not None and self._cached_cos.device != device |
|
|
): |
|
|
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
|
|
inv_freq = self.inv_freq.to(device) |
|
|
freqs = torch.outer(t, inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
self._cached_cos = emb.cos() |
|
|
self._cached_sin = emb.sin() |
|
|
self._cached_len = seq_len |
|
|
|
|
|
return ( |
|
|
self._cached_cos[:seq_len].to(device), |
|
|
self._cached_sin[:seq_len].to(device), |
|
|
) |
|
|
|
|
|
def rotate_half(self, x): |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rope(self, q, k, seq_len): |
|
|
cos, sin = self._compute_cache(seq_len, q.device) |
|
|
cos = cos.unsqueeze(0).unsqueeze(0) |
|
|
sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
q = (q * cos) + (self.rotate_half(q) * sin) |
|
|
k = (k * cos) + (self.rotate_half(k) * sin) |
|
|
|
|
|
return q, k |
|
|
|
|
|
|
|
|
class BitLinearFunction(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input, weight, bias=None): |
|
|
scale = 127.0 / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) |
|
|
x_quant = (input * scale).round().clamp(-128, 127) / scale |
|
|
|
|
|
w_scale = weight.abs().mean().clamp(min=1e-5) |
|
|
w_quant = (weight / w_scale).round().clamp(-1, 1) * w_scale |
|
|
|
|
|
ctx.save_for_backward(input, weight) |
|
|
ctx.w_quant = w_quant |
|
|
|
|
|
return F.linear(x_quant, w_quant, bias) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
input, weight = ctx.saved_tensors |
|
|
w_quant = ctx.w_quant |
|
|
|
|
|
grad_input = grad_output.matmul(w_quant) |
|
|
|
|
|
grad_output_flat = grad_output.view(-1, grad_output.shape[-1]) |
|
|
input_flat = input.view(-1, input.shape[-1]) |
|
|
grad_weight = grad_output_flat.t().mm(input_flat) |
|
|
|
|
|
grad_bias = None |
|
|
if ctx.needs_input_grad[2]: |
|
|
grad_bias = grad_output_flat.sum(0) |
|
|
|
|
|
return grad_input, grad_weight, grad_bias |
|
|
|
|
|
|
|
|
class RigorousBitLinear(nn.Module): |
|
|
def __init__(self, in_features, out_features, bias=False): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.randn(out_features, in_features)) |
|
|
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None |
|
|
|
|
|
def forward(self, x): |
|
|
return BitLinearFunction.apply(x, self.weight, self.bias) |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, dim, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
return normed * self.weight |
|
|
|
|
|
|
|
|
class ImprovedBitAttention(nn.Module): |
|
|
def __init__(self, dim, heads=8, dropout=0.1, max_len=2048): |
|
|
super().__init__() |
|
|
self.heads = heads |
|
|
self.head_dim = dim // heads |
|
|
self.scale = self.head_dim ** -0.5 |
|
|
|
|
|
self.q_proj = RigorousBitLinear(dim, dim) |
|
|
self.k_proj = RigorousBitLinear(dim, dim) |
|
|
self.v_proj = RigorousBitLinear(dim, dim) |
|
|
self.out_proj = RigorousBitLinear(dim, dim) |
|
|
|
|
|
self.rope = RoPEPositionalEncoding(self.head_dim, max_len) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
B, L, D = x.shape |
|
|
|
|
|
q = self.q_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
q, k = self.rope.apply_rope(q, k, L) |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
|
|
|
mask = torch.tril(torch.ones(L, L, device=x.device, dtype=torch.bool)) |
|
|
attn = attn.masked_fill(~mask, float("-inf")) |
|
|
|
|
|
attn = F.softmax(attn, dim=-1) |
|
|
attn = self.dropout(attn) |
|
|
|
|
|
out = (attn @ v).transpose(1, 2).contiguous().view(B, L, D) |
|
|
return self.out_proj(out) |
|
|
|
|
|
|
|
|
|
|
|
class SwiGLUMLP(nn.Module): |
|
|
def __init__(self, dim, expansion=2.67, dropout=0.1): |
|
|
super().__init__() |
|
|
hidden = int(dim * expansion) |
|
|
|
|
|
|
|
|
self.gate_proj = RigorousBitLinear(dim, hidden) |
|
|
self.up_proj = RigorousBitLinear(dim, hidden) |
|
|
self.down_proj = RigorousBitLinear(hidden, dim) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
gate = F.silu(self.gate_proj(x)) |
|
|
up = self.up_proj(x) |
|
|
return self.down_proj(self.dropout(gate * up)) |
|
|
|
|
|
|
|
|
|
|
|
class ImprovedBitBlock(nn.Module): |
|
|
def __init__(self, dim, heads=8, dropout=0.1, max_len=2048): |
|
|
super().__init__() |
|
|
self.norm1 = RMSNorm(dim) |
|
|
self.attn = ImprovedBitAttention(dim, heads, dropout, max_len) |
|
|
self.norm2 = RMSNorm(dim) |
|
|
self.mlp = SwiGLUMLP(dim, dropout=dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.attn(self.norm1(x)) |
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class ImprovedBitNet(nn.Module, PyTorchModelHubMixin): |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 30522, |
|
|
dim: int = 768, |
|
|
depth: int = 12, |
|
|
heads: int = 12, |
|
|
max_len: int = 512, |
|
|
dropout: float = 0.05, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.dim = dim |
|
|
self.depth = depth |
|
|
|
|
|
|
|
|
self.token_emb = nn.Embedding(vocab_size, dim) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
ImprovedBitBlock( |
|
|
dim=dim, |
|
|
heads=heads, |
|
|
dropout=dropout, |
|
|
max_len=max_len, |
|
|
) |
|
|
for _ in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.norm = RMSNorm(dim) |
|
|
self.head = nn.Linear(dim, vocab_size) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.token_emb(x) |
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
|
|
|
x = self.norm(x) |
|
|
logits = self.head(x) |
|
|
return logits |
|
|
|