import argparse import random import math import torch import torch.nn as nn import torch.nn.functional as F import json import numpy as np # ============================================================ # Model Hyperparameters # ============================================================ MODEL_LAYERS = 2 MODEL_DIM = 5 ATTENTION_HEADS = 2 KEY_VALUE_HEADS = 1 HEAD_DIM = 2 INTERMEDIATE_SIZE = 3 VOCAB_SIZE = 10 OUTPUT_DIGITS = 11 MAX_ADDEND = 10**10 - 1 # ============================================================ # Layers # ============================================================ class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): norm = torch.mean(x ** 2, dim=-1, keepdim=True) return (x / torch.sqrt(norm + self.eps)) * self.weight class RoPE(nn.Module): def __init__(self, dim, base=10000.0): super().__init__() self.dim = dim self.base = base def forward(self, x, seq_len): inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(x.device) t = torch.arange(seq_len, device=x.device).float() freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos()[None, :, None, :] sin = emb.sin()[None, :, None, :] x_half1, x_half2 = x.chunk(2, dim=-1) x_rotated = torch.cat((-x_half2, x_half1), dim=-1) return (x * cos) + (x_rotated * sin) class Attention(nn.Module): def __init__(self): super().__init__() self.q_proj = nn.Linear(MODEL_DIM, ATTENTION_HEADS * HEAD_DIM, bias=False) self.k_proj = nn.Linear(MODEL_DIM, KEY_VALUE_HEADS * HEAD_DIM, bias=False) self.v_proj = nn.Linear(MODEL_DIM, KEY_VALUE_HEADS * HEAD_DIM, bias=False) self.o_proj = nn.Linear(ATTENTION_HEADS * HEAD_DIM, MODEL_DIM, bias=False) self.q_norm = RMSNorm(HEAD_DIM) self.k_norm = RMSNorm(HEAD_DIM) self.rope = RoPE(HEAD_DIM) def forward(self, x): B, L, _ = x.shape q = self.q_proj(x).view(B, L, ATTENTION_HEADS, HEAD_DIM) k = self.k_proj(x).view(B, L, KEY_VALUE_HEADS, HEAD_DIM) v = self.v_proj(x).view(B, L, KEY_VALUE_HEADS, HEAD_DIM) q = self.q_norm(q) k = self.k_norm(k) q = self.rope(q, L) k = self.rope(k, L) k = k.expand(B, L, ATTENTION_HEADS, HEAD_DIM) v = v.expand(B, L, ATTENTION_HEADS, HEAD_DIM) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) scores = (q @ k.transpose(-2, -1)) / math.sqrt(HEAD_DIM) mask = torch.tril(torch.ones((L, L), device=x.device)).unsqueeze(0).unsqueeze(0) == 1 scores = scores.masked_fill(~mask, float('-inf')) probs = F.softmax(scores, dim=-1) out = (probs @ v).transpose(1, 2).contiguous().view(B, L, ATTENTION_HEADS * HEAD_DIM) return self.o_proj(out) class MLP(nn.Module): def __init__(self): super().__init__() self.gate_proj = nn.Linear(MODEL_DIM, INTERMEDIATE_SIZE, bias=False) self.up_proj = nn.Linear(MODEL_DIM, INTERMEDIATE_SIZE, bias=False) self.down_proj = nn.Linear(INTERMEDIATE_SIZE, MODEL_DIM, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class Layer(nn.Module): def __init__(self): super().__init__() self.input_layernorm = RMSNorm(MODEL_DIM) self.post_attention_layernorm = RMSNorm(MODEL_DIM) self.self_attn = Attention() self.mlp = MLP() def forward(self, x): x = x + self.self_attn(self.input_layernorm(x)) x = x + self.mlp(self.post_attention_layernorm(x)) return x class Model(nn.Module): def __init__(self): super().__init__() self.embed_tokens = nn.Embedding(VOCAB_SIZE, MODEL_DIM) self.layers = nn.ModuleList([Layer() for _ in range(MODEL_LAYERS)]) self.norm = RMSNorm(MODEL_DIM) self.lm_head = nn.Linear(MODEL_DIM, VOCAB_SIZE, bias=False) def forward(self, input_ids): x = self.embed_tokens(input_ids) for layer in self.layers: x = layer(x) x = self.norm(x) return self.lm_head(x) # ============================================================ # Helper Functions # ============================================================ def _validate_addends(a, b): if not isinstance(a, int) or not isinstance(b, int): raise ValueError("a and b must be ints") if a < 0 or a > MAX_ADDEND or b < 0 or b > MAX_ADDEND: raise ValueError(f"a and b must be in [0, {MAX_ADDEND}]") def _encode_addends_internal(a, b): _validate_addends(a, b) prompt = f"{a:010d}{b:010d}" a_digits = [int(c) for c in prompt[:10]] b_digits = [int(c) for c in prompt[10:]] return [0] + list(reversed(a_digits)) + [0] + [0] + list(reversed(b_digits)) + [0] def _expected_output(a, b): _validate_addends(a, b) return str(a + b)[::-1].ljust(OUTPUT_DIGITS, "0") # ============================================================ # Load weights from JSON # ============================================================ def load_weights_from_json(model: nn.Module, path: str): with open(path, "r") as f: data = json.load(f) def set_param(module, key_list, value): if len(key_list) == 1: setattr(module, key_list[0], nn.Parameter(torch.tensor(value, dtype=torch.float32))) else: attr = getattr(module, key_list[0]) set_param(attr, key_list[1:], value) for key, value in data.items(): key_list = key.split(".") set_param(model, key_list, value) # ============================================================ # Model Building # ============================================================ def build_model_from_json(weights_path, device): model = Model() load_weights_from_json(model, weights_path) model.to(device) model.eval() return model # ============================================================ # Batch generation & self-test # ============================================================ def _generate_output_batch(model, addends, device): internal = [_encode_addends_internal(a, b) for a, b in addends] with torch.no_grad(): for _ in range(OUTPUT_DIGITS): x = torch.tensor(internal, dtype=torch.long, device=device) logits = model(x) next_digits = logits[:, -1, :].argmax(dim=-1).cpu().numpy() for seq, next_digit in zip(internal, next_digits): seq.append(int(next_digit)) return ["".join(str(d) for d in seq[-OUTPUT_DIGITS:]) for seq in internal] def run_self_test_batched(model, num_tests, batch_size, device): rng = random.Random(123) tested = 0 correct = 0 while tested < num_tests: cur_batch_size = min(batch_size, num_tests - tested) addends = [] expected = [] for _ in range(cur_batch_size): a = rng.randint(0, MAX_ADDEND) b = rng.randint(0, MAX_ADDEND) addends.append((a, b)) expected.append(_expected_output(a, b)) actual = _generate_output_batch(model, addends, device) for (_, _), exp, act in zip(addends, expected, actual): if act == exp: correct += 1 tested += cur_batch_size print(f"self-test progress: {tested}/{num_tests}") return (correct / num_tests) * 100 if num_tests > 0 else 0 def count_parameters(model): return sum(p.numel() for p in model.parameters()) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--weights", type=str, default="weights.json") parser.add_argument("--num-tests", type=int, default=8192) parser.add_argument("--batch-size", type=int, default=1024) args, _ = parser.parse_known_args() # Modified line return args def main(): args = parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") model = build_model_from_json(args.weights, device) print(f"parameter count: {count_parameters(model)}") accuracy = run_self_test_batched(model, args.num_tests, args.batch_size, device) print(f"self-test passed ({args.num_tests} random cases, batch size {args.batch_size})") print(f"Model Accuracy: {accuracy:.2f}%") if __name__ == "__main__": main()