| | import argparse |
| | import random |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import json |
| | import numpy as np |
| |
|
| | |
| | |
| | |
| | MODEL_LAYERS = 2 |
| | MODEL_DIM = 5 |
| | ATTENTION_HEADS = 2 |
| | KEY_VALUE_HEADS = 1 |
| | HEAD_DIM = 2 |
| | INTERMEDIATE_SIZE = 3 |
| | VOCAB_SIZE = 10 |
| | OUTPUT_DIGITS = 11 |
| | MAX_ADDEND = 10**10 - 1 |
| |
|
| | |
| | |
| | |
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, dim, eps=1e-6): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| | self.eps = eps |
| |
|
| | def forward(self, x): |
| | norm = torch.mean(x ** 2, dim=-1, keepdim=True) |
| | return (x / torch.sqrt(norm + self.eps)) * self.weight |
| |
|
| | class RoPE(nn.Module): |
| | def __init__(self, dim, base=10000.0): |
| | super().__init__() |
| | self.dim = dim |
| | self.base = base |
| |
|
| | def forward(self, x, seq_len): |
| | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(x.device) |
| | t = torch.arange(seq_len, device=x.device).float() |
| | freqs = torch.outer(t, inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | cos = emb.cos()[None, :, None, :] |
| | sin = emb.sin()[None, :, None, :] |
| | x_half1, x_half2 = x.chunk(2, dim=-1) |
| | x_rotated = torch.cat((-x_half2, x_half1), dim=-1) |
| | return (x * cos) + (x_rotated * sin) |
| |
|
| | class Attention(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.q_proj = nn.Linear(MODEL_DIM, ATTENTION_HEADS * HEAD_DIM, bias=False) |
| | self.k_proj = nn.Linear(MODEL_DIM, KEY_VALUE_HEADS * HEAD_DIM, bias=False) |
| | self.v_proj = nn.Linear(MODEL_DIM, KEY_VALUE_HEADS * HEAD_DIM, bias=False) |
| | self.o_proj = nn.Linear(ATTENTION_HEADS * HEAD_DIM, MODEL_DIM, bias=False) |
| | self.q_norm = RMSNorm(HEAD_DIM) |
| | self.k_norm = RMSNorm(HEAD_DIM) |
| | self.rope = RoPE(HEAD_DIM) |
| |
|
| | def forward(self, x): |
| | B, L, _ = x.shape |
| | q = self.q_proj(x).view(B, L, ATTENTION_HEADS, HEAD_DIM) |
| | k = self.k_proj(x).view(B, L, KEY_VALUE_HEADS, HEAD_DIM) |
| | v = self.v_proj(x).view(B, L, KEY_VALUE_HEADS, HEAD_DIM) |
| | q = self.q_norm(q) |
| | k = self.k_norm(k) |
| | q = self.rope(q, L) |
| | k = self.rope(k, L) |
| | k = k.expand(B, L, ATTENTION_HEADS, HEAD_DIM) |
| | v = v.expand(B, L, ATTENTION_HEADS, HEAD_DIM) |
| | q = q.transpose(1, 2) |
| | k = k.transpose(1, 2) |
| | v = v.transpose(1, 2) |
| | scores = (q @ k.transpose(-2, -1)) / math.sqrt(HEAD_DIM) |
| | mask = torch.tril(torch.ones((L, L), device=x.device)).unsqueeze(0).unsqueeze(0) == 1 |
| | scores = scores.masked_fill(~mask, float('-inf')) |
| | probs = F.softmax(scores, dim=-1) |
| | out = (probs @ v).transpose(1, 2).contiguous().view(B, L, ATTENTION_HEADS * HEAD_DIM) |
| | return self.o_proj(out) |
| |
|
| | class MLP(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.gate_proj = nn.Linear(MODEL_DIM, INTERMEDIATE_SIZE, bias=False) |
| | self.up_proj = nn.Linear(MODEL_DIM, INTERMEDIATE_SIZE, bias=False) |
| | self.down_proj = nn.Linear(INTERMEDIATE_SIZE, MODEL_DIM, bias=False) |
| |
|
| | def forward(self, x): |
| | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
| |
|
| | class Layer(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.input_layernorm = RMSNorm(MODEL_DIM) |
| | self.post_attention_layernorm = RMSNorm(MODEL_DIM) |
| | self.self_attn = Attention() |
| | self.mlp = MLP() |
| |
|
| | def forward(self, x): |
| | x = x + self.self_attn(self.input_layernorm(x)) |
| | x = x + self.mlp(self.post_attention_layernorm(x)) |
| | return x |
| |
|
| | class Model(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.embed_tokens = nn.Embedding(VOCAB_SIZE, MODEL_DIM) |
| | self.layers = nn.ModuleList([Layer() for _ in range(MODEL_LAYERS)]) |
| | self.norm = RMSNorm(MODEL_DIM) |
| | self.lm_head = nn.Linear(MODEL_DIM, VOCAB_SIZE, bias=False) |
| |
|
| | def forward(self, input_ids): |
| | x = self.embed_tokens(input_ids) |
| | for layer in self.layers: |
| | x = layer(x) |
| | x = self.norm(x) |
| | return self.lm_head(x) |
| |
|
| | |
| | |
| | |
| |
|
| | def _validate_addends(a, b): |
| | if not isinstance(a, int) or not isinstance(b, int): |
| | raise ValueError("a and b must be ints") |
| | if a < 0 or a > MAX_ADDEND or b < 0 or b > MAX_ADDEND: |
| | raise ValueError(f"a and b must be in [0, {MAX_ADDEND}]") |
| |
|
| | def _encode_addends_internal(a, b): |
| | _validate_addends(a, b) |
| | prompt = f"{a:010d}{b:010d}" |
| | a_digits = [int(c) for c in prompt[:10]] |
| | b_digits = [int(c) for c in prompt[10:]] |
| | return [0] + list(reversed(a_digits)) + [0] + [0] + list(reversed(b_digits)) + [0] |
| |
|
| | def _expected_output(a, b): |
| | _validate_addends(a, b) |
| | return str(a + b)[::-1].ljust(OUTPUT_DIGITS, "0") |
| |
|
| | |
| | |
| | |
| |
|
| | def load_weights_from_json(model: nn.Module, path: str): |
| | with open(path, "r") as f: |
| | data = json.load(f) |
| |
|
| | def set_param(module, key_list, value): |
| | if len(key_list) == 1: |
| | setattr(module, key_list[0], nn.Parameter(torch.tensor(value, dtype=torch.float32))) |
| | else: |
| | attr = getattr(module, key_list[0]) |
| | set_param(attr, key_list[1:], value) |
| |
|
| | for key, value in data.items(): |
| | key_list = key.split(".") |
| | set_param(model, key_list, value) |
| |
|
| | |
| | |
| | |
| |
|
| | def build_model_from_json(weights_path, device): |
| | model = Model() |
| | load_weights_from_json(model, weights_path) |
| | model.to(device) |
| | model.eval() |
| | return model |
| |
|
| | |
| | |
| | |
| |
|
| | def _generate_output_batch(model, addends, device): |
| | internal = [_encode_addends_internal(a, b) for a, b in addends] |
| | with torch.no_grad(): |
| | for _ in range(OUTPUT_DIGITS): |
| | x = torch.tensor(internal, dtype=torch.long, device=device) |
| | logits = model(x) |
| | next_digits = logits[:, -1, :].argmax(dim=-1).cpu().numpy() |
| | for seq, next_digit in zip(internal, next_digits): |
| | seq.append(int(next_digit)) |
| | return ["".join(str(d) for d in seq[-OUTPUT_DIGITS:]) for seq in internal] |
| |
|
| | def run_self_test_batched(model, num_tests, batch_size, device): |
| | rng = random.Random(123) |
| | tested = 0 |
| | correct = 0 |
| | while tested < num_tests: |
| | cur_batch_size = min(batch_size, num_tests - tested) |
| | addends = [] |
| | expected = [] |
| | for _ in range(cur_batch_size): |
| | a = rng.randint(0, MAX_ADDEND) |
| | b = rng.randint(0, MAX_ADDEND) |
| | addends.append((a, b)) |
| | expected.append(_expected_output(a, b)) |
| |
|
| | actual = _generate_output_batch(model, addends, device) |
| | for (_, _), exp, act in zip(addends, expected, actual): |
| | if act == exp: |
| | correct += 1 |
| | tested += cur_batch_size |
| | print(f"self-test progress: {tested}/{num_tests}") |
| | return (correct / num_tests) * 100 if num_tests > 0 else 0 |
| |
|
| | def count_parameters(model): |
| | return sum(p.numel() for p in model.parameters()) |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--weights", type=str, default="weights.json") |
| | parser.add_argument("--num-tests", type=int, default=8192) |
| | parser.add_argument("--batch-size", type=int, default=1024) |
| | args, _ = parser.parse_known_args() |
| | return args |
| |
|
| | def main(): |
| | args = parse_args() |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {device}") |
| |
|
| | model = build_model_from_json(args.weights, device) |
| | print(f"parameter count: {count_parameters(model)}") |
| |
|
| | accuracy = run_self_test_batched(model, args.num_tests, args.batch_size, device) |
| | print(f"self-test passed ({args.num_tests} random cases, batch size {args.batch_size})") |
| | print(f"Model Accuracy: {accuracy:.2f}%") |
| |
|
| | if __name__ == "__main__": |
| | main() |