OpenTransformer commited on
Commit
eedd277
·
verified ·
1 Parent(s): bde71d6

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +111 -0
model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PureBit Transformer - Binary-level language model"""
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ class Attention(nn.Module):
8
+ def __init__(self, d, heads=8):
9
+ super().__init__()
10
+ self.heads = heads
11
+ self.dk = d // heads
12
+ self.q_proj = nn.Linear(d, d, bias=False)
13
+ self.k_proj = nn.Linear(d, d, bias=False)
14
+ self.v_proj = nn.Linear(d, d, bias=False)
15
+ self.out_proj = nn.Linear(d, d, bias=False)
16
+
17
+ def forward(self, x, mask=None):
18
+ B, N, D = x.shape
19
+ q = self.q_proj(x).view(B, N, self.heads, self.dk).transpose(1, 2)
20
+ k = self.k_proj(x).view(B, N, self.heads, self.dk).transpose(1, 2)
21
+ v = self.v_proj(x).view(B, N, self.heads, self.dk).transpose(1, 2)
22
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
23
+ if mask is not None:
24
+ att = att + mask
25
+ att = F.softmax(att, dim=-1)
26
+ out = (att @ v).transpose(1, 2).reshape(B, N, D)
27
+ return self.out_proj(out)
28
+
29
+ class MLP(nn.Module):
30
+ def __init__(self, d, mult=4):
31
+ super().__init__()
32
+ self.fc1 = nn.Linear(d, d * mult, bias=False)
33
+ self.fc2 = nn.Linear(d * mult, d, bias=False)
34
+
35
+ def forward(self, x):
36
+ return self.fc2(F.gelu(self.fc1(x)))
37
+
38
+ class Block(nn.Module):
39
+ def __init__(self, d, heads=8):
40
+ super().__init__()
41
+ self.ln1 = nn.LayerNorm(d)
42
+ self.attn = Attention(d, heads)
43
+ self.ln2 = nn.LayerNorm(d)
44
+ self.mlp = MLP(d)
45
+
46
+ def forward(self, x, mask):
47
+ x = x + self.attn(self.ln1(x), mask)
48
+ x = x + self.mlp(self.ln2(x))
49
+ return x
50
+
51
+ class PureBitTransformer(nn.Module):
52
+ """Transformer operating on raw binary bits (vocab_size=2)"""
53
+ def __init__(self, d=256, layers=6, heads=8, ctx=4096):
54
+ super().__init__()
55
+ self.ctx = ctx
56
+ self.emb = nn.Embedding(2, d) # Binary: 0 or 1
57
+ self.blocks = nn.ModuleList([Block(d, heads) for _ in range(layers)])
58
+ self.ln = nn.LayerNorm(d)
59
+ self.head = nn.Linear(d, 2, bias=False)
60
+ self.head.weight = self.emb.weight # Weight tying
61
+
62
+ def forward(self, x):
63
+ B, N = x.shape
64
+ mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
65
+ h = self.emb(x)
66
+ for b in self.blocks:
67
+ h = b(h, mask)
68
+ return self.head(self.ln(h))
69
+
70
+ @torch.no_grad()
71
+ def generate(self, bits, max_new=256, temp=0.8):
72
+ """Generate new bits autoregressively"""
73
+ x = torch.tensor(bits, device=next(self.parameters()).device).unsqueeze(0)
74
+ for _ in range(max_new):
75
+ logits = self(x[:, -self.ctx:])[:, -1, :] / temp
76
+ next_bit = torch.multinomial(F.softmax(logits, -1), 1)
77
+ x = torch.cat([x, next_bit], 1)
78
+ return x[0].tolist()
79
+
80
+ def text_to_bits(text):
81
+ """Convert UTF-8 text to list of bits"""
82
+ bits = []
83
+ for byte in text.encode('utf-8'):
84
+ for i in range(7, -1, -1):
85
+ bits.append((byte >> i) & 1)
86
+ return bits
87
+
88
+ def bits_to_text(bits):
89
+ """Convert list of bits back to UTF-8 text"""
90
+ while len(bits) % 8 != 0:
91
+ bits = bits + [0]
92
+ bytes_out = []
93
+ for i in range(0, len(bits), 8):
94
+ byte = 0
95
+ for j in range(8):
96
+ byte = (byte << 1) | bits[i + j]
97
+ bytes_out.append(byte)
98
+ return bytes(bytes_out).decode('utf-8', errors='replace')
99
+
100
+ def load_model(checkpoint_path, device='cuda'):
101
+ """Load model from checkpoint"""
102
+ ckpt = torch.load(checkpoint_path, map_location=device)
103
+ model = PureBitTransformer(d=256, layers=6, heads=8).to(device)
104
+ model.load_state_dict(ckpt['model'])
105
+ model.eval()
106
+ return model, ckpt
107
+
108
+ if __name__ == "__main__":
109
+ model = PureBitTransformer()
110
+ params = sum(p.numel() for p in model.parameters())
111
+ print(f"PureBit Transformer: {params:,} parameters")