MiniLM / bitnet_test.py
0sparsh2's picture
Upload bitnet_test.py with huggingface_hub
f5c2b6c verified
Raw
History Blame Contribute Delete
8.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import urllib.request
# --- BITNET 1.58b IMPLEMENTATION ---
class RoundWithSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class BitLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=False):
super(BitLinear, self).__init__(in_features, out_features, bias=bias)
def forward(self, x):
# 1. Weight Quantization to {-1, 0, 1}
# scale = mean of absolute values
weight_scale = self.weight.abs().mean()
# Scale, round, clamp
w_scaled = self.weight / (weight_scale + 1e-8)
w_quant = torch.clamp(RoundWithSTE.apply(w_scaled), -1.0, 1.0)
# 2. Activation Quantization to 8-bit [-128, 127]
# per-token max absolute value
x_scale = x.abs().max(dim=-1, keepdim=True)[0]
x_scaled = x * (127.0 / (x_scale + 1e-8))
x_quant = torch.clamp(RoundWithSTE.apply(x_scaled), -128.0, 127.0)
# 3. Linear Projection (in a real engine this is int8 * int2 bitwise ops)
out = F.linear(x_quant, w_quant)
# 4. Dequantize
out = out * (weight_scale * x_scale / 127.0)
if self.bias is not None:
out += self.bias
return out
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x
class BitTransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.ln_1 = RMSNorm(embed_dim)
self.q_proj = BitLinear(embed_dim, embed_dim, bias=False)
self.k_proj = BitLinear(embed_dim, embed_dim, bias=False)
self.v_proj = BitLinear(embed_dim, embed_dim, bias=False)
self.o_proj = BitLinear(embed_dim, embed_dim, bias=False)
self.ln_2 = RMSNorm(embed_dim)
self.gate_proj = BitLinear(embed_dim, embed_dim * 4, bias=False)
self.up_proj = BitLinear(embed_dim, embed_dim * 4, bias=False)
self.down_proj = BitLinear(embed_dim * 4, embed_dim, bias=False)
def forward(self, x):
batch, seq, _ = x.shape
# --- Attention ---
norm_x = self.ln_1(x)
Q = self.q_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(norm_x).view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
mask = torch.triu(torch.ones(seq, seq, device=x.device), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))
attn = torch.softmax(scores, dim=-1)
context = torch.matmul(attn, V).transpose(1, 2).contiguous().view(batch, seq, self.embed_dim)
x = x + self.o_proj(context)
# --- SwiGLU FFN ---
norm_x2 = self.ln_2(x)
gate = F.silu(self.gate_proj(norm_x2))
up = self.up_proj(norm_x2)
x = x + self.down_proj(gate * up)
return x
class BitGPT(nn.Module):
def __init__(self, vocab_size, embed_dim, num_layers, num_heads, tie_weights=False):
super().__init__()
# Continuous embeddings are usually kept continuous in BitNet
self.vocab_embed = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Embedding(1024, embed_dim)
self.layers = nn.ModuleList([BitTransformerBlock(embed_dim, num_heads) for _ in range(num_layers)])
self.ln_f = RMSNorm(embed_dim)
# Head is usually continuous or bitlinear, we'll use BitLinear for maximum compression!
self.head = BitLinear(embed_dim, vocab_size, bias=False)
self.apply(self._init_weights)
if tie_weights:
self.head.weight = self.vocab_embed.weight
def _init_weights(self, module):
if isinstance(module, nn.Linear) or isinstance(module, BitLinear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x):
batch, seq = x.shape
pos = torch.arange(seq, device=x.device).unsqueeze(0)
x = self.vocab_embed(x) + self.pos_embed(pos)
for layer in self.layers:
x = layer(x)
x = self.ln_f(x)
return self.head(x)
def get_batch(text, seq_length, batch_size, char_to_ix):
ixs = torch.randint(0, len(text) - seq_length - 1, (batch_size,))
x = torch.zeros(batch_size, seq_length, dtype=torch.long)
y = torch.zeros(batch_size, seq_length, dtype=torch.long)
for i, idx in enumerate(ixs):
x[i] = torch.tensor([char_to_ix[ch] for ch in text[idx:idx+seq_length]])
y[i] = torch.tensor([char_to_ix[ch] for ch in text[idx+1:idx+seq_length+1]])
return x, y
def main():
if not os.path.exists("tinyshakespeare.txt"):
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, "tinyshakespeare.txt")
with open("tinyshakespeare.txt", "r") as f:
text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_ix = {ch: i for i, ch in enumerate(chars)}
ix_to_char = {i: ch for i, ch in enumerate(chars)}
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
# 256 dims, 4 layers, 4 heads
model = BitGPT(vocab_size, embed_dim=256, num_layers=4, num_heads=4).to(device)
# Count parameters
params = sum(p.numel() for p in model.parameters())
print(f"Total Model Parameters: {params}")
print(f"File size if 16-bit Float: {params * 2 / 1024 / 1024:.2f} MB")
print(f"File size at 1.58 bits: {params * 1.58 / 8 / 1024 / 1024:.2f} MB")
# BitNet requires a slightly higher learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
batch_size = 128
seq_length = 64
print("Training the 1.58b BitNet Transformer...")
for step in range(600):
x, y = get_batch(text, seq_length, batch_size, char_to_ix)
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if step % 50 == 0:
perplexity = torch.exp(loss)
print(f"Step {step} | Loss: {loss.item():.4f} | Perplexity: {perplexity.item():.4f}")
print("Saving model weights to bitnet_model.pt...")
torch.save(model.state_dict(), "bitnet_model.pt")
print("\n--- GENERATING TEXT ---")
model.eval()
with torch.no_grad():
x = torch.tensor([[char_to_ix['T']]]).to(device)
out_text = 'T'
for _ in range(300):
logits = model(x)
# Take the last logit
logits = logits[:, -1, :]
probs = F.softmax(logits / 0.8, dim=-1)
next_ix = torch.multinomial(probs, 1).item()
out_text += ix_to_char[next_ix]
# Append next char
next_tensor = torch.tensor([[next_ix]]).to(device)
x = torch.cat([x, next_tensor], dim=1)
# Truncate context if too long
if x.size(1) > seq_length:
x = x[:, -seq_length:]
print(out_text)
if __name__ == '__main__':
main()