"""PureBit Transformer - Binary-level language model""" import torch import torch.nn as nn import torch.nn.functional as F import math class Attention(nn.Module): def __init__(self, d, heads=8): super().__init__() self.heads = heads self.dk = d // heads self.q_proj = nn.Linear(d, d, bias=False) self.k_proj = nn.Linear(d, d, bias=False) self.v_proj = nn.Linear(d, d, bias=False) self.out_proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, D = x.shape q = self.q_proj(x).view(B, N, self.heads, self.dk).transpose(1, 2) k = self.k_proj(x).view(B, N, self.heads, self.dk).transpose(1, 2) v = self.v_proj(x).view(B, N, self.heads, self.dk).transpose(1, 2) att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) if mask is not None: att = att + mask att = F.softmax(att, dim=-1) out = (att @ v).transpose(1, 2).reshape(B, N, D) return self.out_proj(out) class MLP(nn.Module): def __init__(self, d, mult=4): super().__init__() self.fc1 = nn.Linear(d, d * mult, bias=False) self.fc2 = nn.Linear(d * mult, d, bias=False) def forward(self, x): return self.fc2(F.gelu(self.fc1(x))) class Block(nn.Module): def __init__(self, d, heads=8): super().__init__() self.ln1 = nn.LayerNorm(d) self.attn = Attention(d, heads) self.ln2 = nn.LayerNorm(d) self.mlp = MLP(d) def forward(self, x, mask): x = x + self.attn(self.ln1(x), mask) x = x + self.mlp(self.ln2(x)) return x class PureBitTransformer(nn.Module): """Transformer operating on raw binary bits (vocab_size=2)""" def __init__(self, d=256, layers=6, heads=8, ctx=4096): super().__init__() self.ctx = ctx self.emb = nn.Embedding(2, d) # Binary: 0 or 1 self.blocks = nn.ModuleList([Block(d, heads) for _ in range(layers)]) self.ln = nn.LayerNorm(d) self.head = nn.Linear(d, 2, bias=False) self.head.weight = self.emb.weight # Weight tying def forward(self, x): B, N = x.shape mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9 h = self.emb(x) for b in self.blocks: h = b(h, mask) return self.head(self.ln(h)) @torch.no_grad() def generate(self, bits, max_new=256, temp=0.8): """Generate new bits autoregressively""" x = torch.tensor(bits, device=next(self.parameters()).device).unsqueeze(0) for _ in range(max_new): logits = self(x[:, -self.ctx:])[:, -1, :] / temp next_bit = torch.multinomial(F.softmax(logits, -1), 1) x = torch.cat([x, next_bit], 1) return x[0].tolist() def text_to_bits(text): """Convert UTF-8 text to list of bits""" bits = [] for byte in text.encode('utf-8'): for i in range(7, -1, -1): bits.append((byte >> i) & 1) return bits def bits_to_text(bits): """Convert list of bits back to UTF-8 text""" while len(bits) % 8 != 0: bits = bits + [0] bytes_out = [] for i in range(0, len(bits), 8): byte = 0 for j in range(8): byte = (byte << 1) | bits[i + j] bytes_out.append(byte) return bytes(bytes_out).decode('utf-8', errors='replace') def load_model(checkpoint_path, device='cuda'): """Load model from checkpoint""" ckpt = torch.load(checkpoint_path, map_location=device) model = PureBitTransformer(d=256, layers=6, heads=8).to(device) model.load_state_dict(ckpt['model']) model.eval() return model, ckpt if __name__ == "__main__": model = PureBitTransformer() params = sum(p.numel() for p in model.parameters()) print(f"PureBit Transformer: {params:,} parameters")