CharlesCNorton
commited on
Commit
·
df088a9
1
Parent(s):
a5b1777
Add 3-operand adder circuit (arithmetic.add3_8bit)
Browse files- build.py: add_full_adder() and add_add3() functions
- build.py: infer_add3_inputs() for routing metadata
- eval.py: _test_add3() with 240 test cases including 15+27+33=75
- Fitness 1.000000, all tests pass
- README.md +7 -5
- build.py +106 -1
- eval.py +110 -0
- neural_computer.safetensors +2 -2
README.md
CHANGED
|
@@ -457,15 +457,17 @@ The interface generalizes to **all** 65,536 8-bit additions once trained—no me
|
|
| 457 |
|
| 458 |
### Extension Roadmap
|
| 459 |
|
| 460 |
-
1. **
|
| 461 |
|
| 462 |
-
2. **
|
| 463 |
|
| 464 |
-
3. **
|
| 465 |
|
| 466 |
-
4. **
|
| 467 |
|
| 468 |
-
|
|
|
|
|
|
|
| 469 |
|
| 470 |
---
|
| 471 |
|
|
|
|
| 457 |
|
| 458 |
### Extension Roadmap
|
| 459 |
|
| 460 |
+
1. **Order of operations (5 + 3 × 2 = 11)** — Parse expression into tree, evaluate depth-first. MUL before ADD. Requires either: (a) expression parser producing evaluation order, or (b) learned routing that implicitly respects precedence.
|
| 461 |
|
| 462 |
+
2. **Parenthetical expressions ((5 + 3) × 2 = 16)** — Explicit grouping overrides precedence. Parser must recognize parens and build correct tree. Evaluation proceeds innermost-out. Adds complexity to extraction layer.
|
| 463 |
|
| 464 |
+
3. **16-bit operations (0-65535)** — Chain two 8-bit circuits with carry propagation. ADD16: low = ADD8(A_lo, B_lo), high = ADD8(A_hi, B_hi, carry_out). MUL16: four partial products + shift-add. Doubles operand extraction width.
|
| 465 |
|
| 466 |
+
4. **Floating point arithmetic** — IEEE 754-style with separate circuits for mantissa and exponent. ADD: align exponents, add mantissas, renormalize. MUL: add exponents, multiply mantissas. Requires sign handling, overflow detection, and rounding logic.
|
| 467 |
|
| 468 |
+
### Completed Extensions
|
| 469 |
+
|
| 470 |
+
- **3-operand addition (15 + 27 + 33 = 75)** — `arithmetic.add3_8bit` chains two 8-bit ripple carry stages. 16 full adders, 144 gates, 240 test cases verified.
|
| 471 |
|
| 472 |
---
|
| 473 |
|
build.py
CHANGED
|
@@ -235,6 +235,51 @@ def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor], addr_bits: in
|
|
| 235 |
add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
|
| 236 |
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def add_shl_shr(tensors: Dict[str, torch.Tensor]) -> None:
|
| 239 |
"""Add SHL (shift left) and SHR (shift right) circuits.
|
| 240 |
|
|
@@ -604,6 +649,58 @@ def infer_ripplecarry_inputs(gate: str, prefix: str, bits: int, reg: SignalRegis
|
|
| 604 |
return []
|
| 605 |
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
def infer_adcsbc_inputs(gate: str, prefix: str, is_sub: bool, reg: SignalRegistry) -> List[int]:
|
| 608 |
for i in range(8):
|
| 609 |
reg.register(f"{prefix}.$a[{i}]")
|
|
@@ -1080,6 +1177,8 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
|
|
| 1080 |
return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry4bit", 4, reg)
|
| 1081 |
if 'ripplecarry8bit' in gate:
|
| 1082 |
return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg)
|
|
|
|
|
|
|
| 1083 |
if 'adc8bit' in gate:
|
| 1084 |
return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg)
|
| 1085 |
if 'sbc8bit' in gate:
|
|
@@ -1305,7 +1404,7 @@ def cmd_alu(args) -> None:
|
|
| 1305 |
"alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.",
|
| 1306 |
"arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
|
| 1307 |
"arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
|
| 1308 |
-
"arithmetic.equality8bit.",
|
| 1309 |
"control.push.", "control.pop.", "control.ret.",
|
| 1310 |
"combinational.barrelshifter.", "combinational.priorityencoder.",
|
| 1311 |
])
|
|
@@ -1370,6 +1469,12 @@ def cmd_alu(args) -> None:
|
|
| 1370 |
print(" Added GT, GE, LT, LE (single-layer), EQ (two-layer)")
|
| 1371 |
except ValueError as e:
|
| 1372 |
print(f" Comparators already exist: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1373 |
if args.apply:
|
| 1374 |
print(f"\nSaving: {args.model}")
|
| 1375 |
save_file(tensors, str(args.model))
|
|
|
|
| 235 |
add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
|
| 236 |
|
| 237 |
|
| 238 |
+
def add_full_adder(tensors: Dict[str, torch.Tensor], prefix: str) -> None:
|
| 239 |
+
"""Add a single full adder at the given prefix.
|
| 240 |
+
|
| 241 |
+
Full adder structure:
|
| 242 |
+
- ha1: first half adder (A XOR B for sum, A AND B for carry)
|
| 243 |
+
- ha2: second half adder (ha1.sum XOR Cin for sum, ha1.sum AND Cin for carry)
|
| 244 |
+
- carry_or: OR of ha1.carry and ha2.carry for final carry out
|
| 245 |
+
"""
|
| 246 |
+
# XOR for ha1.sum (2-layer: OR + NAND -> AND)
|
| 247 |
+
add_gate(tensors, f"{prefix}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0])
|
| 248 |
+
add_gate(tensors, f"{prefix}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0])
|
| 249 |
+
add_gate(tensors, f"{prefix}.ha1.sum.layer2", [1.0, 1.0], [-2.0])
|
| 250 |
+
# AND for ha1.carry
|
| 251 |
+
add_gate(tensors, f"{prefix}.ha1.carry", [1.0, 1.0], [-2.0])
|
| 252 |
+
# XOR for ha2.sum
|
| 253 |
+
add_gate(tensors, f"{prefix}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0])
|
| 254 |
+
add_gate(tensors, f"{prefix}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0])
|
| 255 |
+
add_gate(tensors, f"{prefix}.ha2.sum.layer2", [1.0, 1.0], [-2.0])
|
| 256 |
+
# AND for ha2.carry
|
| 257 |
+
add_gate(tensors, f"{prefix}.ha2.carry", [1.0, 1.0], [-2.0])
|
| 258 |
+
# OR for final carry
|
| 259 |
+
add_gate(tensors, f"{prefix}.carry_or", [1.0, 1.0], [-1.0])
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def add_add3(tensors: Dict[str, torch.Tensor]) -> None:
|
| 263 |
+
"""Add 3-operand 8-bit adder circuit.
|
| 264 |
+
|
| 265 |
+
Computes A + B + C using two chained ripple-carry stages:
|
| 266 |
+
- Stage 1: temp = A + B (8 full adders)
|
| 267 |
+
- Stage 2: result = temp + C (8 full adders)
|
| 268 |
+
|
| 269 |
+
Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first)
|
| 270 |
+
Outputs: stage2.fa0-7.ha2.sum.layer2 (result bits), stage2.fa7.carry_or (overflow)
|
| 271 |
+
|
| 272 |
+
Total: 16 full adders = 144 gates
|
| 273 |
+
"""
|
| 274 |
+
# Stage 1: A + B -> temp
|
| 275 |
+
for bit in range(8):
|
| 276 |
+
add_full_adder(tensors, f"arithmetic.add3_8bit.stage1.fa{bit}")
|
| 277 |
+
|
| 278 |
+
# Stage 2: temp + C -> result
|
| 279 |
+
for bit in range(8):
|
| 280 |
+
add_full_adder(tensors, f"arithmetic.add3_8bit.stage2.fa{bit}")
|
| 281 |
+
|
| 282 |
+
|
| 283 |
def add_shl_shr(tensors: Dict[str, torch.Tensor]) -> None:
|
| 284 |
"""Add SHL (shift left) and SHR (shift right) circuits.
|
| 285 |
|
|
|
|
| 649 |
return []
|
| 650 |
|
| 651 |
|
| 652 |
+
def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]:
|
| 653 |
+
"""Infer inputs for 3-operand adder: A + B + C."""
|
| 654 |
+
prefix = "arithmetic.add3_8bit"
|
| 655 |
+
# Register all inputs
|
| 656 |
+
for i in range(8):
|
| 657 |
+
reg.register(f"$a[{i}]")
|
| 658 |
+
reg.register(f"$b[{i}]")
|
| 659 |
+
reg.register(f"$c[{i}]")
|
| 660 |
+
|
| 661 |
+
# Parse stage and bit
|
| 662 |
+
if '.stage1.' in gate:
|
| 663 |
+
m = re.search(r'\.fa(\d+)\.', gate)
|
| 664 |
+
if not m:
|
| 665 |
+
return []
|
| 666 |
+
bit = int(m.group(1))
|
| 667 |
+
# Stage 1: A + B (LSB is index 7 in MSB-first)
|
| 668 |
+
a_bit = reg.get_id(f"$a[{7-bit}]")
|
| 669 |
+
b_bit = reg.get_id(f"$b[{7-bit}]")
|
| 670 |
+
cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.stage1.fa{bit-1}.carry_or")
|
| 671 |
+
fa_prefix = f"{prefix}.stage1.fa{bit}"
|
| 672 |
+
elif '.stage2.' in gate:
|
| 673 |
+
m = re.search(r'\.fa(\d+)\.', gate)
|
| 674 |
+
if not m:
|
| 675 |
+
return []
|
| 676 |
+
bit = int(m.group(1))
|
| 677 |
+
# Stage 2: stage1_result + C
|
| 678 |
+
temp_bit = reg.register(f"{prefix}.stage1.fa{bit}.ha2.sum.layer2")
|
| 679 |
+
c_bit = reg.get_id(f"$c[{7-bit}]")
|
| 680 |
+
cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.stage2.fa{bit-1}.carry_or")
|
| 681 |
+
a_bit = temp_bit
|
| 682 |
+
b_bit = c_bit
|
| 683 |
+
fa_prefix = f"{prefix}.stage2.fa{bit}"
|
| 684 |
+
else:
|
| 685 |
+
return []
|
| 686 |
+
|
| 687 |
+
if '.ha1.sum.layer1' in gate:
|
| 688 |
+
return [a_bit, b_bit]
|
| 689 |
+
if '.ha1.sum.layer2' in gate:
|
| 690 |
+
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
|
| 691 |
+
if '.ha1.carry' in gate and '.layer' not in gate:
|
| 692 |
+
return [a_bit, b_bit]
|
| 693 |
+
if '.ha2.sum.layer1' in gate:
|
| 694 |
+
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
|
| 695 |
+
if '.ha2.sum.layer2' in gate:
|
| 696 |
+
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
|
| 697 |
+
if '.ha2.carry' in gate and '.layer' not in gate:
|
| 698 |
+
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
|
| 699 |
+
if '.carry_or' in gate:
|
| 700 |
+
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
|
| 701 |
+
return []
|
| 702 |
+
|
| 703 |
+
|
| 704 |
def infer_adcsbc_inputs(gate: str, prefix: str, is_sub: bool, reg: SignalRegistry) -> List[int]:
|
| 705 |
for i in range(8):
|
| 706 |
reg.register(f"{prefix}.$a[{i}]")
|
|
|
|
| 1177 |
return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry4bit", 4, reg)
|
| 1178 |
if 'ripplecarry8bit' in gate:
|
| 1179 |
return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg)
|
| 1180 |
+
if 'add3_8bit' in gate:
|
| 1181 |
+
return infer_add3_inputs(gate, reg)
|
| 1182 |
if 'adc8bit' in gate:
|
| 1183 |
return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg)
|
| 1184 |
if 'sbc8bit' in gate:
|
|
|
|
| 1404 |
"alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.",
|
| 1405 |
"arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
|
| 1406 |
"arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
|
| 1407 |
+
"arithmetic.equality8bit.", "arithmetic.add3_8bit.",
|
| 1408 |
"control.push.", "control.pop.", "control.ret.",
|
| 1409 |
"combinational.barrelshifter.", "combinational.priorityencoder.",
|
| 1410 |
])
|
|
|
|
| 1469 |
print(" Added GT, GE, LT, LE (single-layer), EQ (two-layer)")
|
| 1470 |
except ValueError as e:
|
| 1471 |
print(f" Comparators already exist: {e}")
|
| 1472 |
+
print("\nGenerating 3-operand adder circuit...")
|
| 1473 |
+
try:
|
| 1474 |
+
add_add3(tensors)
|
| 1475 |
+
print(" Added ADD3 (16 full adders = 144 gates)")
|
| 1476 |
+
except ValueError as e:
|
| 1477 |
+
print(f" ADD3 already exists: {e}")
|
| 1478 |
if args.apply:
|
| 1479 |
print(f"\nSaving: {args.model}")
|
| 1480 |
save_file(tensors, str(args.model))
|
eval.py
CHANGED
|
@@ -527,6 +527,110 @@ class BatchedFitnessEvaluator:
|
|
| 527 |
|
| 528 |
return correct, num_tests
|
| 529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
# =========================================================================
|
| 531 |
# COMPARATORS
|
| 532 |
# =========================================================================
|
|
@@ -2340,6 +2444,12 @@ class BatchedFitnessEvaluator:
|
|
| 2340 |
total_tests += t
|
| 2341 |
self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 2342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2343 |
# Comparators
|
| 2344 |
s, t = self._test_comparators(population, debug)
|
| 2345 |
scores += s
|
|
|
|
| 527 |
|
| 528 |
return correct, num_tests
|
| 529 |
|
| 530 |
+
# =========================================================================
|
| 531 |
+
# 3-OPERAND ADDER
|
| 532 |
+
# =========================================================================
|
| 533 |
+
|
| 534 |
+
def _test_add3(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 535 |
+
"""Test 3-operand 8-bit adder (A + B + C)."""
|
| 536 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 537 |
+
|
| 538 |
+
if debug:
|
| 539 |
+
print(f"\n=== 3-OPERAND ADDER ===")
|
| 540 |
+
|
| 541 |
+
prefix = 'arithmetic.add3_8bit'
|
| 542 |
+
bits = 8
|
| 543 |
+
|
| 544 |
+
# Strategic test cases for 3-operand addition
|
| 545 |
+
# Include edge cases and overflow scenarios
|
| 546 |
+
test_cases = []
|
| 547 |
+
# Small values
|
| 548 |
+
for a in [0, 1, 2]:
|
| 549 |
+
for b in [0, 1, 2]:
|
| 550 |
+
for c in [0, 1, 2]:
|
| 551 |
+
test_cases.append((a, b, c))
|
| 552 |
+
# Edge values
|
| 553 |
+
edge = [0, 1, 127, 128, 254, 255]
|
| 554 |
+
for a in edge:
|
| 555 |
+
for b in edge:
|
| 556 |
+
for c in edge:
|
| 557 |
+
test_cases.append((a, b, c))
|
| 558 |
+
# Specific multi-operand expression tests
|
| 559 |
+
test_cases.extend([
|
| 560 |
+
(15, 27, 33), # Example from roadmap: 15 + 27 + 33 = 75
|
| 561 |
+
(100, 100, 55), # = 255 (exact fit)
|
| 562 |
+
(100, 100, 56), # = 256 -> 0 (overflow)
|
| 563 |
+
(85, 85, 85), # = 255 (exact fit)
|
| 564 |
+
(86, 85, 85), # = 256 -> 0 (overflow)
|
| 565 |
+
])
|
| 566 |
+
test_cases = list(set(test_cases))
|
| 567 |
+
|
| 568 |
+
a_vals = torch.tensor([t[0] for t in test_cases], device=self.device)
|
| 569 |
+
b_vals = torch.tensor([t[1] for t in test_cases], device=self.device)
|
| 570 |
+
c_vals = torch.tensor([t[2] for t in test_cases], device=self.device)
|
| 571 |
+
num_tests = len(test_cases)
|
| 572 |
+
|
| 573 |
+
# Convert to bits [num_tests, bits] MSB-first
|
| 574 |
+
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 575 |
+
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 576 |
+
c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 577 |
+
|
| 578 |
+
# Stage 1: A + B
|
| 579 |
+
carry1 = torch.zeros(num_tests, pop_size, device=self.device)
|
| 580 |
+
stage1_bits = []
|
| 581 |
+
for bit in range(bits):
|
| 582 |
+
bit_idx = bits - 1 - bit # LSB first
|
| 583 |
+
s, carry1 = self._eval_single_fa(
|
| 584 |
+
pop, f'{prefix}.stage1.fa{bit}',
|
| 585 |
+
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
|
| 586 |
+
b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
|
| 587 |
+
carry1
|
| 588 |
+
)
|
| 589 |
+
stage1_bits.append(s)
|
| 590 |
+
|
| 591 |
+
# Stage 2: stage1_result + C
|
| 592 |
+
carry2 = torch.zeros(num_tests, pop_size, device=self.device)
|
| 593 |
+
result_bits = []
|
| 594 |
+
for bit in range(bits):
|
| 595 |
+
bit_idx = bits - 1 - bit # LSB first
|
| 596 |
+
s, carry2 = self._eval_single_fa(
|
| 597 |
+
pop, f'{prefix}.stage2.fa{bit}',
|
| 598 |
+
stage1_bits[bit], # Already [num_tests, pop_size]
|
| 599 |
+
c_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
|
| 600 |
+
carry2
|
| 601 |
+
)
|
| 602 |
+
result_bits.append(s)
|
| 603 |
+
|
| 604 |
+
# Reconstruct result (bits are in LSB-first order, need to reverse for MSB-first)
|
| 605 |
+
result_bits = torch.stack(result_bits[::-1], dim=-1) # MSB first
|
| 606 |
+
result = torch.zeros(num_tests, pop_size, device=self.device)
|
| 607 |
+
for i in range(bits):
|
| 608 |
+
result += result_bits[:, :, i] * (1 << (bits - 1 - i))
|
| 609 |
+
|
| 610 |
+
# Expected (8-bit wrap)
|
| 611 |
+
expected = ((a_vals + b_vals + c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float()
|
| 612 |
+
correct = (result == expected).float().sum(0)
|
| 613 |
+
|
| 614 |
+
failures = []
|
| 615 |
+
if pop_size == 1:
|
| 616 |
+
for i in range(min(num_tests, 100)):
|
| 617 |
+
if result[i, 0].item() != expected[i, 0].item():
|
| 618 |
+
failures.append((
|
| 619 |
+
[int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())],
|
| 620 |
+
int(expected[i, 0].item()),
|
| 621 |
+
int(result[i, 0].item())
|
| 622 |
+
))
|
| 623 |
+
|
| 624 |
+
self._record(prefix, int(correct[0].item()), num_tests, failures)
|
| 625 |
+
if debug:
|
| 626 |
+
r = self.results[-1]
|
| 627 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 628 |
+
if failures:
|
| 629 |
+
for inp, exp, got in failures[:5]:
|
| 630 |
+
print(f" FAIL: {inp[0]} + {inp[1]} + {inp[2]} = {exp}, got {got}")
|
| 631 |
+
|
| 632 |
+
return correct, num_tests
|
| 633 |
+
|
| 634 |
# =========================================================================
|
| 635 |
# COMPARATORS
|
| 636 |
# =========================================================================
|
|
|
|
| 2444 |
total_tests += t
|
| 2445 |
self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 2446 |
|
| 2447 |
+
# 3-operand adder
|
| 2448 |
+
s, t = self._test_add3(population, debug)
|
| 2449 |
+
scores += s
|
| 2450 |
+
total_tests += t
|
| 2451 |
+
self.category_scores['add3'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 2452 |
+
|
| 2453 |
# Comparators
|
| 2454 |
s, t = self._test_comparators(population, debug)
|
| 2455 |
scores += s
|
neural_computer.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:270309b1ac57e808827cee555b6f6f9e3f14c37abe23fa21069db4ff251a0b72
|
| 3 |
+
size 34552948
|