0sparsh2 commited on
Commit
d353ea8
·
verified ·
1 Parent(s): ce103c4

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +236 -0
model.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import os
6
+ import urllib.request
7
+
8
+ # --- BITNET 1.58b IMPLEMENTATION ---
9
+
10
+ class RoundWithSTE(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(ctx, x):
13
+ return torch.round(x)
14
+
15
+ @staticmethod
16
+ def backward(ctx, grad_output):
17
+ return grad_output
18
+
19
+ class BitLinear(nn.Linear):
20
+ def __init__(self, in_features, out_features, bias=False):
21
+ super(BitLinear, self).__init__(in_features, out_features, bias=bias)
22
+
23
+ def forward(self, x):
24
+ # 1. Weight Quantization to {-1, 0, 1}
25
+ # scale = mean of absolute values
26
+ weight_scale = self.weight.abs().mean()
27
+
28
+ # Scale, round, clamp
29
+ w_scaled = self.weight / (weight_scale + 1e-8)
30
+ w_quant = torch.clamp(RoundWithSTE.apply(w_scaled), -1.0, 1.0)
31
+
32
+ # 2. Activation Quantization to 8-bit [-128, 127]
33
+ # per-token max absolute value
34
+ x_scale = x.abs().max(dim=-1, keepdim=True)[0]
35
+ x_scaled = x * (127.0 / (x_scale + 1e-8))
36
+ x_quant = torch.clamp(RoundWithSTE.apply(x_scaled), -128.0, 127.0)
37
+
38
+ # 3. Linear Projection (in a real engine this is int8 * int2 bitwise ops)
39
+ out = F.linear(x_quant, w_quant)
40
+
41
+ # 4. Dequantize
42
+ out = out * (weight_scale * x_scale / 127.0)
43
+
44
+ if self.bias is not None:
45
+ out += self.bias
46
+
47
+ return out
48
+
49
+ class RMSNorm(nn.Module):
50
+ def __init__(self, dim, eps=1e-6):
51
+ super().__init__()
52
+ self.weight = nn.Parameter(torch.ones(dim))
53
+ self.eps = eps
54
+
55
+ def forward(self, x):
56
+ variance = x.pow(2).mean(-1, keepdim=True)
57
+ x = x * torch.rsqrt(variance + self.eps)
58
+ return self.weight * x
59
+
60
+ class BitTransformerBlock(nn.Module):
61
+ def __init__(self, embed_dim, num_heads):
62
+ super().__init__()
63
+ self.embed_dim = embed_dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = embed_dim // num_heads
66
+
67
+ self.ln_1 = RMSNorm(embed_dim)
68
+ self.q_proj = BitLinear(embed_dim, embed_dim, bias=False)
69
+ self.k_proj = BitLinear(embed_dim, embed_dim, bias=False)
70
+ self.v_proj = BitLinear(embed_dim, embed_dim, bias=False)
71
+ self.o_proj = BitLinear(embed_dim, embed_dim, bias=False)
72
+
73
+ self.ln_2 = RMSNorm(embed_dim)
74
+ self.gate_proj = BitLinear(embed_dim, embed_dim * 4, bias=False)
75
+ self.up_proj = BitLinear(embed_dim, embed_dim * 4, bias=False)
76
+ self.down_proj = BitLinear(embed_dim * 4, embed_dim, bias=False)
77
+
78
+ def forward(self, x):
79
+ batch, seq, _ = x.shape
80
+
81
+ # --- Attention ---
82
+ norm_x = self.ln_1(x)
83
+ Q = self.q_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
84
+ K = self.k_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
85
+ V = self.v_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
86
+
87
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
88
+ mask = torch.triu(torch.ones(seq, seq, device=x.device), diagonal=1).bool()
89
+ scores.masked_fill_(mask, float('-inf'))
90
+ attn = torch.softmax(scores, dim=-1)
91
+
92
+ context = torch.matmul(attn, V).transpose(1, 2).contiguous().view(batch, seq, self.embed_dim)
93
+ x = x + self.o_proj(context)
94
+
95
+ # --- SwiGLU FFN ---
96
+ norm_x2 = self.ln_2(x)
97
+ gate = F.silu(self.gate_proj(norm_x2))
98
+ up = self.up_proj(norm_x2)
99
+ x = x + self.down_proj(gate * up)
100
+
101
+ return x
102
+
103
+ class BitGPT(nn.Module):
104
+ def __init__(self, vocab_size, embed_dim=256, num_layers=12, num_heads=4, tie_weights=True, universal=False):
105
+ super().__init__()
106
+ self.vocab_size = vocab_size
107
+ self.embed_dim = embed_dim
108
+ self.num_layers = num_layers
109
+ self.universal = universal
110
+
111
+ self.vocab_embed = nn.Embedding(vocab_size, embed_dim)
112
+ self.pos_embed = nn.Embedding(2048, embed_dim)
113
+
114
+ if universal:
115
+ self.layers = nn.ModuleList([BitTransformerBlock(embed_dim, num_heads)])
116
+ else:
117
+ self.layers = nn.ModuleList([
118
+ BitTransformerBlock(embed_dim, num_heads) for _ in range(num_layers)
119
+ ])
120
+
121
+ self.ln_f = nn.LayerNorm(embed_dim)
122
+ # Head is usually continuous or bitlinear, we'll use BitLinear for maximum compression!
123
+ self.head = BitLinear(embed_dim, vocab_size, bias=False)
124
+
125
+ if tie_weights:
126
+ self.head.weight = self.vocab_embed.weight
127
+
128
+ self.apply(self._init_weights)
129
+
130
+ def _init_weights(self, module):
131
+ if isinstance(module, nn.Linear) or isinstance(module, BitLinear):
132
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
133
+ elif isinstance(module, nn.Embedding):
134
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
135
+
136
+ def forward(self, x):
137
+ batch, seq = x.shape
138
+ pos = torch.arange(seq, device=x.device).unsqueeze(0)
139
+
140
+ x = self.vocab_embed(x) + self.pos_embed(pos)
141
+
142
+ if self.universal:
143
+ for _ in range(self.num_layers):
144
+ x = self.layers[0](x)
145
+ else:
146
+ for layer in self.layers:
147
+ x = layer(x)
148
+
149
+ x = self.ln_f(x)
150
+ return self.head(x)
151
+
152
+ def get_batch(text, seq_length, batch_size, char_to_ix):
153
+ ixs = torch.randint(0, len(text) - seq_length - 1, (batch_size,))
154
+ x = torch.zeros(batch_size, seq_length, dtype=torch.long)
155
+ y = torch.zeros(batch_size, seq_length, dtype=torch.long)
156
+ for i, idx in enumerate(ixs):
157
+ x[i] = torch.tensor([char_to_ix[ch] for ch in text[idx:idx+seq_length]])
158
+ y[i] = torch.tensor([char_to_ix[ch] for ch in text[idx+1:idx+seq_length+1]])
159
+ return x, y
160
+
161
+ def main():
162
+ if not os.path.exists("tinyshakespeare.txt"):
163
+ url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
164
+ urllib.request.urlretrieve(url, "tinyshakespeare.txt")
165
+
166
+ with open("tinyshakespeare.txt", "r") as f:
167
+ text = f.read()
168
+
169
+ chars = sorted(list(set(text)))
170
+ vocab_size = len(chars)
171
+ char_to_ix = {ch: i for i, ch in enumerate(chars)}
172
+ ix_to_char = {i: ch for i, ch in enumerate(chars)}
173
+
174
+ device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
175
+ print(f"Using device: {device}")
176
+
177
+ # 256 dims, 4 layers, 4 heads
178
+ model = BitGPT(vocab_size, embed_dim=256, num_layers=4, num_heads=4).to(device)
179
+
180
+ # Count parameters
181
+ params = sum(p.numel() for p in model.parameters())
182
+ print(f"Total Model Parameters: {params}")
183
+ print(f"File size if 16-bit Float: {params * 2 / 1024 / 1024:.2f} MB")
184
+ print(f"File size at 1.58 bits: {params * 1.58 / 8 / 1024 / 1024:.2f} MB")
185
+
186
+ # BitNet requires a slightly higher learning rate
187
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=0.01)
188
+ criterion = nn.CrossEntropyLoss()
189
+
190
+ batch_size = 128
191
+ seq_length = 64
192
+
193
+ print("Training the 1.58b BitNet Transformer...")
194
+ for step in range(600):
195
+ x, y = get_batch(text, seq_length, batch_size, char_to_ix)
196
+ x, y = x.to(device), y.to(device)
197
+
198
+ optimizer.zero_grad()
199
+ logits = model(x)
200
+
201
+ loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
202
+ loss.backward()
203
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
204
+ optimizer.step()
205
+
206
+ if step % 50 == 0:
207
+ perplexity = torch.exp(loss)
208
+ print(f"Step {step} | Loss: {loss.item():.4f} | Perplexity: {perplexity.item():.4f}")
209
+
210
+ print("Saving model weights to bitnet_model.pt...")
211
+ torch.save(model.state_dict(), "bitnet_model.pt")
212
+
213
+ print("\n--- GENERATING TEXT ---")
214
+ model.eval()
215
+ with torch.no_grad():
216
+ x = torch.tensor([[char_to_ix['T']]]).to(device)
217
+ out_text = 'T'
218
+ for _ in range(300):
219
+ logits = model(x)
220
+ # Take the last logit
221
+ logits = logits[:, -1, :]
222
+ probs = F.softmax(logits / 0.8, dim=-1)
223
+ next_ix = torch.multinomial(probs, 1).item()
224
+ out_text += ix_to_char[next_ix]
225
+
226
+ # Append next char
227
+ next_tensor = torch.tensor([[next_ix]]).to(device)
228
+ x = torch.cat([x, next_tensor], dim=1)
229
+ # Truncate context if too long
230
+ if x.size(1) > seq_length:
231
+ x = x[:, -seq_length:]
232
+
233
+ print(out_text)
234
+
235
+ if __name__ == '__main__':
236
+ main()