0sparsh2 commited on
Commit
f5c2b6c
·
verified ·
1 Parent(s): 4831dd7

Upload bitnet_test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. bitnet_test.py +223 -0
bitnet_test.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import os
6
+ import urllib.request
7
+
8
+ # --- BITNET 1.58b IMPLEMENTATION ---
9
+
10
+ class RoundWithSTE(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(ctx, x):
13
+ return torch.round(x)
14
+
15
+ @staticmethod
16
+ def backward(ctx, grad_output):
17
+ return grad_output
18
+
19
+ class BitLinear(nn.Linear):
20
+ def __init__(self, in_features, out_features, bias=False):
21
+ super(BitLinear, self).__init__(in_features, out_features, bias=bias)
22
+
23
+ def forward(self, x):
24
+ # 1. Weight Quantization to {-1, 0, 1}
25
+ # scale = mean of absolute values
26
+ weight_scale = self.weight.abs().mean()
27
+
28
+ # Scale, round, clamp
29
+ w_scaled = self.weight / (weight_scale + 1e-8)
30
+ w_quant = torch.clamp(RoundWithSTE.apply(w_scaled), -1.0, 1.0)
31
+
32
+ # 2. Activation Quantization to 8-bit [-128, 127]
33
+ # per-token max absolute value
34
+ x_scale = x.abs().max(dim=-1, keepdim=True)[0]
35
+ x_scaled = x * (127.0 / (x_scale + 1e-8))
36
+ x_quant = torch.clamp(RoundWithSTE.apply(x_scaled), -128.0, 127.0)
37
+
38
+ # 3. Linear Projection (in a real engine this is int8 * int2 bitwise ops)
39
+ out = F.linear(x_quant, w_quant)
40
+
41
+ # 4. Dequantize
42
+ out = out * (weight_scale * x_scale / 127.0)
43
+
44
+ if self.bias is not None:
45
+ out += self.bias
46
+
47
+ return out
48
+
49
+ class RMSNorm(nn.Module):
50
+ def __init__(self, dim, eps=1e-6):
51
+ super().__init__()
52
+ self.weight = nn.Parameter(torch.ones(dim))
53
+ self.eps = eps
54
+
55
+ def forward(self, x):
56
+ variance = x.pow(2).mean(-1, keepdim=True)
57
+ x = x * torch.rsqrt(variance + self.eps)
58
+ return self.weight * x
59
+
60
+ class BitTransformerBlock(nn.Module):
61
+ def __init__(self, embed_dim, num_heads):
62
+ super().__init__()
63
+ self.embed_dim = embed_dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = embed_dim // num_heads
66
+
67
+ self.ln_1 = RMSNorm(embed_dim)
68
+ self.q_proj = BitLinear(embed_dim, embed_dim, bias=False)
69
+ self.k_proj = BitLinear(embed_dim, embed_dim, bias=False)
70
+ self.v_proj = BitLinear(embed_dim, embed_dim, bias=False)
71
+ self.o_proj = BitLinear(embed_dim, embed_dim, bias=False)
72
+
73
+ self.ln_2 = RMSNorm(embed_dim)
74
+ self.gate_proj = BitLinear(embed_dim, embed_dim * 4, bias=False)
75
+ self.up_proj = BitLinear(embed_dim, embed_dim * 4, bias=False)
76
+ self.down_proj = BitLinear(embed_dim * 4, embed_dim, bias=False)
77
+
78
+ def forward(self, x):
79
+ batch, seq, _ = x.shape
80
+
81
+ # --- Attention ---
82
+ norm_x = self.ln_1(x)
83
+ Q = self.q_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
84
+ K = self.k_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
85
+ V = self.v_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
86
+
87
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
88
+ mask = torch.triu(torch.ones(seq, seq, device=x.device), diagonal=1).bool()
89
+ scores.masked_fill_(mask, float('-inf'))
90
+ attn = torch.softmax(scores, dim=-1)
91
+
92
+ context = torch.matmul(attn, V).transpose(1, 2).contiguous().view(batch, seq, self.embed_dim)
93
+ x = x + self.o_proj(context)
94
+
95
+ # --- SwiGLU FFN ---
96
+ norm_x2 = self.ln_2(x)
97
+ gate = F.silu(self.gate_proj(norm_x2))
98
+ up = self.up_proj(norm_x2)
99
+ x = x + self.down_proj(gate * up)
100
+
101
+ return x
102
+
103
+ class BitGPT(nn.Module):
104
+ def __init__(self, vocab_size, embed_dim, num_layers, num_heads, tie_weights=False):
105
+ super().__init__()
106
+ # Continuous embeddings are usually kept continuous in BitNet
107
+ self.vocab_embed = nn.Embedding(vocab_size, embed_dim)
108
+ self.pos_embed = nn.Embedding(1024, embed_dim)
109
+
110
+ self.layers = nn.ModuleList([BitTransformerBlock(embed_dim, num_heads) for _ in range(num_layers)])
111
+
112
+ self.ln_f = RMSNorm(embed_dim)
113
+ # Head is usually continuous or bitlinear, we'll use BitLinear for maximum compression!
114
+ self.head = BitLinear(embed_dim, vocab_size, bias=False)
115
+
116
+ self.apply(self._init_weights)
117
+
118
+ if tie_weights:
119
+ self.head.weight = self.vocab_embed.weight
120
+
121
+ def _init_weights(self, module):
122
+ if isinstance(module, nn.Linear) or isinstance(module, BitLinear):
123
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
124
+ elif isinstance(module, nn.Embedding):
125
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
126
+
127
+ def forward(self, x):
128
+ batch, seq = x.shape
129
+ pos = torch.arange(seq, device=x.device).unsqueeze(0)
130
+
131
+ x = self.vocab_embed(x) + self.pos_embed(pos)
132
+
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+
136
+ x = self.ln_f(x)
137
+ return self.head(x)
138
+
139
+ def get_batch(text, seq_length, batch_size, char_to_ix):
140
+ ixs = torch.randint(0, len(text) - seq_length - 1, (batch_size,))
141
+ x = torch.zeros(batch_size, seq_length, dtype=torch.long)
142
+ y = torch.zeros(batch_size, seq_length, dtype=torch.long)
143
+ for i, idx in enumerate(ixs):
144
+ x[i] = torch.tensor([char_to_ix[ch] for ch in text[idx:idx+seq_length]])
145
+ y[i] = torch.tensor([char_to_ix[ch] for ch in text[idx+1:idx+seq_length+1]])
146
+ return x, y
147
+
148
+ def main():
149
+ if not os.path.exists("tinyshakespeare.txt"):
150
+ url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
151
+ urllib.request.urlretrieve(url, "tinyshakespeare.txt")
152
+
153
+ with open("tinyshakespeare.txt", "r") as f:
154
+ text = f.read()
155
+
156
+ chars = sorted(list(set(text)))
157
+ vocab_size = len(chars)
158
+ char_to_ix = {ch: i for i, ch in enumerate(chars)}
159
+ ix_to_char = {i: ch for i, ch in enumerate(chars)}
160
+
161
+ device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
162
+ print(f"Using device: {device}")
163
+
164
+ # 256 dims, 4 layers, 4 heads
165
+ model = BitGPT(vocab_size, embed_dim=256, num_layers=4, num_heads=4).to(device)
166
+
167
+ # Count parameters
168
+ params = sum(p.numel() for p in model.parameters())
169
+ print(f"Total Model Parameters: {params}")
170
+ print(f"File size if 16-bit Float: {params * 2 / 1024 / 1024:.2f} MB")
171
+ print(f"File size at 1.58 bits: {params * 1.58 / 8 / 1024 / 1024:.2f} MB")
172
+
173
+ # BitNet requires a slightly higher learning rate
174
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=0.01)
175
+ criterion = nn.CrossEntropyLoss()
176
+
177
+ batch_size = 128
178
+ seq_length = 64
179
+
180
+ print("Training the 1.58b BitNet Transformer...")
181
+ for step in range(600):
182
+ x, y = get_batch(text, seq_length, batch_size, char_to_ix)
183
+ x, y = x.to(device), y.to(device)
184
+
185
+ optimizer.zero_grad()
186
+ logits = model(x)
187
+
188
+ loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
189
+ loss.backward()
190
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
191
+ optimizer.step()
192
+
193
+ if step % 50 == 0:
194
+ perplexity = torch.exp(loss)
195
+ print(f"Step {step} | Loss: {loss.item():.4f} | Perplexity: {perplexity.item():.4f}")
196
+
197
+ print("Saving model weights to bitnet_model.pt...")
198
+ torch.save(model.state_dict(), "bitnet_model.pt")
199
+
200
+ print("\n--- GENERATING TEXT ---")
201
+ model.eval()
202
+ with torch.no_grad():
203
+ x = torch.tensor([[char_to_ix['T']]]).to(device)
204
+ out_text = 'T'
205
+ for _ in range(300):
206
+ logits = model(x)
207
+ # Take the last logit
208
+ logits = logits[:, -1, :]
209
+ probs = F.softmax(logits / 0.8, dim=-1)
210
+ next_ix = torch.multinomial(probs, 1).item()
211
+ out_text += ix_to_char[next_ix]
212
+
213
+ # Append next char
214
+ next_tensor = torch.tensor([[next_ix]]).to(device)
215
+ x = torch.cat([x, next_tensor], dim=1)
216
+ # Truncate context if too long
217
+ if x.size(1) > seq_length:
218
+ x = x[:, -seq_length:]
219
+
220
+ print(out_text)
221
+
222
+ if __name__ == '__main__':
223
+ main()