TernaryLM / model.py
OpenRAG128's picture
Update model.py
5c99c91 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from huggingface_hub import PyTorchModelHubMixin
class RoPEPositionalEncoding(nn.Module):
def __init__(self, dim, max_len=2048):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._cached_cos = None
self._cached_sin = None
self._cached_len = 0
def _compute_cache(self, seq_len, device):
if seq_len > self._cached_len or (
self._cached_cos is not None and self._cached_cos.device != device
):
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq.to(device)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self._cached_cos = emb.cos()
self._cached_sin = emb.sin()
self._cached_len = seq_len
return (
self._cached_cos[:seq_len].to(device),
self._cached_sin[:seq_len].to(device),
)
def rotate_half(self, x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rope(self, q, k, seq_len):
cos, sin = self._compute_cache(seq_len, q.device)
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q = (q * cos) + (self.rotate_half(q) * sin)
k = (k * cos) + (self.rotate_half(k) * sin)
return q, k
class BitLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
scale = 127.0 / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
x_quant = (input * scale).round().clamp(-128, 127) / scale
w_scale = weight.abs().mean().clamp(min=1e-5)
w_quant = (weight / w_scale).round().clamp(-1, 1) * w_scale
ctx.save_for_backward(input, weight)
ctx.w_quant = w_quant
return F.linear(x_quant, w_quant, bias)
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
w_quant = ctx.w_quant
grad_input = grad_output.matmul(w_quant)
grad_output_flat = grad_output.view(-1, grad_output.shape[-1])
input_flat = input.view(-1, input.shape[-1])
grad_weight = grad_output_flat.t().mm(input_flat)
grad_bias = None
if ctx.needs_input_grad[2]:
grad_bias = grad_output_flat.sum(0)
return grad_input, grad_weight, grad_bias
class RigorousBitLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
def forward(self, x):
return BitLinearFunction.apply(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return normed * self.weight
class ImprovedBitAttention(nn.Module):
def __init__(self, dim, heads=8, dropout=0.1, max_len=2048):
super().__init__()
self.heads = heads
self.head_dim = dim // heads
self.scale = self.head_dim ** -0.5
self.q_proj = RigorousBitLinear(dim, dim)
self.k_proj = RigorousBitLinear(dim, dim)
self.v_proj = RigorousBitLinear(dim, dim)
self.out_proj = RigorousBitLinear(dim, dim)
self.rope = RoPEPositionalEncoding(self.head_dim, max_len)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, L, D = x.shape
q = self.q_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, L, self.heads, self.head_dim).transpose(1, 2)
q, k = self.rope.apply_rope(q, k, L)
attn = (q @ k.transpose(-2, -1)) * self.scale
mask = torch.tril(torch.ones(L, L, device=x.device, dtype=torch.bool))
attn = attn.masked_fill(~mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = (attn @ v).transpose(1, 2).contiguous().view(B, L, D)
return self.out_proj(out)
class SwiGLUMLP(nn.Module):
def __init__(self, dim, expansion=2.67, dropout=0.1):
super().__init__()
hidden = int(dim * expansion)
# IMPORTANT: keep original names
self.gate_proj = RigorousBitLinear(dim, hidden)
self.up_proj = RigorousBitLinear(dim, hidden)
self.down_proj = RigorousBitLinear(hidden, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(self.dropout(gate * up))
class ImprovedBitBlock(nn.Module):
def __init__(self, dim, heads=8, dropout=0.1, max_len=2048):
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = ImprovedBitAttention(dim, heads, dropout, max_len)
self.norm2 = RMSNorm(dim)
self.mlp = SwiGLUMLP(dim, dropout=dropout)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class ImprovedBitNet(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
vocab_size: int = 30522,
dim: int = 768,
depth: int = 12,
heads: int = 12,
max_len: int = 512,
dropout: float = 0.05,
):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.depth = depth
# Token embedding
self.token_emb = nn.Embedding(vocab_size, dim)
# Transformer blocks
self.blocks = nn.ModuleList(
[
ImprovedBitBlock(
dim=dim,
heads=heads,
dropout=dropout,
max_len=max_len,
)
for _ in range(depth)
]
)
# Final normalization + LM head
self.norm = RMSNorm(dim)
self.head = nn.Linear(dim, vocab_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.token_emb(x)
for block in self.blocks:
x = block(x)
x = self.norm(x)
logits = self.head(x)
return logits