| import torch
|
| from safetensors.torch import save_file
|
|
|
|
|
|
|
|
|
|
|
| def xor_block(prefix):
|
| return {
|
| f'{prefix}.or.weight': torch.tensor([1.0, 1.0], dtype=torch.float32),
|
| f'{prefix}.or.bias': torch.tensor([-1.0], dtype=torch.float32),
|
| f'{prefix}.nand.weight': torch.tensor([-1.0, -1.0], dtype=torch.float32),
|
| f'{prefix}.nand.bias': torch.tensor([1.0], dtype=torch.float32),
|
| f'{prefix}.and.weight': torch.tensor([1.0, 1.0], dtype=torch.float32),
|
| f'{prefix}.and.bias': torch.tensor([-2.0], dtype=torch.float32),
|
| }
|
|
|
| weights = {}
|
|
|
| weights.update(xor_block('xor_01'))
|
| weights.update(xor_block('xor_23'))
|
| weights.update(xor_block('xor_45'))
|
|
|
| weights.update(xor_block('xor_0123'))
|
| weights.update(xor_block('xor_456'))
|
|
|
| weights.update(xor_block('xor_final'))
|
|
|
| save_file(weights, 'model.safetensors')
|
|
|
| def xor2(a, b, prefix):
|
| or_out = int(a * weights[f'{prefix}.or.weight'][0] + b * weights[f'{prefix}.or.weight'][1] + weights[f'{prefix}.or.bias'] >= 0)
|
| nand_out = int(a * weights[f'{prefix}.nand.weight'][0] + b * weights[f'{prefix}.nand.weight'][1] + weights[f'{prefix}.nand.bias'] >= 0)
|
| return int(or_out * weights[f'{prefix}.and.weight'][0] + nand_out * weights[f'{prefix}.and.weight'][1] + weights[f'{prefix}.and.bias'] >= 0)
|
|
|
| def parity7(x0, x1, x2, x3, x4, x5, x6):
|
| xor01 = xor2(x0, x1, 'xor_01')
|
| xor23 = xor2(x2, x3, 'xor_23')
|
| xor45 = xor2(x4, x5, 'xor_45')
|
| xor0123 = xor2(xor01, xor23, 'xor_0123')
|
| xor456 = xor2(xor45, x6, 'xor_456')
|
| return xor2(xor0123, xor456, 'xor_final')
|
|
|
| print("Verifying parity7...")
|
| errors = 0
|
| for i in range(128):
|
| bits = [(i >> j) & 1 for j in range(7)]
|
| result = parity7(*bits)
|
| expected = sum(bits) % 2
|
| if result != expected:
|
| errors += 1
|
| print(f"ERROR: parity({bits}) = {result}, expected {expected}")
|
|
|
| if errors == 0:
|
| print("All 128 test cases passed!")
|
| else:
|
| print(f"FAILED: {errors} errors")
|
|
|
| print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")
|
| print(f"Parameters: {sum(t.numel() for t in weights.values())}")
|
|
|