handcraftparameters / handcraftparameters.py
Xhub1880's picture
Upload 2 files
f93cca0 verified
import argparse
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import numpy as np
# ============================================================
# Model Hyperparameters
# ============================================================
MODEL_LAYERS = 2
MODEL_DIM = 5
ATTENTION_HEADS = 2
KEY_VALUE_HEADS = 1
HEAD_DIM = 2
INTERMEDIATE_SIZE = 3
VOCAB_SIZE = 10
OUTPUT_DIGITS = 11
MAX_ADDEND = 10**10 - 1
# ============================================================
# Layers
# ============================================================
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
norm = torch.mean(x ** 2, dim=-1, keepdim=True)
return (x / torch.sqrt(norm + self.eps)) * self.weight
class RoPE(nn.Module):
def __init__(self, dim, base=10000.0):
super().__init__()
self.dim = dim
self.base = base
def forward(self, x, seq_len):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(x.device)
t = torch.arange(seq_len, device=x.device).float()
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()[None, :, None, :]
sin = emb.sin()[None, :, None, :]
x_half1, x_half2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x_half2, x_half1), dim=-1)
return (x * cos) + (x_rotated * sin)
class Attention(nn.Module):
def __init__(self):
super().__init__()
self.q_proj = nn.Linear(MODEL_DIM, ATTENTION_HEADS * HEAD_DIM, bias=False)
self.k_proj = nn.Linear(MODEL_DIM, KEY_VALUE_HEADS * HEAD_DIM, bias=False)
self.v_proj = nn.Linear(MODEL_DIM, KEY_VALUE_HEADS * HEAD_DIM, bias=False)
self.o_proj = nn.Linear(ATTENTION_HEADS * HEAD_DIM, MODEL_DIM, bias=False)
self.q_norm = RMSNorm(HEAD_DIM)
self.k_norm = RMSNorm(HEAD_DIM)
self.rope = RoPE(HEAD_DIM)
def forward(self, x):
B, L, _ = x.shape
q = self.q_proj(x).view(B, L, ATTENTION_HEADS, HEAD_DIM)
k = self.k_proj(x).view(B, L, KEY_VALUE_HEADS, HEAD_DIM)
v = self.v_proj(x).view(B, L, KEY_VALUE_HEADS, HEAD_DIM)
q = self.q_norm(q)
k = self.k_norm(k)
q = self.rope(q, L)
k = self.rope(k, L)
k = k.expand(B, L, ATTENTION_HEADS, HEAD_DIM)
v = v.expand(B, L, ATTENTION_HEADS, HEAD_DIM)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(HEAD_DIM)
mask = torch.tril(torch.ones((L, L), device=x.device)).unsqueeze(0).unsqueeze(0) == 1
scores = scores.masked_fill(~mask, float('-inf'))
probs = F.softmax(scores, dim=-1)
out = (probs @ v).transpose(1, 2).contiguous().view(B, L, ATTENTION_HEADS * HEAD_DIM)
return self.o_proj(out)
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(MODEL_DIM, INTERMEDIATE_SIZE, bias=False)
self.up_proj = nn.Linear(MODEL_DIM, INTERMEDIATE_SIZE, bias=False)
self.down_proj = nn.Linear(INTERMEDIATE_SIZE, MODEL_DIM, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class Layer(nn.Module):
def __init__(self):
super().__init__()
self.input_layernorm = RMSNorm(MODEL_DIM)
self.post_attention_layernorm = RMSNorm(MODEL_DIM)
self.self_attn = Attention()
self.mlp = MLP()
def forward(self, x):
x = x + self.self_attn(self.input_layernorm(x))
x = x + self.mlp(self.post_attention_layernorm(x))
return x
class Model(nn.Module):
def __init__(self):
super().__init__()
self.embed_tokens = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
self.layers = nn.ModuleList([Layer() for _ in range(MODEL_LAYERS)])
self.norm = RMSNorm(MODEL_DIM)
self.lm_head = nn.Linear(MODEL_DIM, VOCAB_SIZE, bias=False)
def forward(self, input_ids):
x = self.embed_tokens(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
return self.lm_head(x)
# ============================================================
# Helper Functions
# ============================================================
def _validate_addends(a, b):
if not isinstance(a, int) or not isinstance(b, int):
raise ValueError("a and b must be ints")
if a < 0 or a > MAX_ADDEND or b < 0 or b > MAX_ADDEND:
raise ValueError(f"a and b must be in [0, {MAX_ADDEND}]")
def _encode_addends_internal(a, b):
_validate_addends(a, b)
prompt = f"{a:010d}{b:010d}"
a_digits = [int(c) for c in prompt[:10]]
b_digits = [int(c) for c in prompt[10:]]
return [0] + list(reversed(a_digits)) + [0] + [0] + list(reversed(b_digits)) + [0]
def _expected_output(a, b):
_validate_addends(a, b)
return str(a + b)[::-1].ljust(OUTPUT_DIGITS, "0")
# ============================================================
# Load weights from JSON
# ============================================================
def load_weights_from_json(model: nn.Module, path: str):
with open(path, "r") as f:
data = json.load(f)
def set_param(module, key_list, value):
if len(key_list) == 1:
setattr(module, key_list[0], nn.Parameter(torch.tensor(value, dtype=torch.float32)))
else:
attr = getattr(module, key_list[0])
set_param(attr, key_list[1:], value)
for key, value in data.items():
key_list = key.split(".")
set_param(model, key_list, value)
# ============================================================
# Model Building
# ============================================================
def build_model_from_json(weights_path, device):
model = Model()
load_weights_from_json(model, weights_path)
model.to(device)
model.eval()
return model
# ============================================================
# Batch generation & self-test
# ============================================================
def _generate_output_batch(model, addends, device):
internal = [_encode_addends_internal(a, b) for a, b in addends]
with torch.no_grad():
for _ in range(OUTPUT_DIGITS):
x = torch.tensor(internal, dtype=torch.long, device=device)
logits = model(x)
next_digits = logits[:, -1, :].argmax(dim=-1).cpu().numpy()
for seq, next_digit in zip(internal, next_digits):
seq.append(int(next_digit))
return ["".join(str(d) for d in seq[-OUTPUT_DIGITS:]) for seq in internal]
def run_self_test_batched(model, num_tests, batch_size, device):
rng = random.Random(123)
tested = 0
correct = 0
while tested < num_tests:
cur_batch_size = min(batch_size, num_tests - tested)
addends = []
expected = []
for _ in range(cur_batch_size):
a = rng.randint(0, MAX_ADDEND)
b = rng.randint(0, MAX_ADDEND)
addends.append((a, b))
expected.append(_expected_output(a, b))
actual = _generate_output_batch(model, addends, device)
for (_, _), exp, act in zip(addends, expected, actual):
if act == exp:
correct += 1
tested += cur_batch_size
print(f"self-test progress: {tested}/{num_tests}")
return (correct / num_tests) * 100 if num_tests > 0 else 0
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--weights", type=str, default="weights.json")
parser.add_argument("--num-tests", type=int, default=8192)
parser.add_argument("--batch-size", type=int, default=1024)
args, _ = parser.parse_known_args() # Modified line
return args
def main():
args = parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = build_model_from_json(args.weights, device)
print(f"parameter count: {count_parameters(model)}")
accuracy = run_self_test_batched(model, args.num_tests, args.batch_size, device)
print(f"self-test passed ({args.num_tests} random cases, batch size {args.batch_size})")
print(f"Model Accuracy: {accuracy:.2f}%")
if __name__ == "__main__":
main()