File size: 8,304 Bytes
f5c2b6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import urllib.request
# --- BITNET 1.58b IMPLEMENTATION ---
class RoundWithSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class BitLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=False):
super(BitLinear, self).__init__(in_features, out_features, bias=bias)
def forward(self, x):
# 1. Weight Quantization to {-1, 0, 1}
# scale = mean of absolute values
weight_scale = self.weight.abs().mean()
# Scale, round, clamp
w_scaled = self.weight / (weight_scale + 1e-8)
w_quant = torch.clamp(RoundWithSTE.apply(w_scaled), -1.0, 1.0)
# 2. Activation Quantization to 8-bit [-128, 127]
# per-token max absolute value
x_scale = x.abs().max(dim=-1, keepdim=True)[0]
x_scaled = x * (127.0 / (x_scale + 1e-8))
x_quant = torch.clamp(RoundWithSTE.apply(x_scaled), -128.0, 127.0)
# 3. Linear Projection (in a real engine this is int8 * int2 bitwise ops)
out = F.linear(x_quant, w_quant)
# 4. Dequantize
out = out * (weight_scale * x_scale / 127.0)
if self.bias is not None:
out += self.bias
return out
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x
class BitTransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.ln_1 = RMSNorm(embed_dim)
self.q_proj = BitLinear(embed_dim, embed_dim, bias=False)
self.k_proj = BitLinear(embed_dim, embed_dim, bias=False)
self.v_proj = BitLinear(embed_dim, embed_dim, bias=False)
self.o_proj = BitLinear(embed_dim, embed_dim, bias=False)
self.ln_2 = RMSNorm(embed_dim)
self.gate_proj = BitLinear(embed_dim, embed_dim * 4, bias=False)
self.up_proj = BitLinear(embed_dim, embed_dim * 4, bias=False)
self.down_proj = BitLinear(embed_dim * 4, embed_dim, bias=False)
def forward(self, x):
batch, seq, _ = x.shape
# --- Attention ---
norm_x = self.ln_1(x)
Q = self.q_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
mask = torch.triu(torch.ones(seq, seq, device=x.device), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))
attn = torch.softmax(scores, dim=-1)
context = torch.matmul(attn, V).transpose(1, 2).contiguous().view(batch, seq, self.embed_dim)
x = x + self.o_proj(context)
# --- SwiGLU FFN ---
norm_x2 = self.ln_2(x)
gate = F.silu(self.gate_proj(norm_x2))
up = self.up_proj(norm_x2)
x = x + self.down_proj(gate * up)
return x
class BitGPT(nn.Module):
def __init__(self, vocab_size, embed_dim, num_layers, num_heads, tie_weights=False):
super().__init__()
# Continuous embeddings are usually kept continuous in BitNet
self.vocab_embed = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Embedding(1024, embed_dim)
self.layers = nn.ModuleList([BitTransformerBlock(embed_dim, num_heads) for _ in range(num_layers)])
self.ln_f = RMSNorm(embed_dim)
# Head is usually continuous or bitlinear, we'll use BitLinear for maximum compression!
self.head = BitLinear(embed_dim, vocab_size, bias=False)
self.apply(self._init_weights)
if tie_weights:
self.head.weight = self.vocab_embed.weight
def _init_weights(self, module):
if isinstance(module, nn.Linear) or isinstance(module, BitLinear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x):
batch, seq = x.shape
pos = torch.arange(seq, device=x.device).unsqueeze(0)
x = self.vocab_embed(x) + self.pos_embed(pos)
for layer in self.layers:
x = layer(x)
x = self.ln_f(x)
return self.head(x)
def get_batch(text, seq_length, batch_size, char_to_ix):
ixs = torch.randint(0, len(text) - seq_length - 1, (batch_size,))
x = torch.zeros(batch_size, seq_length, dtype=torch.long)
y = torch.zeros(batch_size, seq_length, dtype=torch.long)
for i, idx in enumerate(ixs):
x[i] = torch.tensor([char_to_ix[ch] for ch in text[idx:idx+seq_length]])
y[i] = torch.tensor([char_to_ix[ch] for ch in text[idx+1:idx+seq_length+1]])
return x, y
def main():
if not os.path.exists("tinyshakespeare.txt"):
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, "tinyshakespeare.txt")
with open("tinyshakespeare.txt", "r") as f:
text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_ix = {ch: i for i, ch in enumerate(chars)}
ix_to_char = {i: ch for i, ch in enumerate(chars)}
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
# 256 dims, 4 layers, 4 heads
model = BitGPT(vocab_size, embed_dim=256, num_layers=4, num_heads=4).to(device)
# Count parameters
params = sum(p.numel() for p in model.parameters())
print(f"Total Model Parameters: {params}")
print(f"File size if 16-bit Float: {params * 2 / 1024 / 1024:.2f} MB")
print(f"File size at 1.58 bits: {params * 1.58 / 8 / 1024 / 1024:.2f} MB")
# BitNet requires a slightly higher learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
batch_size = 128
seq_length = 64
print("Training the 1.58b BitNet Transformer...")
for step in range(600):
x, y = get_batch(text, seq_length, batch_size, char_to_ix)
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if step % 50 == 0:
perplexity = torch.exp(loss)
print(f"Step {step} | Loss: {loss.item():.4f} | Perplexity: {perplexity.item():.4f}")
print("Saving model weights to bitnet_model.pt...")
torch.save(model.state_dict(), "bitnet_model.pt")
print("\n--- GENERATING TEXT ---")
model.eval()
with torch.no_grad():
x = torch.tensor([[char_to_ix['T']]]).to(device)
out_text = 'T'
for _ in range(300):
logits = model(x)
# Take the last logit
logits = logits[:, -1, :]
probs = F.softmax(logits / 0.8, dim=-1)
next_ix = torch.multinomial(probs, 1).item()
out_text += ix_to_char[next_ix]
# Append next char
next_tensor = torch.tensor([[next_ix]]).to(device)
x = torch.cat([x, next_tensor], dim=1)
# Truncate context if too long
if x.size(1) > seq_length:
x = x[:, -seq_length:]
print(out_text)
if __name__ == '__main__':
main()
|