|
|
""" |
|
|
Build tools for 8-bit Threshold Computer safetensors. |
|
|
|
|
|
Subcommands: |
|
|
python build.py memory - Generate 64KB memory circuits |
|
|
python build.py inputs - Add .inputs metadata tensors |
|
|
python build.py all - Run both (memory first, then inputs) |
|
|
|
|
|
|
|
|
ROUTING SCHEMA (formerly routing.json) |
|
|
====================================== |
|
|
|
|
|
Routing info is now embedded in safetensors via .inputs tensors and signal registry metadata. |
|
|
|
|
|
|
|
|
INPUT SOURCE TYPES |
|
|
------------------ |
|
|
|
|
|
1. External input: "$input_name" - Named input to the circuit |
|
|
- Example: "$a", "$b", "$cin" |
|
|
|
|
|
2. Gate output: "path.to.gate" - Output of another gate |
|
|
- Example: "ha1.sum", "layer1.or" |
|
|
|
|
|
3. Bit extraction: "$input[i]" - Single bit from multi-bit input |
|
|
- Example: "$a[0]" (LSB), "$a[7]" (MSB for 8-bit) |
|
|
|
|
|
4. Constant: "#0" or "#1" - Fixed value |
|
|
- Example: "#1" for carry-in in two's complement |
|
|
|
|
|
|
|
|
CIRCUIT TYPES |
|
|
------------- |
|
|
|
|
|
Single-Layer Gates: .weight and .bias only |
|
|
"boolean.and": ["$a", "$b"] |
|
|
|
|
|
Two-Layer Gates (XOR, XNOR): layer1 + layer2 |
|
|
"boolean.xor.layer1.or": ["$a", "$b"] |
|
|
"boolean.xor.layer1.nand": ["$a", "$b"] |
|
|
"boolean.xor.layer2": ["layer1.or", "layer1.nand"] |
|
|
|
|
|
Hierarchical Circuits: nested sub-components |
|
|
"arithmetic.fulladder": { |
|
|
"ha1.sum.layer1.or": ["$a", "$b"], |
|
|
"ha1.carry": ["$a", "$b"], |
|
|
"ha2.sum.layer1.or": ["ha1.sum", "$cin"], |
|
|
"carry_or": ["ha1.carry", "ha2.carry"] |
|
|
} |
|
|
|
|
|
Bit-Indexed Circuits: multi-bit operations |
|
|
"arithmetic.ripplecarry8bit.fa0": ["$a[0]", "$b[0]", "#0"] |
|
|
"arithmetic.ripplecarry8bit.fa1": ["$a[1]", "$b[1]", "fa0.cout"] |
|
|
|
|
|
|
|
|
PACKED MEMORY CIRCUITS |
|
|
---------------------- |
|
|
|
|
|
64KB memory uses packed tensors (shapes for 16-bit address, 8-bit data): |
|
|
|
|
|
memory.addr_decode.weight: [65536, 16] |
|
|
memory.addr_decode.bias: [65536] |
|
|
memory.read.and.weight: [8, 65536, 2] |
|
|
memory.read.and.bias: [8, 65536] |
|
|
memory.read.or.weight: [8, 65536] |
|
|
memory.read.or.bias: [8] |
|
|
memory.write.sel.weight: [65536, 2] |
|
|
memory.write.sel.bias: [65536] |
|
|
memory.write.nsel.weight: [65536, 1] |
|
|
memory.write.nsel.bias: [65536] |
|
|
memory.write.and_old.weight: [65536, 8, 2] |
|
|
memory.write.and_old.bias: [65536, 8] |
|
|
memory.write.and_new.weight: [65536, 8, 2] |
|
|
memory.write.and_new.bias: [65536, 8] |
|
|
memory.write.or.weight: [65536, 8, 2] |
|
|
memory.write.or.bias: [65536, 8] |
|
|
|
|
|
Semantics: |
|
|
decode: sel[i] = H(sum(addr_bits * weight[i]) + bias[i]) |
|
|
read: bit[b] = H(sum(H([mem_bit, sel] * and_w) + and_b) * or_w + or_b) |
|
|
write: new = H(H([old, nsel] * and_old) + H([data, sel] * and_new) - 1) |
|
|
|
|
|
|
|
|
SIGNAL REGISTRY |
|
|
--------------- |
|
|
|
|
|
Signal IDs are stored in safetensors metadata as JSON: |
|
|
|
|
|
{"0": "#0", "1": "#1", "2": "$a", "3": "$b", ...} |
|
|
|
|
|
Each gate's .inputs tensor contains integer IDs referencing this registry. |
|
|
|
|
|
|
|
|
NAMING CONVENTIONS |
|
|
------------------ |
|
|
|
|
|
- External inputs: $name or $name[bit] |
|
|
- Constants: #0, #1 |
|
|
- Internal gates: relative path from circuit root |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import Dict, Iterable, List, Set |
|
|
|
|
|
import torch |
|
|
from safetensors import safe_open |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
|
|
|
MODEL_DIR = Path(__file__).resolve().parent |
|
|
|
|
|
|
|
|
def get_model_path(bits: int = 8, memory_profile: str = None, addr_bits: int = None) -> Path: |
|
|
"""Generate model filename based on configuration.""" |
|
|
if addr_bits is not None: |
|
|
if addr_bits == 0: |
|
|
has_memory = False |
|
|
mem_suffix = "" |
|
|
else: |
|
|
has_memory = True |
|
|
mem_suffix = f"_addr{addr_bits}" |
|
|
elif memory_profile == "none": |
|
|
has_memory = False |
|
|
mem_suffix = "" |
|
|
elif memory_profile == "full" or memory_profile is None: |
|
|
has_memory = True |
|
|
mem_suffix = "" |
|
|
else: |
|
|
has_memory = True |
|
|
mem_suffix = f"_{memory_profile}" |
|
|
|
|
|
base = "neural_alu" if not has_memory else "neural_computer" |
|
|
|
|
|
return MODEL_DIR / f"{base}{bits}{mem_suffix}.safetensors" |
|
|
|
|
|
|
|
|
MODEL_PATH = MODEL_DIR / "neural_computer8.safetensors" |
|
|
MANIFEST_PATH = Path(__file__).resolve().parent / "tensors.txt" |
|
|
|
|
|
DEFAULT_ADDR_BITS = 16 |
|
|
DEFAULT_MEM_BYTES = 1 << DEFAULT_ADDR_BITS |
|
|
|
|
|
MEMORY_PROFILES = { |
|
|
"full": 16, |
|
|
"reduced": 12, |
|
|
"small": 10, |
|
|
"scratchpad": 8, |
|
|
"registers": 4, |
|
|
"none": 0, |
|
|
} |
|
|
|
|
|
SUPPORTED_BITS = [8, 16, 32] |
|
|
|
|
|
|
|
|
def load_tensors(path: Path) -> Dict[str, torch.Tensor]: |
|
|
tensors: Dict[str, torch.Tensor] = {} |
|
|
with safe_open(str(path), framework="pt") as f: |
|
|
for name in f.keys(): |
|
|
tensors[name] = f.get_tensor(name).clone() |
|
|
return tensors |
|
|
|
|
|
|
|
|
def get_all_gates(tensors: Dict[str, torch.Tensor]) -> Set[str]: |
|
|
gates = set() |
|
|
for name in tensors: |
|
|
if name.endswith('.weight'): |
|
|
gates.add(name[:-7]) |
|
|
return gates |
|
|
|
|
|
|
|
|
class SignalRegistry: |
|
|
def __init__(self): |
|
|
self.name_to_id: Dict[str, int] = {} |
|
|
self.id_to_name: Dict[int, str] = {} |
|
|
self.next_id = 0 |
|
|
self.register("#0") |
|
|
self.register("#1") |
|
|
|
|
|
def register(self, name: str) -> int: |
|
|
if name not in self.name_to_id: |
|
|
self.name_to_id[name] = self.next_id |
|
|
self.id_to_name[self.next_id] = name |
|
|
self.next_id += 1 |
|
|
return self.name_to_id[name] |
|
|
|
|
|
def get_id(self, name: str) -> int: |
|
|
return self.name_to_id.get(name, -1) |
|
|
|
|
|
def to_metadata(self) -> str: |
|
|
return json.dumps(self.id_to_name) |
|
|
|
|
|
|
|
|
def add_gate(tensors: Dict[str, torch.Tensor], name: str, weight: Iterable[float], bias: Iterable[float]) -> None: |
|
|
w_key = f"{name}.weight" |
|
|
b_key = f"{name}.bias" |
|
|
if w_key in tensors or b_key in tensors: |
|
|
raise ValueError(f"Gate already exists: {name}") |
|
|
tensors[w_key] = torch.tensor(list(weight), dtype=torch.float32) |
|
|
tensors[b_key] = torch.tensor(list(bias), dtype=torch.float32) |
|
|
|
|
|
|
|
|
def drop_prefixes(tensors: Dict[str, torch.Tensor], prefixes: List[str]) -> None: |
|
|
for key in list(tensors.keys()): |
|
|
if any(key.startswith(prefix) for prefix in prefixes): |
|
|
del tensors[key] |
|
|
|
|
|
|
|
|
def add_decoder(tensors: Dict[str, torch.Tensor], addr_bits: int, mem_bytes: int) -> None: |
|
|
weights = torch.empty((mem_bytes, addr_bits), dtype=torch.float32) |
|
|
bias = torch.empty((mem_bytes,), dtype=torch.float32) |
|
|
for addr in range(mem_bytes): |
|
|
bits = [(addr >> (addr_bits - 1 - i)) & 1 for i in range(addr_bits)] |
|
|
weights[addr] = torch.tensor([1.0 if bit == 1 else -1.0 for bit in bits], dtype=torch.float32) |
|
|
bias[addr] = -float(sum(bits)) |
|
|
tensors["memory.addr_decode.weight"] = weights |
|
|
tensors["memory.addr_decode.bias"] = bias |
|
|
|
|
|
|
|
|
def add_memory_read_mux(tensors: Dict[str, torch.Tensor], mem_bytes: int) -> None: |
|
|
and_weight = torch.ones((8, mem_bytes, 2), dtype=torch.float32) |
|
|
and_bias = torch.full((8, mem_bytes), -2.0, dtype=torch.float32) |
|
|
or_weight = torch.ones((8, mem_bytes), dtype=torch.float32) |
|
|
or_bias = torch.full((8,), -1.0, dtype=torch.float32) |
|
|
tensors["memory.read.and.weight"] = and_weight |
|
|
tensors["memory.read.and.bias"] = and_bias |
|
|
tensors["memory.read.or.weight"] = or_weight |
|
|
tensors["memory.read.or.bias"] = or_bias |
|
|
|
|
|
|
|
|
def add_memory_write_cells(tensors: Dict[str, torch.Tensor], mem_bytes: int) -> None: |
|
|
sel_weight = torch.ones((mem_bytes, 2), dtype=torch.float32) |
|
|
sel_bias = torch.full((mem_bytes,), -2.0, dtype=torch.float32) |
|
|
nsel_weight = torch.full((mem_bytes, 1), -1.0, dtype=torch.float32) |
|
|
nsel_bias = torch.zeros((mem_bytes,), dtype=torch.float32) |
|
|
and_old_weight = torch.ones((mem_bytes, 8, 2), dtype=torch.float32) |
|
|
and_old_bias = torch.full((mem_bytes, 8), -2.0, dtype=torch.float32) |
|
|
and_new_weight = torch.ones((mem_bytes, 8, 2), dtype=torch.float32) |
|
|
and_new_bias = torch.full((mem_bytes, 8), -2.0, dtype=torch.float32) |
|
|
or_weight = torch.ones((mem_bytes, 8, 2), dtype=torch.float32) |
|
|
or_bias = torch.full((mem_bytes, 8), -1.0, dtype=torch.float32) |
|
|
tensors["memory.write.sel.weight"] = sel_weight |
|
|
tensors["memory.write.sel.bias"] = sel_bias |
|
|
tensors["memory.write.nsel.weight"] = nsel_weight |
|
|
tensors["memory.write.nsel.bias"] = nsel_bias |
|
|
tensors["memory.write.and_old.weight"] = and_old_weight |
|
|
tensors["memory.write.and_old.bias"] = and_old_bias |
|
|
tensors["memory.write.and_new.weight"] = and_new_weight |
|
|
tensors["memory.write.and_new.bias"] = and_new_bias |
|
|
tensors["memory.write.or.weight"] = or_weight |
|
|
tensors["memory.write.or.bias"] = or_bias |
|
|
|
|
|
|
|
|
def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int) -> None: |
|
|
"""Add control buffers for fetch, load, store operations. |
|
|
|
|
|
Args: |
|
|
data_bits: Width of data bus (8/16/32) |
|
|
addr_bits: Width of address bus (determines instruction register width) |
|
|
""" |
|
|
|
|
|
|
|
|
ir_bits = max(16, addr_bits) |
|
|
for bit in range(ir_bits): |
|
|
add_gate(tensors, f"control.fetch.ir.bit{bit}", [1.0], [-1.0]) |
|
|
for bit in range(data_bits): |
|
|
add_gate(tensors, f"control.load.bit{bit}", [1.0], [-1.0]) |
|
|
add_gate(tensors, f"control.store.bit{bit}", [1.0], [-1.0]) |
|
|
for bit in range(addr_bits): |
|
|
add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_full_adder(tensors: Dict[str, torch.Tensor], prefix: str) -> None: |
|
|
"""Add a single full adder at the given prefix. |
|
|
|
|
|
Full adder structure: |
|
|
- ha1: first half adder (A XOR B for sum, A AND B for carry) |
|
|
- ha2: second half adder (ha1.sum XOR Cin for sum, ha1.sum AND Cin for carry) |
|
|
- carry_or: OR of ha1.carry and ha2.carry for final carry out |
|
|
""" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_expr_add_mul(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add expression circuit for A + B × C (order of operations). |
|
|
|
|
|
Computes A + (B × C) where multiplication has higher precedence. |
|
|
|
|
|
Structure: |
|
|
- Stage 1: Multiply B × C using shift-add algorithm |
|
|
- 8 mask stages: mask[i] = B AND C[i] (8 AND gates each, shifted) |
|
|
- 7 accumulator adders to sum masked values |
|
|
- Stage 2: Add A to multiplication result (8-bit ripple carry) |
|
|
|
|
|
Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each) |
|
|
Output: 8-bit result of A + (B × C), wrapping on overflow |
|
|
|
|
|
Total: 64 AND gates + 7×8 full adders (mul) + 8 full adders (add) = ~640 gates |
|
|
""" |
|
|
prefix = "arithmetic.expr_add_mul" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for stage in range(8): |
|
|
for bit in range(8): |
|
|
add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for stage in range(1, 8): |
|
|
for bit in range(8): |
|
|
add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}") |
|
|
|
|
|
|
|
|
for bit in range(8): |
|
|
add_full_adder(tensors, f"{prefix}.add.fa{bit}") |
|
|
|
|
|
|
|
|
def add_expr_paren_add_mul(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add expression circuit for (A + B) × C (parenthetical override). |
|
|
|
|
|
Computes (A + B) × C where parentheses override normal precedence. |
|
|
Addition happens first, then multiplication. |
|
|
|
|
|
Structure: |
|
|
- Stage 1: Add A + B (8-bit ripple carry adder) |
|
|
- Stage 2: Multiply sum × C using shift-add algorithm |
|
|
- 8 mask stages: mask[i] = sum AND C[i] (8 AND gates each) |
|
|
- 7 accumulator adders to sum shifted masked values |
|
|
|
|
|
Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each) |
|
|
Output: 8-bit result of (A + B) × C, wrapping on overflow |
|
|
|
|
|
Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates |
|
|
""" |
|
|
prefix = "arithmetic.expr_paren_add_mul" |
|
|
|
|
|
|
|
|
for bit in range(8): |
|
|
add_full_adder(tensors, f"{prefix}.add.fa{bit}") |
|
|
|
|
|
|
|
|
|
|
|
for stage in range(8): |
|
|
for bit in range(8): |
|
|
add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
for stage in range(1, 8): |
|
|
for bit in range(8): |
|
|
add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}") |
|
|
|
|
|
|
|
|
def add_expr_paren(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add expression circuit for (A + B) × C (parenthetical grouping). |
|
|
|
|
|
Computes (A + B) × C where addition happens first due to parentheses. |
|
|
|
|
|
Structure: |
|
|
- Stage 1: Add A + B (8-bit ripple carry) |
|
|
- Stage 2: Multiply sum × C using shift-add algorithm |
|
|
- 8 mask stages: mask[i] = sum AND C[i] (8 AND gates each) |
|
|
- 7 accumulator adders to sum shifted masked values |
|
|
|
|
|
Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each) |
|
|
Output: 8-bit result of (A + B) × C, wrapping on overflow |
|
|
|
|
|
Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates |
|
|
""" |
|
|
prefix = "arithmetic.expr_paren" |
|
|
|
|
|
|
|
|
for bit in range(8): |
|
|
add_full_adder(tensors, f"{prefix}.add.fa{bit}") |
|
|
|
|
|
|
|
|
|
|
|
for stage in range(8): |
|
|
for bit in range(8): |
|
|
add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
for stage in range(1, 8): |
|
|
for bit in range(8): |
|
|
add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}") |
|
|
|
|
|
|
|
|
def add_add3(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add 3-operand 8-bit adder circuit. |
|
|
|
|
|
Computes A + B + C using two chained ripple-carry stages: |
|
|
- Stage 1: temp = A + B (8 full adders) |
|
|
- Stage 2: result = temp + C (8 full adders) |
|
|
|
|
|
Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first) |
|
|
Outputs: stage2.fa0-7.ha2.sum.layer2 (result bits), stage2.fa7.carry_or (overflow) |
|
|
|
|
|
Total: 16 full adders = 144 gates |
|
|
""" |
|
|
|
|
|
for bit in range(8): |
|
|
add_full_adder(tensors, f"arithmetic.add3_8bit.stage1.fa{bit}") |
|
|
|
|
|
|
|
|
for bit in range(8): |
|
|
add_full_adder(tensors, f"arithmetic.add3_8bit.stage2.fa{bit}") |
|
|
|
|
|
|
|
|
def add_shl_shr(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add SHL (shift left) and SHR (shift right) circuits. |
|
|
|
|
|
Identity gate: w=2, b=-1 -> H(x*2 - 1) = x for x in {0,1} |
|
|
Zero gate: w=0, b=-1 -> H(-1) = 0 |
|
|
|
|
|
SHL (MSB-first): out[i] = in[i+1] for i<7, out[7] = 0 |
|
|
SHR (MSB-first): out[0] = 0, out[i] = in[i-1] for i>0 |
|
|
""" |
|
|
for bit in range(8): |
|
|
if bit < 7: |
|
|
add_gate(tensors, f"alu.alu8bit.shl.bit{bit}", [2.0], [-1.0]) |
|
|
else: |
|
|
add_gate(tensors, f"alu.alu8bit.shl.bit{bit}", [0.0], [-1.0]) |
|
|
|
|
|
for bit in range(8): |
|
|
if bit > 0: |
|
|
add_gate(tensors, f"alu.alu8bit.shr.bit{bit}", [2.0], [-1.0]) |
|
|
else: |
|
|
add_gate(tensors, f"alu.alu8bit.shr.bit{bit}", [0.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_mul(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add 8-bit multiplication circuit. |
|
|
|
|
|
Produces low 8 bits of the 16-bit result. |
|
|
|
|
|
Structure: |
|
|
- 64 AND gates for partial products P[i][j] = A[i] AND B[j] |
|
|
- Uses existing ripple-carry adder components for summation |
|
|
|
|
|
The multiply method in ThresholdALU computes: |
|
|
1. Partial products via these AND gates |
|
|
2. Shift-add accumulation via existing 8-bit adder |
|
|
""" |
|
|
|
|
|
|
|
|
for i in range(8): |
|
|
for j in range(8): |
|
|
add_gate(tensors, f"alu.alu8bit.mul.pp.a{i}b{j}", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_div(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add 8-bit division circuit. |
|
|
|
|
|
Produces quotient (8 bits) and remainder (8 bits). |
|
|
|
|
|
Uses restoring division algorithm: |
|
|
- 8 iterations, each producing one quotient bit |
|
|
- Each iteration: compare, conditionally subtract, shift |
|
|
|
|
|
Structure: |
|
|
- 8 comparison gates (one per iteration) |
|
|
- 8 conditional subtraction stages |
|
|
- Uses existing comparator and subtractor components |
|
|
""" |
|
|
|
|
|
for stage in range(8): |
|
|
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.cmp", |
|
|
[128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0, |
|
|
-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0], [0.0]) |
|
|
|
|
|
|
|
|
for stage in range(8): |
|
|
for bit in range(8): |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.and_a", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.and_b", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_inc_dec(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add INC and DEC circuits. |
|
|
|
|
|
INC: A + 1 using half adders with carry chain |
|
|
DEC: A - 1 using borrow chain (A + 255, two's complement of 1) |
|
|
|
|
|
For INC, we add 1 to the LSB and propagate carry. |
|
|
For DEC, we add 0xFF (two's complement of 1) or use borrow logic. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for bit in range(8): |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.inc.bit{bit}.carry", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for bit in range(8): |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.not_a", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.borrow", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_neg(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add NEG circuit (two's complement negation). |
|
|
|
|
|
NEG(A) = NOT(A) + 1 = ~A + 1 |
|
|
|
|
|
Structure: NOT gates followed by INC-style adder. |
|
|
""" |
|
|
for bit in range(8): |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.neg.not.bit{bit}", [-1.0], [0.0]) |
|
|
|
|
|
add_gate(tensors, f"alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.neg.inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"alu.alu8bit.neg.inc.bit{bit}.carry", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_rol_ror(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add 8-bit ROL and ROR circuits (legacy wrapper).""" |
|
|
add_rol_ror_nbits(tensors, 8) |
|
|
|
|
|
|
|
|
def add_rol_ror_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit ROL and ROR circuits (rotate left/right). |
|
|
|
|
|
ROL: out[i] = in[i+1] for i<N-1, out[N-1] = in[0] (MSB wraps to LSB) |
|
|
ROR: out[0] = in[N-1], out[i] = in[i-1] for i>0 (LSB wraps to MSB) |
|
|
|
|
|
Args: |
|
|
bits: Data width (8, 16, 32, etc.) |
|
|
""" |
|
|
|
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.rol.bit{bit}", [2.0], [-1.0]) |
|
|
|
|
|
|
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.ror.bit{bit}", [2.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_stack_ops(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int) -> None: |
|
|
"""Add RET, PUSH, POP circuit components. |
|
|
|
|
|
These are higher-level operations that use memory read/write. |
|
|
We create the control logic gates. |
|
|
|
|
|
Args: |
|
|
data_bits: Width of data to push/pop (8/16/32) |
|
|
addr_bits: Width of stack pointer and return addresses |
|
|
|
|
|
RET: Pop return address from stack, jump to it |
|
|
PUSH: Decrement SP, write value to [SP] |
|
|
POP: Read value from [SP], increment SP |
|
|
""" |
|
|
|
|
|
for bit in range(addr_bits): |
|
|
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"control.push.sp_dec.bit{bit}.borrow", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
for bit in range(addr_bits): |
|
|
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.carry", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
for bit in range(data_bits): |
|
|
add_gate(tensors, f"control.push.data.bit{bit}", [2.0], [-1.0]) |
|
|
|
|
|
|
|
|
for bit in range(data_bits): |
|
|
add_gate(tensors, f"control.pop.data.bit{bit}", [2.0], [-1.0]) |
|
|
|
|
|
|
|
|
for bit in range(addr_bits): |
|
|
add_gate(tensors, f"control.ret.addr.bit{bit}", [2.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_conditional_jumps(tensors: Dict[str, torch.Tensor], addr_bits: int) -> None: |
|
|
"""Add conditional jump circuits (JZ, JNZ, JC, JNC, JP, JN, JV, JNV). |
|
|
|
|
|
Each conditional jump is a 2:1 MUX per address bit: |
|
|
- If flag is set: output = target_bit |
|
|
- If flag is clear: output = pc_bit |
|
|
|
|
|
Structure per bit: |
|
|
- not_sel: NOT(flag) |
|
|
- and_a: pc_bit AND NOT(flag) |
|
|
- and_b: target_bit AND flag |
|
|
- or: and_a OR and_b |
|
|
|
|
|
Args: |
|
|
addr_bits: Width of program counter / jump target |
|
|
""" |
|
|
jump_types = ['jz', 'jnz', 'jc', 'jnc', 'jp', 'jn', 'jv', 'jnv'] |
|
|
|
|
|
for jmp in jump_types: |
|
|
for bit in range(addr_bits): |
|
|
prefix = f"control.{jmp}.bit{bit}" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.not_sel", [-1.0], [0.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.and_a", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.and_b", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_status_flags(tensors: Dict[str, torch.Tensor], data_bits: int) -> None: |
|
|
"""Add status flag computation circuits (Z, N, C, V). |
|
|
|
|
|
Args: |
|
|
data_bits: Width of ALU data (8/16/32) |
|
|
|
|
|
Flags: |
|
|
- Z (Zero): NOR of all result bits (1 if result == 0) |
|
|
- N (Negative): Copy of MSB (sign bit) |
|
|
- C (Carry): Carry out from adder (external input) |
|
|
- V (Overflow): XOR of carry into and out of MSB (signed overflow) |
|
|
""" |
|
|
|
|
|
|
|
|
add_gate(tensors, "flags.zero", [-1.0] * data_bits, [0.0]) |
|
|
|
|
|
|
|
|
add_gate(tensors, "flags.negative", [2.0], [-1.0]) |
|
|
|
|
|
|
|
|
add_gate(tensors, "flags.carry", [2.0], [-1.0]) |
|
|
|
|
|
|
|
|
|
|
|
add_gate(tensors, "flags.overflow.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, "flags.overflow.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, "flags.overflow.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_barrel_shifter(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add 8-bit barrel shifter circuit (legacy wrapper).""" |
|
|
add_barrel_shifter_nbits(tensors, 8) |
|
|
|
|
|
|
|
|
def add_barrel_shifter_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit barrel shifter circuit. |
|
|
|
|
|
Shifts input by 0 to (bits-1) positions based on ceil(log2(bits))-bit shift amount. |
|
|
Uses layers of 2:1 muxes controlled by shift amount bits. |
|
|
|
|
|
Args: |
|
|
bits: Data width (8, 16, 32, etc.) |
|
|
""" |
|
|
import math |
|
|
num_layers = max(1, math.ceil(math.log2(bits))) |
|
|
|
|
|
for layer in range(num_layers): |
|
|
shift_amount = 1 << (num_layers - 1 - layer) |
|
|
for bit in range(bits): |
|
|
prefix = f"combinational.barrelshifter{bits}.layer{layer}.bit{bit}" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.and_a", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.and_b", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_priority_encoder(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add 8-bit priority encoder circuit (legacy wrapper).""" |
|
|
add_priority_encoder_nbits(tensors, 8) |
|
|
|
|
|
|
|
|
def add_priority_encoder_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit priority encoder circuit. |
|
|
|
|
|
Finds the position of the highest set bit (0 to bits-1). |
|
|
Position 0 = MSB (highest priority), Position bits-1 = LSB (lowest priority). |
|
|
Output is ceil(log2(bits))-bit index + valid flag. |
|
|
|
|
|
Circuit structure: |
|
|
1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions) |
|
|
2. is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) |
|
|
3. out{bit}: OR of is_highest{pos} for positions where (pos >> bit) & 1 |
|
|
4. valid: OR of all input bits |
|
|
|
|
|
Args: |
|
|
bits: Input width (8, 16, 32, etc.) |
|
|
""" |
|
|
import math |
|
|
out_bits = max(1, math.ceil(math.log2(bits))) |
|
|
prefix = f"combinational.priorityencoder{bits}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for pos in range(1, bits): |
|
|
weights = [1.0] * pos |
|
|
add_gate(tensors, f"{prefix}.any_higher{pos}", weights, [-1.0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for pos in range(1, bits): |
|
|
add_gate(tensors, f"{prefix}.is_highest{pos}.not_higher", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.is_highest{pos}.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
|
|
|
for out_bit in range(out_bits): |
|
|
weights = [] |
|
|
for pos in range(bits): |
|
|
if (pos >> out_bit) & 1: |
|
|
weights.append(1.0) |
|
|
if weights: |
|
|
add_gate(tensors, f"{prefix}.out{out_bit}", weights, [-1.0]) |
|
|
|
|
|
|
|
|
add_gate(tensors, f"{prefix}.valid", [1.0] * bits, [-1.0]) |
|
|
|
|
|
|
|
|
def add_comparators(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add 8-bit comparator circuits (GT, LT, GE, LE, EQ). |
|
|
|
|
|
Each comparator takes 16 inputs (8 bits from A, 8 bits from B) in MSB-first order. |
|
|
Uses weighted sum comparison on the binary representation. |
|
|
|
|
|
For unsigned comparison of A vs B: |
|
|
- Assign positional weights: bit i has weight 2^(7-i) |
|
|
- A > B: sum(a_i * w_i) > sum(b_i * w_i) |
|
|
- This becomes: sum(a_i * w_i - b_i * w_i) > 0 |
|
|
- Or: sum((a_i - b_i) * w_i) > 0 |
|
|
|
|
|
Threshold gate: H(sum(x_i * w_i) + b) = 1 if sum >= -b |
|
|
|
|
|
For A > B: weights = [128, 64, 32, 16, 8, 4, 2, 1, -128, -64, -32, -16, -8, -4, -2, -1] |
|
|
bias = -1 (strictly greater, so need sum >= 1) |
|
|
For A >= B: bias = 0 (sum >= 0) |
|
|
For A < B: flip weights, bias = -1 |
|
|
For A <= B: flip weights, bias = 0 |
|
|
For A == B: need A >= B AND A <= B (two-layer) |
|
|
""" |
|
|
pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0] |
|
|
neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0] |
|
|
|
|
|
gt_weights = pos_weights + neg_weights |
|
|
lt_weights = neg_weights + pos_weights |
|
|
|
|
|
add_gate(tensors, "arithmetic.greaterthan8bit", gt_weights, [-1.0]) |
|
|
add_gate(tensors, "arithmetic.greaterorequal8bit", gt_weights, [0.0]) |
|
|
add_gate(tensors, "arithmetic.lessthan8bit", lt_weights, [-1.0]) |
|
|
add_gate(tensors, "arithmetic.lessorequal8bit", lt_weights, [0.0]) |
|
|
|
|
|
add_gate(tensors, "arithmetic.equality8bit.layer1.geq", gt_weights, [0.0]) |
|
|
add_gate(tensors, "arithmetic.equality8bit.layer1.leq", lt_weights, [0.0]) |
|
|
add_gate(tensors, "arithmetic.equality8bit.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_ripple_carry_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit ripple carry adder circuit. |
|
|
|
|
|
Creates a chain of full adders for N-bit addition. |
|
|
Works for 8, 16, or 32 bits. |
|
|
|
|
|
Inputs: $a[0..N-1], $b[0..N-1] (MSB-first) |
|
|
Outputs: fa0-fa{N-1} sum bits, fa{N-1}.carry_or for overflow |
|
|
""" |
|
|
prefix = f"arithmetic.ripplecarry{bits}bit" |
|
|
for bit in range(bits): |
|
|
add_full_adder(tensors, f"{prefix}.fa{bit}") |
|
|
|
|
|
|
|
|
def add_sub_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit subtractor circuit (A - B). |
|
|
|
|
|
Uses two's complement: A - B = A + (~B) + 1 |
|
|
|
|
|
Structure: |
|
|
- NOT gates for each bit of B |
|
|
- N-bit ripple carry adder with carry_in = 1 |
|
|
|
|
|
The carry_in=1 is handled by the adder's fa0 having cin=#1 instead of #0. |
|
|
""" |
|
|
prefix = f"arithmetic.sub{bits}bit" |
|
|
|
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"{prefix}.not_b.bit{bit}", [-1.0], [0.0]) |
|
|
|
|
|
for bit in range(bits): |
|
|
add_full_adder(tensors, f"{prefix}.fa{bit}") |
|
|
|
|
|
|
|
|
def add_comparators_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit comparator circuits (GT, LT, GE, LE, EQ). |
|
|
|
|
|
For bits <= 16: Use single-layer weighted comparison (float32 safe). |
|
|
For bits > 16: Use cascaded byte-wise comparison to avoid float32 precision loss. |
|
|
|
|
|
Cascaded approach compares byte-by-byte from MSB: |
|
|
A > B iff: (A[31:24] > B[31:24]) OR |
|
|
(A[31:24] == B[31:24] AND A[23:16] > B[23:16]) OR ... |
|
|
""" |
|
|
if bits <= 16: |
|
|
pos_weights = [float(1 << (bits - 1 - i)) for i in range(bits)] |
|
|
neg_weights = [-w for w in pos_weights] |
|
|
|
|
|
gt_weights = pos_weights + neg_weights |
|
|
lt_weights = neg_weights + pos_weights |
|
|
|
|
|
add_gate(tensors, f"arithmetic.greaterthan{bits}bit", gt_weights, [-1.0]) |
|
|
add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", gt_weights, [0.0]) |
|
|
add_gate(tensors, f"arithmetic.lessthan{bits}bit", lt_weights, [-1.0]) |
|
|
add_gate(tensors, f"arithmetic.lessorequal{bits}bit", lt_weights, [0.0]) |
|
|
|
|
|
add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.geq", gt_weights, [0.0]) |
|
|
add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.leq", lt_weights, [0.0]) |
|
|
add_gate(tensors, f"arithmetic.equality{bits}bit.layer2", [1.0, 1.0], [-2.0]) |
|
|
else: |
|
|
num_bytes = bits // 8 |
|
|
prefix = f"arithmetic.cmp{bits}bit" |
|
|
|
|
|
byte_pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0] |
|
|
byte_neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0] |
|
|
byte_gt_weights = byte_pos_weights + byte_neg_weights |
|
|
byte_lt_weights = byte_neg_weights + byte_pos_weights |
|
|
|
|
|
for b in range(num_bytes): |
|
|
add_gate(tensors, f"{prefix}.byte{b}.gt", byte_gt_weights, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.byte{b}.lt", byte_lt_weights, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.byte{b}.eq.geq", byte_gt_weights, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.byte{b}.eq.leq", byte_lt_weights, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.byte{b}.eq.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for b in range(num_bytes): |
|
|
if b == 0: |
|
|
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}", [1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}", [1.0], [-1.0]) |
|
|
else: |
|
|
eq_weights = [1.0] * b |
|
|
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}.all_eq", eq_weights, [-float(b)]) |
|
|
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}.and", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}.all_eq", eq_weights, [-float(b)]) |
|
|
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
or_weights_gt = [1.0] * num_bytes |
|
|
or_weights_lt = [1.0] * num_bytes |
|
|
add_gate(tensors, f"arithmetic.greaterthan{bits}bit", or_weights_gt, [-1.0]) |
|
|
add_gate(tensors, f"arithmetic.lessthan{bits}bit", or_weights_lt, [-1.0]) |
|
|
|
|
|
not_lt_weights = [-1.0] |
|
|
add_gate(tensors, f"arithmetic.greaterorequal{bits}bit.not_lt", not_lt_weights, [0.0]) |
|
|
add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", [1.0], [-1.0]) |
|
|
|
|
|
not_gt_weights = [-1.0] |
|
|
add_gate(tensors, f"arithmetic.lessorequal{bits}bit.not_gt", not_gt_weights, [0.0]) |
|
|
add_gate(tensors, f"arithmetic.lessorequal{bits}bit", [1.0], [-1.0]) |
|
|
|
|
|
eq_all_weights = [1.0] * num_bytes |
|
|
add_gate(tensors, f"arithmetic.equality{bits}bit", eq_all_weights, [-float(num_bytes)]) |
|
|
|
|
|
|
|
|
def add_mul_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit multiplication circuit. |
|
|
|
|
|
Produces low N bits of the 2N-bit result. |
|
|
|
|
|
Structure: |
|
|
- N*N AND gates for partial products P[i][j] = A[i] AND B[j] |
|
|
- Shift-add accumulation using existing adder circuits |
|
|
|
|
|
For 32-bit: 1024 AND gates for partial products. |
|
|
""" |
|
|
for i in range(bits): |
|
|
for j in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.mul.pp.a{i}b{j}", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_div_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit division circuit. |
|
|
|
|
|
Uses restoring division algorithm with N iterations. |
|
|
""" |
|
|
pos_weights = [float(1 << (bits - 1 - i)) for i in range(bits)] |
|
|
neg_weights = [-w for w in pos_weights] |
|
|
cmp_weights = pos_weights + neg_weights |
|
|
|
|
|
for stage in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.cmp", cmp_weights, [0.0]) |
|
|
|
|
|
for stage in range(bits): |
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.mux.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.mux.bit{bit}.and_a", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.mux.bit{bit}.and_b", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.mux.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_bitwise_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit bitwise operation circuits (AND, OR, XOR, NOT). |
|
|
|
|
|
These are simply N copies of the 1-bit gates. |
|
|
""" |
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.and.bit{bit}", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.or.bit{bit}", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.xor.bit{bit}.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.xor.bit{bit}.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.xor.bit{bit}.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.not.bit{bit}", [-1.0], [0.0]) |
|
|
|
|
|
|
|
|
def add_shift_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit shift circuits (SHL, SHR by 1 position). |
|
|
|
|
|
SHL: out[i] = in[i+1] for i<N-1, out[N-1] = 0 |
|
|
SHR: out[0] = 0, out[i] = in[i-1] for i>0 |
|
|
""" |
|
|
for bit in range(bits): |
|
|
if bit < bits - 1: |
|
|
add_gate(tensors, f"alu.alu{bits}bit.shl.bit{bit}", [2.0], [-1.0]) |
|
|
else: |
|
|
add_gate(tensors, f"alu.alu{bits}bit.shl.bit{bit}", [0.0], [-1.0]) |
|
|
|
|
|
for bit in range(bits): |
|
|
if bit > 0: |
|
|
add_gate(tensors, f"alu.alu{bits}bit.shr.bit{bit}", [2.0], [-1.0]) |
|
|
else: |
|
|
add_gate(tensors, f"alu.alu{bits}bit.shr.bit{bit}", [0.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_inc_dec_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit INC and DEC circuits.""" |
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.inc.bit{bit}.carry", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.not_a", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.borrow", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_neg_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None: |
|
|
"""Add N-bit NEG circuit (two's complement negation).""" |
|
|
for bit in range(bits): |
|
|
add_gate(tensors, f"alu.alu{bits}bit.neg.not.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.carry", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_float16_core(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float16 core circuits (unpack, pack, classify, normalize). |
|
|
|
|
|
IEEE 754 half-precision format (16 bits): |
|
|
- Bit 15: Sign (0=positive, 1=negative) |
|
|
- Bits 14-10: Exponent (5 bits, bias=15) |
|
|
- Bits 9-0: Mantissa/fraction (10 bits, implicit leading 1 for normalized) |
|
|
|
|
|
Special values: |
|
|
- Zero: exp=0, frac=0 |
|
|
- Subnormal: exp=0, frac≠0 |
|
|
- Infinity: exp=31, frac=0 |
|
|
- NaN: exp=31, frac≠0 |
|
|
""" |
|
|
prefix = "float16" |
|
|
|
|
|
for i in range(16): |
|
|
add_gate(tensors, f"{prefix}.unpack.bit{i}", [1.0], [0.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.classify.exp_zero", [-1.0] * 5, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.exp_max", [1.0] * 5, [-5.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.frac_zero", [-1.0] * 10, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.frac_nonzero", [1.0] * 10, [-1.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.classify.is_zero.and", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.is_subnormal.and", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.is_inf.and", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.is_nan.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for stage in range(4): |
|
|
shift = 1 << (3 - stage) |
|
|
for bit in range(11): |
|
|
add_gate(tensors, f"{prefix}.normalize.stage{stage}.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.stage{stage}.bit{bit}.and_a", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.stage{stage}.bit{bit}.and_b", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.stage{stage}.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for stage in range(4): |
|
|
shift = 1 << (3 - stage) |
|
|
for bit in range(5): |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for i in range(16): |
|
|
add_gate(tensors, f"{prefix}.pack.bit{i}", [1.0], [0.0]) |
|
|
|
|
|
|
|
|
def add_float16_add(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float16 addition circuit. |
|
|
|
|
|
Algorithm: |
|
|
1. Unpack both operands |
|
|
2. Compare exponents, align mantissas |
|
|
3. Add/subtract mantissas based on signs |
|
|
4. Normalize result |
|
|
5. Handle special cases (inf, nan, zero) |
|
|
""" |
|
|
prefix = "float16.add" |
|
|
|
|
|
pos_weights = [float(1 << (4 - i)) for i in range(5)] |
|
|
neg_weights = [-w for w in pos_weights] |
|
|
add_gate(tensors, f"{prefix}.exp_cmp.a_gt_b", pos_weights + neg_weights, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_cmp.a_lt_b", neg_weights + pos_weights, [-1.0]) |
|
|
|
|
|
for bit in range(5): |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.not_b.bit{bit}", [-1.0], [0.0]) |
|
|
|
|
|
for stage in range(4): |
|
|
shift = 1 << (3 - stage) |
|
|
for bit in range(11): |
|
|
add_gate(tensors, f"{prefix}.align.stage{stage}.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.align.stage{stage}.bit{bit}.and_a", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.align.stage{stage}.bit{bit}.and_b", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.align.stage{stage}.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(12): |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(11): |
|
|
add_gate(tensors, f"{prefix}.mant_sub.not_b.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(11): |
|
|
add_gate(tensors, f"{prefix}.mant_select.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_select.bit{bit}.and_add", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_select.bit{bit}.and_sub", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_select.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_float16_mul(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float16 multiplication circuit. |
|
|
|
|
|
Algorithm: |
|
|
1. Unpack both operands |
|
|
2. XOR signs for result sign |
|
|
3. Add exponents (subtract bias) |
|
|
4. Multiply mantissas (11x11 -> 22 bits) |
|
|
5. Normalize result |
|
|
6. Handle special cases |
|
|
""" |
|
|
prefix = "float16.mul" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(6): |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(5): |
|
|
add_gate(tensors, f"{prefix}.bias_sub.not_bias.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for i in range(11): |
|
|
for j in range(11): |
|
|
add_gate(tensors, f"{prefix}.mant_mul.pp.a{i}b{j}", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for stage in range(10): |
|
|
for bit in range(22): |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_float16_div(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float16 division circuit. |
|
|
|
|
|
Algorithm: |
|
|
1. Unpack both operands |
|
|
2. XOR signs for result sign |
|
|
3. Subtract exponents (add bias) |
|
|
4. Divide mantissas (restoring division) |
|
|
5. Normalize result |
|
|
6. Handle special cases (div by zero -> inf) |
|
|
""" |
|
|
prefix = "float16.div" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(5): |
|
|
add_gate(tensors, f"{prefix}.exp_sub.not_b.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(5): |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for stage in range(11): |
|
|
pos_weights = [float(1 << (10 - i)) for i in range(11)] |
|
|
neg_weights = [-w for w in pos_weights] |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.cmp", pos_weights + neg_weights, [0.0]) |
|
|
|
|
|
for bit in range(11): |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.not_d.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.mux.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.mux.bit{bit}.and_old", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.mux.bit{bit}.and_new", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.mux.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_float16_cmp(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float16 comparison circuits (EQ, LT, LE, GT, GE). |
|
|
|
|
|
Float comparison: |
|
|
1. Handle NaN (any comparison with NaN is false except NaN != NaN) |
|
|
2. Handle signed zeros (+0 == -0) |
|
|
3. For same signs: compare as integers (exponent then mantissa) |
|
|
4. For different signs: negative < positive (unless both zero) |
|
|
""" |
|
|
prefix = "float16.cmp" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.a.exp_max", [1.0] * 5, [-5.0]) |
|
|
add_gate(tensors, f"{prefix}.a.frac_nz", [1.0] * 10, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.a.is_nan", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.b.exp_max", [1.0] * 5, [-5.0]) |
|
|
add_gate(tensors, f"{prefix}.b.frac_nz", [1.0] * 10, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.b.is_nan", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.either_nan", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.a.is_zero.exp_zero", [-1.0] * 5, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.a.is_zero.frac_zero", [-1.0] * 10, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.a.is_zero.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.b.is_zero.exp_zero", [-1.0] * 5, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.b.is_zero.frac_zero", [-1.0] * 10, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.b.is_zero.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.both_zero", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
pos_weights = [float(1 << (14 - i)) for i in range(15)] |
|
|
neg_weights = [-w for w in pos_weights] |
|
|
add_gate(tensors, f"{prefix}.mag_a_gt_b", pos_weights + neg_weights, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_a_ge_b", pos_weights + neg_weights, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_a_lt_b", neg_weights + pos_weights, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_a_le_b", neg_weights + pos_weights, [0.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.mag_eq.geq", pos_weights + neg_weights, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_eq.leq", neg_weights + pos_weights, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_eq.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.eq.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.eq.mag_or_zero", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.eq.same_sign_or_zero", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.eq.result", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.lt.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.diff_sign.not_a_sign", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.diff_sign.a_neg", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.same_sign.pos_lt", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.same_sign.neg_gt", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.same_sign.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.case_or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.not_both_zero", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.result", [1.0, 1.0, 1.0], [-3.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.gt.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.diff_sign.not_b_sign", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.diff_sign.b_neg", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.same_sign.pos_gt", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.same_sign.neg_lt", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.same_sign.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.case_or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.not_both_zero", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.result", [1.0, 1.0, 1.0], [-3.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.le.eq_or_lt", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.le.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.le.result", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.ge.eq_or_gt", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.ge.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.ge.result", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_float32_core(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float32 core circuits (unpack, pack, classify, normalize). |
|
|
|
|
|
IEEE 754 single-precision format (32 bits): |
|
|
- Bit 31: Sign |
|
|
- Bits 30-23: Exponent (8 bits, bias=127) |
|
|
- Bits 22-0: Mantissa (23 bits, implicit leading 1) |
|
|
""" |
|
|
prefix = "float32" |
|
|
|
|
|
for i in range(32): |
|
|
add_gate(tensors, f"{prefix}.unpack.bit{i}", [1.0], [0.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.classify.exp_zero", [-1.0] * 8, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.exp_max", [1.0] * 8, [-8.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.frac_zero", [-1.0] * 23, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.frac_nonzero", [1.0] * 23, [-1.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.classify.is_zero.and", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.is_subnormal.and", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.is_inf.and", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.classify.is_nan.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for stage in range(5): |
|
|
shift = 1 << (4 - stage) |
|
|
for bit in range(24): |
|
|
add_gate(tensors, f"{prefix}.normalize.stage{stage}.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.stage{stage}.bit{bit}.and_a", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.stage{stage}.bit{bit}.and_b", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.stage{stage}.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for stage in range(5): |
|
|
for bit in range(8): |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.normalize.exp_adj.stage{stage}.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for i in range(32): |
|
|
add_gate(tensors, f"{prefix}.pack.bit{i}", [1.0], [0.0]) |
|
|
|
|
|
|
|
|
def add_float32_cmp(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float32 comparison circuits (EQ, LT, LE, GT, GE).""" |
|
|
prefix = "float32.cmp" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.a.exp_max", [1.0] * 8, [-8.0]) |
|
|
add_gate(tensors, f"{prefix}.a.frac_nz", [1.0] * 23, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.a.is_nan", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.b.exp_max", [1.0] * 8, [-8.0]) |
|
|
add_gate(tensors, f"{prefix}.b.frac_nz", [1.0] * 23, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.b.is_nan", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.either_nan", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.a.is_zero.exp_zero", [-1.0] * 8, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.a.is_zero.frac_zero", [-1.0] * 23, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.a.is_zero.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.b.is_zero.exp_zero", [-1.0] * 8, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.b.is_zero.frac_zero", [-1.0] * 23, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.b.is_zero.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.both_zero", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
pos_weights = [float(1 << (30 - i)) for i in range(31)] |
|
|
neg_weights = [-w for w in pos_weights] |
|
|
add_gate(tensors, f"{prefix}.mag_a_gt_b", pos_weights + neg_weights, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_a_ge_b", pos_weights + neg_weights, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_a_lt_b", neg_weights + pos_weights, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_a_le_b", neg_weights + pos_weights, [0.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.mag_eq.geq", pos_weights + neg_weights, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_eq.leq", neg_weights + pos_weights, [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mag_eq.and", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.eq.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.eq.mag_or_zero", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.eq.same_sign_or_zero", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.eq.result", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.lt.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.diff_sign.not_a_sign", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.diff_sign.a_neg", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.same_sign.pos_lt", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.same_sign.neg_gt", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.same_sign.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.case_or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.not_both_zero", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.lt.result", [1.0, 1.0, 1.0], [-3.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.gt.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.diff_sign.not_b_sign", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.diff_sign.b_neg", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.same_sign.pos_gt", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.same_sign.neg_lt", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.same_sign.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.case_or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.not_both_zero", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.gt.result", [1.0, 1.0, 1.0], [-3.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.le.eq_or_lt", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.le.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.le.result", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.ge.eq_or_gt", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.ge.not_nan", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.ge.result", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
|
|
|
def add_float32_add(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float32 addition circuit. |
|
|
|
|
|
Algorithm: |
|
|
1. Unpack both operands |
|
|
2. Compare exponents, align mantissas |
|
|
3. Add/subtract mantissas based on signs |
|
|
4. Normalize result |
|
|
5. Handle special cases (inf, nan, zero) |
|
|
""" |
|
|
prefix = "float32.add" |
|
|
|
|
|
pos_weights = [float(1 << (7 - i)) for i in range(8)] |
|
|
neg_weights = [-w for w in pos_weights] |
|
|
add_gate(tensors, f"{prefix}.exp_cmp.a_gt_b", pos_weights + neg_weights, [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_cmp.a_lt_b", neg_weights + pos_weights, [-1.0]) |
|
|
|
|
|
for bit in range(8): |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_diff.not_b.bit{bit}", [-1.0], [0.0]) |
|
|
|
|
|
for stage in range(5): |
|
|
shift = 1 << (4 - stage) |
|
|
for bit in range(24): |
|
|
add_gate(tensors, f"{prefix}.align.stage{stage}.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.align.stage{stage}.bit{bit}.and_a", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.align.stage{stage}.bit{bit}.and_b", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.align.stage{stage}.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(25): |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_add.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(24): |
|
|
add_gate(tensors, f"{prefix}.mant_sub.not_b.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_sub.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(24): |
|
|
add_gate(tensors, f"{prefix}.mant_select.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_select.bit{bit}.and_add", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_select.bit{bit}.and_sub", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_select.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_float32_mul(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float32 multiplication circuit. |
|
|
|
|
|
Algorithm: |
|
|
1. Unpack both operands |
|
|
2. XOR signs for result sign |
|
|
3. Add exponents (subtract bias) |
|
|
4. Multiply mantissas (24x24 -> 48 bits) |
|
|
5. Normalize result |
|
|
6. Handle special cases |
|
|
""" |
|
|
prefix = "float32.mul" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(9): |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_add.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(8): |
|
|
add_gate(tensors, f"{prefix}.bias_sub.not_bias.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_sub.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for i in range(24): |
|
|
for j in range(24): |
|
|
add_gate(tensors, f"{prefix}.mant_mul.pp.a{i}b{j}", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for stage in range(23): |
|
|
for bit in range(48): |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_mul.acc.s{stage}.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def add_float32_div(tensors: Dict[str, torch.Tensor]) -> None: |
|
|
"""Add float32 division circuit. |
|
|
|
|
|
Algorithm: |
|
|
1. Unpack both operands |
|
|
2. XOR signs for result sign |
|
|
3. Subtract exponents (add bias) |
|
|
4. Divide mantissas (restoring division) |
|
|
5. Normalize result |
|
|
6. Handle special cases (div by zero -> inf) |
|
|
""" |
|
|
prefix = "float32.div" |
|
|
|
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.sign_xor.layer2", [1.0, 1.0], [-2.0]) |
|
|
|
|
|
for bit in range(8): |
|
|
add_gate(tensors, f"{prefix}.exp_sub.not_b.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.exp_sub.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for bit in range(8): |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.bias_add.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
for stage in range(24): |
|
|
pos_weights = [float(1 << (23 - i)) for i in range(24)] |
|
|
neg_weights = [-w for w in pos_weights] |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.cmp", pos_weights + neg_weights, [0.0]) |
|
|
|
|
|
for bit in range(24): |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.not_d.bit{bit}", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha1.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.ha2.carry", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.sub.fa{bit}.carry_or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.mux.bit{bit}.not_sel", [-1.0], [0.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.mux.bit{bit}.and_old", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.mux.bit{bit}.and_new", [1.0, 1.0], [-2.0]) |
|
|
add_gate(tensors, f"{prefix}.mant_div.stage{stage}.mux.bit{bit}.or", [1.0, 1.0], [-1.0]) |
|
|
|
|
|
|
|
|
def update_manifest(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int, mem_bytes: int) -> None: |
|
|
"""Update manifest metadata tensors. |
|
|
|
|
|
Args: |
|
|
data_bits: ALU/register width (8/16/32) |
|
|
addr_bits: Address bus width (determines memory size) |
|
|
mem_bytes: Memory size in bytes (2^addr_bits) |
|
|
""" |
|
|
tensors["manifest.data_bits"] = torch.tensor([float(data_bits)], dtype=torch.float32) |
|
|
tensors["manifest.addr_bits"] = torch.tensor([float(addr_bits)], dtype=torch.float32) |
|
|
tensors["manifest.memory_bytes"] = torch.tensor([float(mem_bytes)], dtype=torch.float32) |
|
|
tensors["manifest.pc_width"] = torch.tensor([float(addr_bits)], dtype=torch.float32) |
|
|
tensors["manifest.version"] = torch.tensor([4.0], dtype=torch.float32) |
|
|
|
|
|
|
|
|
def write_manifest(path: Path, tensors: Dict[str, torch.Tensor]) -> None: |
|
|
lines: List[str] = [] |
|
|
lines.append("# Tensor Manifest") |
|
|
lines.append(f"# Total: {len(tensors)} tensors") |
|
|
for name in sorted(tensors.keys()): |
|
|
t = tensors[name] |
|
|
values = ", ".join(f"{v:.1f}" for v in t.flatten().tolist()) |
|
|
lines.append(f"{name}: shape={list(t.shape)}, values=[{values}]") |
|
|
path.write_text("\n".join(lines) + "\n", encoding="utf-8") |
|
|
|
|
|
|
|
|
def infer_boolean_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
if gate == 'boolean.not': |
|
|
return [reg.register("$x")] |
|
|
if gate in ['boolean.and', 'boolean.or', 'boolean.nand', 'boolean.nor', 'boolean.implies']: |
|
|
return [reg.register("$a"), reg.register("$b")] |
|
|
if '.layer1.neuron1' in gate or '.layer1.neuron2' in gate or '.layer1.or' in gate or '.layer1.nand' in gate: |
|
|
return [reg.register("$a"), reg.register("$b")] |
|
|
if '.layer2' in gate: |
|
|
parent = gate.rsplit('.layer2', 1)[0] |
|
|
if '.layer1.neuron1' in parent or 'xor' in parent or 'xnor' in parent or 'biimplies' in parent: |
|
|
parent = parent.rsplit('.layer1', 1)[0] if '.layer1' in parent else parent |
|
|
return [reg.register(f"{parent}.layer1.or"), reg.register(f"{parent}.layer1.nand")] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_halfadder_inputs(gate: str, prefix: str, reg: SignalRegistry) -> List[int]: |
|
|
a = reg.register(f"{prefix}.$a") |
|
|
b = reg.register(f"{prefix}.$b") |
|
|
if '.sum.layer1' in gate: |
|
|
return [a, b] |
|
|
if '.sum.layer2' in gate: |
|
|
return [reg.register(f"{prefix}.sum.layer1.or"), reg.register(f"{prefix}.sum.layer1.nand")] |
|
|
if '.carry' in gate and '.layer' not in gate: |
|
|
return [a, b] |
|
|
return [a, b] |
|
|
|
|
|
|
|
|
def infer_fulladder_inputs(gate: str, prefix: str, reg: SignalRegistry) -> List[int]: |
|
|
a = reg.register(f"{prefix}.$a") |
|
|
b = reg.register(f"{prefix}.$b") |
|
|
cin = reg.register(f"{prefix}.$cin") |
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a, b] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{prefix}.ha1.sum.layer1.or"), reg.register(f"{prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a, b] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{prefix}.ha2.sum.layer1.or"), reg.register(f"{prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{prefix}.ha1.carry"), reg.register(f"{prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_ripplecarry_inputs(gate: str, prefix: str, bits: int, reg: SignalRegistry) -> List[int]: |
|
|
for i in range(bits): |
|
|
reg.register(f"{prefix}.$a[{i}]") |
|
|
reg.register(f"{prefix}.$b[{i}]") |
|
|
m = re.search(r'\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
a_bit = reg.get_id(f"{prefix}.$a[{bit}]") |
|
|
b_bit = reg.get_id(f"{prefix}.$b[{bit}]") |
|
|
cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.fa{bit-1}.carry_or") |
|
|
fa_prefix = f"{prefix}.fa{bit}" |
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a_bit, b_bit] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a_bit, b_bit] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_expr_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
"""Infer inputs for A + B × C expression circuit (order of operations). |
|
|
|
|
|
Circuit structure: |
|
|
- Mask stage: mask.s[stage].b[bit] = B[bit] AND C[stage] |
|
|
- Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage) |
|
|
- Final add: result = A + acc.s7 |
|
|
|
|
|
Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB) |
|
|
- $x[7] = bit 0 (LSB), $x[0] = bit 7 (MSB) |
|
|
""" |
|
|
prefix = "arithmetic.expr_add_mul" |
|
|
|
|
|
|
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
reg.register(f"$c[{i}]") |
|
|
|
|
|
|
|
|
if '.mul.mask.' in gate: |
|
|
m = re.search(r'\.s(\d+)\.b(\d+)', gate) |
|
|
if m: |
|
|
stage = int(m.group(1)) |
|
|
bit = int(m.group(2)) |
|
|
|
|
|
b_input = reg.get_id(f"$b[{7-bit}]") |
|
|
c_input = reg.get_id(f"$c[{7-stage}]") |
|
|
return [b_input, c_input] |
|
|
return [] |
|
|
|
|
|
|
|
|
if '.mul.acc.' in gate: |
|
|
m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
stage = int(m.group(1)) |
|
|
bit = int(m.group(2)) |
|
|
|
|
|
|
|
|
if stage == 1: |
|
|
|
|
|
a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}") |
|
|
else: |
|
|
|
|
|
a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if bit < stage: |
|
|
b_input = reg.get_id("#0") |
|
|
else: |
|
|
b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}") |
|
|
|
|
|
|
|
|
if bit == 0: |
|
|
cin = reg.get_id("#0") |
|
|
else: |
|
|
cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or") |
|
|
|
|
|
fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}" |
|
|
|
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
|
|
|
if '.add.fa' in gate: |
|
|
m = re.search(r'\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
|
|
|
|
|
|
a_input = reg.get_id(f"$a[{7-bit}]") |
|
|
|
|
|
|
|
|
b_input = reg.register(f"{prefix}.mul.acc.s7.fa{bit}.ha2.sum.layer2") |
|
|
|
|
|
|
|
|
if bit == 0: |
|
|
cin = reg.get_id("#0") |
|
|
else: |
|
|
cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or") |
|
|
|
|
|
fa_prefix = f"{prefix}.add.fa{bit}" |
|
|
|
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_expr_paren_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
"""Infer inputs for (A + B) × C expression circuit (parenthetical override). |
|
|
|
|
|
Circuit structure: |
|
|
- Add stage: sum = A + B |
|
|
- Mask stage: mask.s[stage].b[bit] = sum[bit] AND C[stage] |
|
|
- Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage) |
|
|
|
|
|
Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB) |
|
|
""" |
|
|
prefix = "arithmetic.expr_paren_add_mul" |
|
|
|
|
|
|
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
reg.register(f"$c[{i}]") |
|
|
|
|
|
|
|
|
if '.add.fa' in gate and '.mul.' not in gate: |
|
|
m = re.search(r'\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
|
|
|
|
|
|
a_input = reg.get_id(f"$a[{7-bit}]") |
|
|
b_input = reg.get_id(f"$b[{7-bit}]") |
|
|
|
|
|
|
|
|
if bit == 0: |
|
|
cin = reg.get_id("#0") |
|
|
else: |
|
|
cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or") |
|
|
|
|
|
fa_prefix = f"{prefix}.add.fa{bit}" |
|
|
|
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
|
|
|
if '.mul.mask.' in gate: |
|
|
m = re.search(r'\.s(\d+)\.b(\d+)', gate) |
|
|
if m: |
|
|
stage = int(m.group(1)) |
|
|
bit = int(m.group(2)) |
|
|
|
|
|
sum_bit = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2") |
|
|
|
|
|
c_input = reg.get_id(f"$c[{7-stage}]") |
|
|
return [sum_bit, c_input] |
|
|
return [] |
|
|
|
|
|
|
|
|
if '.mul.acc.' in gate: |
|
|
m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
stage = int(m.group(1)) |
|
|
bit = int(m.group(2)) |
|
|
|
|
|
|
|
|
if stage == 1: |
|
|
|
|
|
a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}") |
|
|
else: |
|
|
|
|
|
a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2") |
|
|
|
|
|
|
|
|
if bit < stage: |
|
|
b_input = reg.get_id("#0") |
|
|
else: |
|
|
b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}") |
|
|
|
|
|
|
|
|
if bit == 0: |
|
|
cin = reg.get_id("#0") |
|
|
else: |
|
|
cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or") |
|
|
|
|
|
fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}" |
|
|
|
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_expr_paren_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
"""Infer inputs for (A + B) × C expression circuit (parenthetical grouping). |
|
|
|
|
|
Circuit structure: |
|
|
- Add stage: sum = A + B |
|
|
- Mask stage: mask.s[stage].b[bit] = sum[bit] AND C[stage] |
|
|
- Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage) |
|
|
|
|
|
Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB) |
|
|
""" |
|
|
prefix = "arithmetic.expr_paren" |
|
|
|
|
|
|
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
reg.register(f"$c[{i}]") |
|
|
|
|
|
|
|
|
if '.add.fa' in gate and '.mul.' not in gate: |
|
|
m = re.search(r'\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
|
|
|
|
|
|
a_input = reg.get_id(f"$a[{7-bit}]") |
|
|
b_input = reg.get_id(f"$b[{7-bit}]") |
|
|
|
|
|
|
|
|
if bit == 0: |
|
|
cin = reg.get_id("#0") |
|
|
else: |
|
|
cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or") |
|
|
|
|
|
fa_prefix = f"{prefix}.add.fa{bit}" |
|
|
|
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
|
|
|
if '.mul.mask.' in gate: |
|
|
m = re.search(r'\.s(\d+)\.b(\d+)', gate) |
|
|
if m: |
|
|
stage = int(m.group(1)) |
|
|
bit = int(m.group(2)) |
|
|
|
|
|
sum_input = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2") |
|
|
|
|
|
c_input = reg.get_id(f"$c[{7-stage}]") |
|
|
return [sum_input, c_input] |
|
|
return [] |
|
|
|
|
|
|
|
|
if '.mul.acc.' in gate: |
|
|
m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
stage = int(m.group(1)) |
|
|
bit = int(m.group(2)) |
|
|
|
|
|
|
|
|
if stage == 1: |
|
|
|
|
|
a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}") |
|
|
else: |
|
|
|
|
|
a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2") |
|
|
|
|
|
|
|
|
if bit < stage: |
|
|
b_input = reg.get_id("#0") |
|
|
else: |
|
|
b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}") |
|
|
|
|
|
|
|
|
if bit == 0: |
|
|
cin = reg.get_id("#0") |
|
|
else: |
|
|
cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or") |
|
|
|
|
|
fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}" |
|
|
|
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a_input, b_input] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
"""Infer inputs for 3-operand adder: A + B + C.""" |
|
|
prefix = "arithmetic.add3_8bit" |
|
|
|
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
reg.register(f"$c[{i}]") |
|
|
|
|
|
|
|
|
if '.stage1.' in gate: |
|
|
m = re.search(r'\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
|
|
|
a_bit = reg.get_id(f"$a[{7-bit}]") |
|
|
b_bit = reg.get_id(f"$b[{7-bit}]") |
|
|
cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.stage1.fa{bit-1}.carry_or") |
|
|
fa_prefix = f"{prefix}.stage1.fa{bit}" |
|
|
elif '.stage2.' in gate: |
|
|
m = re.search(r'\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
|
|
|
temp_bit = reg.register(f"{prefix}.stage1.fa{bit}.ha2.sum.layer2") |
|
|
c_bit = reg.get_id(f"$c[{7-bit}]") |
|
|
cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.stage2.fa{bit-1}.carry_or") |
|
|
a_bit = temp_bit |
|
|
b_bit = c_bit |
|
|
fa_prefix = f"{prefix}.stage2.fa{bit}" |
|
|
else: |
|
|
return [] |
|
|
|
|
|
if '.ha1.sum.layer1' in gate: |
|
|
return [a_bit, b_bit] |
|
|
if '.ha1.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] |
|
|
if '.ha1.carry' in gate and '.layer' not in gate: |
|
|
return [a_bit, b_bit] |
|
|
if '.ha2.sum.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.ha2.sum.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] |
|
|
if '.ha2.carry' in gate and '.layer' not in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] |
|
|
if '.carry_or' in gate: |
|
|
return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_adcsbc_inputs(gate: str, prefix: str, is_sub: bool, reg: SignalRegistry) -> List[int]: |
|
|
for i in range(8): |
|
|
reg.register(f"{prefix}.$a[{i}]") |
|
|
reg.register(f"{prefix}.$b[{i}]") |
|
|
reg.register(f"{prefix}.$cin") |
|
|
if is_sub and '.notb' in gate: |
|
|
m = re.search(r'\.notb(\d+)', gate) |
|
|
if m: |
|
|
return [reg.get_id(f"{prefix}.$b[{int(m.group(1))}]")] |
|
|
return [] |
|
|
m = re.search(r'\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
if is_sub: |
|
|
a_bit = reg.get_id(f"{prefix}.$a[{bit}]") |
|
|
notb = reg.register(f"{prefix}.notb{bit}") |
|
|
else: |
|
|
a_bit = reg.get_id(f"{prefix}.$a[{bit}]") |
|
|
notb = reg.get_id(f"{prefix}.$b[{bit}]") |
|
|
cin = reg.get_id(f"{prefix}.$cin") if bit == 0 else reg.register(f"{prefix}.fa{bit-1}.or_carry") |
|
|
fa_prefix = f"{prefix}.fa{bit}" |
|
|
if '.xor1.layer1' in gate: |
|
|
return [a_bit, notb if is_sub else reg.get_id(f"{prefix}.$b[{bit}]")] |
|
|
if '.xor1.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.xor1.layer1.or"), reg.register(f"{fa_prefix}.xor1.layer1.nand")] |
|
|
if '.xor2.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.xor1.layer2"), cin] |
|
|
if '.xor2.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.xor2.layer1.or"), reg.register(f"{fa_prefix}.xor2.layer1.nand")] |
|
|
if '.and1' in gate: |
|
|
return [a_bit, notb if is_sub else reg.get_id(f"{prefix}.$b[{bit}]")] |
|
|
if '.and2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.xor1.layer2"), cin] |
|
|
if '.or_carry' in gate: |
|
|
return [reg.register(f"{fa_prefix}.and1"), reg.register(f"{fa_prefix}.and2")] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_sub8bit_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
prefix = "arithmetic.sub8bit" |
|
|
for i in range(8): |
|
|
reg.register(f"{prefix}.$a[{i}]") |
|
|
reg.register(f"{prefix}.$b[{i}]") |
|
|
if gate == f"{prefix}.carry_in": |
|
|
return [reg.get_id("#1")] |
|
|
if '.notb' in gate: |
|
|
m = re.search(r'\.notb(\d+)', gate) |
|
|
if m: |
|
|
return [reg.get_id(f"{prefix}.$b[{int(m.group(1))}]")] |
|
|
return [] |
|
|
m = re.search(r'\.fa(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
a_bit = reg.get_id(f"{prefix}.$a[{bit}]") |
|
|
notb = reg.register(f"{prefix}.notb{bit}") |
|
|
cin = reg.get_id("#1") if bit == 0 else reg.register(f"{prefix}.fa{bit-1}.or_carry") |
|
|
fa_prefix = f"{prefix}.fa{bit}" |
|
|
if '.xor1.layer1' in gate: |
|
|
return [a_bit, notb] |
|
|
if '.xor1.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.xor1.layer1.or"), reg.register(f"{fa_prefix}.xor1.layer1.nand")] |
|
|
if '.xor2.layer1' in gate: |
|
|
return [reg.register(f"{fa_prefix}.xor1.layer2"), cin] |
|
|
if '.xor2.layer2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.xor2.layer1.or"), reg.register(f"{fa_prefix}.xor2.layer1.nand")] |
|
|
if '.and1' in gate: |
|
|
return [a_bit, notb] |
|
|
if '.and2' in gate: |
|
|
return [reg.register(f"{fa_prefix}.xor1.layer2"), cin] |
|
|
if '.or_carry' in gate: |
|
|
return [reg.register(f"{fa_prefix}.and1"), reg.register(f"{fa_prefix}.and2")] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_threshold_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
for i in range(8): |
|
|
reg.register(f"$x[{i}]") |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
|
|
|
|
|
|
def infer_modular_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
for i in range(8): |
|
|
reg.register(f"$x[{i}]") |
|
|
if '.layer1' in gate or '.layer2' in gate or '.layer3' in gate: |
|
|
if 'layer1.geq' in gate or 'layer1.leq' in gate: |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
if 'layer2.eq' in gate: |
|
|
m = re.search(r'layer2\.eq(\d+)', gate) |
|
|
if m: |
|
|
idx = m.group(1) |
|
|
parent = gate.rsplit('.layer2', 1)[0] |
|
|
return [reg.register(f"{parent}.layer1.geq{idx}"), reg.register(f"{parent}.layer1.leq{idx}")] |
|
|
if 'layer3.or' in gate: |
|
|
parent = gate.rsplit('.layer3', 1)[0] |
|
|
eq_gates = [] |
|
|
for i in range(256): |
|
|
eq_gate = f"{parent}.layer2.eq{i}" |
|
|
if eq_gate in reg.name_to_id: |
|
|
eq_gates.append(reg.get_id(eq_gate)) |
|
|
return eq_gates if eq_gates else [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
|
|
|
|
|
|
def infer_control_jump_inputs(gate: str, prefix: str, reg: SignalRegistry) -> List[int]: |
|
|
for i in range(8): |
|
|
reg.register(f"{prefix}.$pc[{i}]") |
|
|
reg.register(f"{prefix}.$target[{i}]") |
|
|
flag = "$cond" |
|
|
if "jz" in prefix: |
|
|
flag = "$zero" |
|
|
elif "jc" in prefix: |
|
|
flag = "$carry" |
|
|
elif "jn" in prefix and "jnc" not in prefix and "jnz" not in prefix and "jnv" not in prefix: |
|
|
flag = "$negative" |
|
|
elif "jv" in prefix and "jnv" not in prefix: |
|
|
flag = "$overflow" |
|
|
elif "jp" in prefix: |
|
|
flag = "$positive" |
|
|
elif "jnc" in prefix: |
|
|
flag = "$not_carry" |
|
|
elif "jnz" in prefix: |
|
|
flag = "$not_zero" |
|
|
elif "jnv" in prefix: |
|
|
flag = "$not_overflow" |
|
|
reg.register(f"{prefix}.{flag}") |
|
|
m = re.search(r'\.bit(\d+)\.', gate) |
|
|
if not m: |
|
|
return [] |
|
|
bit = int(m.group(1)) |
|
|
bit_prefix = f"{prefix}.bit{bit}" |
|
|
if '.not_sel' in gate: |
|
|
return [reg.get_id(f"{prefix}.{flag}")] |
|
|
if '.and_a' in gate: |
|
|
return [reg.get_id(f"{prefix}.$pc[{bit}]"), reg.register(f"{bit_prefix}.not_sel")] |
|
|
if '.and_b' in gate: |
|
|
return [reg.get_id(f"{prefix}.$target[{bit}]"), reg.get_id(f"{prefix}.{flag}")] |
|
|
if '.or' in gate: |
|
|
return [reg.register(f"{bit_prefix}.and_a"), reg.register(f"{bit_prefix}.and_b")] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_buffer_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
m = re.search(r'\.bit(\d+)$', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
prefix = gate.rsplit('.bit', 1)[0] |
|
|
return [reg.register(f"{prefix}.$data[{bit}]")] |
|
|
return [reg.register("$data")] |
|
|
|
|
|
|
|
|
def infer_memory_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
if 'addr_decode' in gate: |
|
|
return [reg.register(f"$addr[{i}]") for i in range(16)] |
|
|
if 'read' in gate: |
|
|
return [reg.register("$mem"), reg.register("$sel")] |
|
|
if 'write' in gate: |
|
|
return [reg.register("$mem"), reg.register("$data"), reg.register("$sel"), reg.register("$we")] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_alu_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
for i in range(4): |
|
|
reg.register(f"$opcode[{i}]") |
|
|
if 'alucontrol' in gate: |
|
|
return [reg.get_id(f"$opcode[{i}]") for i in range(4)] |
|
|
if 'aluflags' in gate: |
|
|
return [reg.register("$result"), reg.register("$carry"), reg.register("$overflow")] |
|
|
if '.shl.bit' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
if bit < 7: |
|
|
return [reg.get_id(f"$a[{bit + 1}]")] |
|
|
else: |
|
|
return [reg.get_id("#0")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if '.shr.bit' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
if bit > 0: |
|
|
return [reg.get_id(f"$a[{bit - 1}]")] |
|
|
else: |
|
|
return [reg.get_id("#0")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if '.mul.pp.a' in gate: |
|
|
m = re.search(r'a(\d+)b(\d+)', gate) |
|
|
if m: |
|
|
i, j = int(m.group(1)), int(m.group(2)) |
|
|
return [reg.get_id(f"$a[{i}]"), reg.get_id(f"$b[{j}]")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
if '.mul.' in gate: |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
if '.div.stage' in gate: |
|
|
if '.cmp' in gate: |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
if '.mux.bit' in gate: |
|
|
m = re.search(r'stage(\d+)\.mux\.bit(\d+)', gate) |
|
|
if m: |
|
|
stage, bit = int(m.group(1)), int(m.group(2)) |
|
|
prefix = f"alu.alu8bit.div.stage{stage}" |
|
|
if '.not_sel' in gate: |
|
|
return [reg.register(f"{prefix}.cmp")] |
|
|
if '.and_a' in gate: |
|
|
return [reg.register(f"$rem[{bit}]"), reg.register(f"{prefix}.mux.bit{bit}.not_sel")] |
|
|
if '.and_b' in gate: |
|
|
return [reg.register(f"$sub[{bit}]"), reg.register(f"{prefix}.cmp")] |
|
|
if '.or' in gate: |
|
|
return [reg.register(f"{prefix}.mux.bit{bit}.and_a"), reg.register(f"{prefix}.mux.bit{bit}.and_b")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
if '.inc.bit' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
prefix = f"alu.alu8bit.inc.bit{bit}" |
|
|
if 'layer1' in gate: |
|
|
if bit == 7: |
|
|
return [reg.get_id(f"$a[{bit}]"), reg.get_id("#1")] |
|
|
else: |
|
|
return [reg.get_id(f"$a[{bit}]"), reg.register(f"alu.alu8bit.inc.bit{bit+1}.carry")] |
|
|
if 'layer2' in gate: |
|
|
return [reg.register(f"{prefix}.xor.layer1.or"), reg.register(f"{prefix}.xor.layer1.nand")] |
|
|
if '.carry' in gate: |
|
|
if bit == 7: |
|
|
return [reg.get_id(f"$a[{bit}]"), reg.get_id("#1")] |
|
|
else: |
|
|
return [reg.get_id(f"$a[{bit}]"), reg.register(f"alu.alu8bit.inc.bit{bit+1}.carry")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if '.dec.bit' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
prefix = f"alu.alu8bit.dec.bit{bit}" |
|
|
if '.not_a' in gate: |
|
|
return [reg.get_id(f"$a[{bit}]")] |
|
|
if 'layer1' in gate: |
|
|
if bit == 7: |
|
|
return [reg.get_id(f"$a[{bit}]"), reg.get_id("#1")] |
|
|
else: |
|
|
return [reg.get_id(f"$a[{bit}]"), reg.register(f"alu.alu8bit.dec.bit{bit+1}.borrow")] |
|
|
if 'layer2' in gate: |
|
|
return [reg.register(f"{prefix}.xor.layer1.or"), reg.register(f"{prefix}.xor.layer1.nand")] |
|
|
if '.borrow' in gate: |
|
|
if bit == 7: |
|
|
return [reg.register(f"{prefix}.not_a"), reg.get_id("#1")] |
|
|
else: |
|
|
return [reg.register(f"{prefix}.not_a"), reg.register(f"alu.alu8bit.dec.bit{bit+1}.borrow")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if '.neg.' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
if '.not.bit' in gate: |
|
|
return [reg.get_id(f"$a[{bit}]")] |
|
|
prefix = f"alu.alu8bit.neg.inc.bit{bit}" |
|
|
not_bit = f"alu.alu8bit.neg.not.bit{bit}" |
|
|
if 'layer1' in gate: |
|
|
if bit == 7: |
|
|
return [reg.register(not_bit), reg.get_id("#1")] |
|
|
else: |
|
|
return [reg.register(not_bit), reg.register(f"alu.alu8bit.neg.inc.bit{bit+1}.carry")] |
|
|
if 'layer2' in gate: |
|
|
return [reg.register(f"{prefix}.xor.layer1.or"), reg.register(f"{prefix}.xor.layer1.nand")] |
|
|
if '.carry' in gate: |
|
|
if bit == 7: |
|
|
return [reg.register(not_bit), reg.get_id("#1")] |
|
|
else: |
|
|
return [reg.register(not_bit), reg.register(f"alu.alu8bit.neg.inc.bit{bit+1}.carry")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if '.rol.bit' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
src = (bit + 1) % 8 |
|
|
return [reg.get_id(f"$a[{src}]")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if '.ror.bit' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
src = (bit - 1) % 8 |
|
|
return [reg.get_id(f"$a[{src}]")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if '.and' in gate or '.or' in gate or '.xor' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
return [reg.get_id(f"$a[{bit}]"), reg.get_id(f"$b[{bit}]")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
if '.not' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
return [reg.get_id(f"$a[{int(m.group(1))}]")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if 'layer1' in gate or 'layer2' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
if 'layer1' in gate: |
|
|
return [reg.get_id(f"$a[{bit}]"), reg.get_id(f"$b[{bit}]")] |
|
|
parent = gate.rsplit('.layer2', 1)[0] |
|
|
return [reg.register(f"{parent}.layer1.or"), reg.register(f"{parent}.layer1.nand")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
|
|
|
|
|
|
def infer_pattern_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
for i in range(8): |
|
|
reg.register(f"$x[{i}]") |
|
|
if 'hammingdistance' in gate: |
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
|
|
|
|
|
|
def infer_error_detection_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
for i in range(8): |
|
|
reg.register(f"$x[{i}]") |
|
|
if 'hamming' in gate: |
|
|
if 'encode' in gate: |
|
|
for i in range(4): |
|
|
reg.register(f"$d[{i}]") |
|
|
return [reg.get_id(f"$d[{i}]") for i in range(4)] |
|
|
if 'decode' in gate or 'syndrome' in gate: |
|
|
for i in range(7): |
|
|
reg.register(f"$c[{i}]") |
|
|
return [reg.get_id(f"$c[{i}]") for i in range(7)] |
|
|
if 'crc' in gate: |
|
|
return [reg.register(f"$data[{i}]") for i in range(8)] |
|
|
if 'parity' in gate and 'stage' in gate: |
|
|
m = re.search(r'stage(\d+)\.xor(\d+)', gate) |
|
|
if m: |
|
|
stage = int(m.group(1)) |
|
|
idx = int(m.group(2)) |
|
|
if stage == 1: |
|
|
return [reg.get_id(f"$x[{2*idx}]"), reg.get_id(f"$x[{2*idx+1}]")] |
|
|
parent = gate.rsplit(f'.stage{stage}', 1)[0] |
|
|
prev_stage = stage - 1 |
|
|
return [ |
|
|
reg.register(f"{parent}.stage{prev_stage}.xor{2*idx}.layer2"), |
|
|
reg.register(f"{parent}.stage{prev_stage}.xor{2*idx+1}.layer2") |
|
|
] |
|
|
if 'output.not' in gate: |
|
|
parent = gate.rsplit('.output', 1)[0] |
|
|
return [reg.register(f"{parent}.stage3.xor0.layer2")] |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
|
|
|
|
|
|
def infer_combinational_inputs(gate: str, reg: SignalRegistry) -> List[int]: |
|
|
if 'decoder3to8' in gate: |
|
|
for i in range(3): |
|
|
reg.register(f"$sel[{i}]") |
|
|
return [reg.get_id(f"$sel[{i}]") for i in range(3)] |
|
|
if 'encoder8to3' in gate: |
|
|
for i in range(8): |
|
|
reg.register(f"$x[{i}]") |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
if 'multiplexer' in gate: |
|
|
if '2to1' in gate: |
|
|
return [reg.register("$a"), reg.register("$b"), reg.register("$sel")] |
|
|
if '4to1' in gate: |
|
|
return [reg.register(f"$x[{i}]") for i in range(4)] + [reg.register(f"$sel[{i}]") for i in range(2)] |
|
|
if '8to1' in gate: |
|
|
return [reg.register(f"$x[{i}]") for i in range(8)] + [reg.register(f"$sel[{i}]") for i in range(3)] |
|
|
if 'demultiplexer' in gate: |
|
|
return [reg.register("$x"), reg.register("$sel")] |
|
|
if 'regmux4to1' in gate: |
|
|
for r in range(4): |
|
|
for i in range(8): |
|
|
reg.register(f"$r{r}[{i}]") |
|
|
for i in range(2): |
|
|
reg.register(f"$sel[{i}]") |
|
|
if gate == "combinational.regmux4to1.not_s0": |
|
|
return [reg.get_id("$sel[0]")] |
|
|
if gate == "combinational.regmux4to1.not_s1": |
|
|
return [reg.get_id("$sel[1]")] |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
if '.not_s' in gate: |
|
|
sidx = 0 if 's0' in gate else 1 |
|
|
return [reg.get_id(f"$sel[{sidx}]")] |
|
|
if '.and' in gate: |
|
|
and_m = re.search(r'\.and(\d+)', gate) |
|
|
if and_m: |
|
|
and_idx = int(and_m.group(1)) |
|
|
sel0 = "combinational.regmux4to1.not_s0" if (and_idx & 1) == 0 else "$sel[0]" |
|
|
sel1 = "combinational.regmux4to1.not_s1" if (and_idx & 2) == 0 else "$sel[1]" |
|
|
return [reg.get_id(f"$r{and_idx}[{bit}]"), reg.register(sel0), reg.register(sel1)] |
|
|
if '.or' in gate: |
|
|
return [reg.register(f"combinational.regmux4to1.bit{bit}.and{i}") for i in range(4)] |
|
|
return [] |
|
|
if 'barrelshifter' in gate: |
|
|
for i in range(8): |
|
|
reg.register(f"$x[{i}]") |
|
|
for i in range(3): |
|
|
reg.register(f"$shift[{i}]") |
|
|
m = re.search(r'layer(\d+)\.bit(\d+)', gate) |
|
|
if m: |
|
|
layer, bit = int(m.group(1)), int(m.group(2)) |
|
|
shift_amount = 1 << (2 - layer) |
|
|
prefix = f"combinational.barrelshifter.layer{layer}.bit{bit}" |
|
|
if '.not_sel' in gate: |
|
|
return [reg.get_id(f"$shift[{2 - layer}]")] |
|
|
if '.and_a' in gate: |
|
|
if layer == 0: |
|
|
return [reg.get_id(f"$x[{bit}]"), reg.register(f"{prefix}.not_sel")] |
|
|
else: |
|
|
prev_prefix = f"combinational.barrelshifter.layer{layer-1}.bit{bit}" |
|
|
return [reg.register(f"{prev_prefix}.or"), reg.register(f"{prefix}.not_sel")] |
|
|
if '.and_b' in gate: |
|
|
src = (bit + shift_amount) % 8 |
|
|
if layer == 0: |
|
|
return [reg.get_id(f"$x[{src}]"), reg.get_id(f"$shift[{2 - layer}]")] |
|
|
else: |
|
|
prev_prefix = f"combinational.barrelshifter.layer{layer-1}.bit{src}" |
|
|
return [reg.register(f"{prev_prefix}.or"), reg.get_id(f"$shift[{2 - layer}]")] |
|
|
if '.or' in gate: |
|
|
return [reg.register(f"{prefix}.and_a"), reg.register(f"{prefix}.and_b")] |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
if 'priorityencoder' in gate: |
|
|
for i in range(8): |
|
|
reg.register(f"$x[{i}]") |
|
|
if '.any_ge' in gate: |
|
|
m = re.search(r'any_ge(\d+)', gate) |
|
|
if m: |
|
|
pos = int(m.group(1)) |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(pos, 8)] |
|
|
if '.is_highest' in gate: |
|
|
m = re.search(r'is_highest(\d+)', gate) |
|
|
if m: |
|
|
pos = int(m.group(1)) |
|
|
if '.not_higher' in gate: |
|
|
if pos == 0: |
|
|
return [reg.get_id("#0")] |
|
|
else: |
|
|
return [reg.register(f"combinational.priorityencoder.any_ge{pos-1}")] |
|
|
if '.and' in gate: |
|
|
return [reg.get_id(f"$x[{pos}]"), reg.register(f"combinational.priorityencoder.is_highest{pos}.not_higher")] |
|
|
if '.out' in gate: |
|
|
m = re.search(r'out(\d+)', gate) |
|
|
if m: |
|
|
out_bit = int(m.group(1)) |
|
|
inputs = [] |
|
|
for pos in range(8): |
|
|
if (pos >> out_bit) & 1: |
|
|
inputs.append(reg.register(f"combinational.priorityencoder.is_highest{pos}.and")) |
|
|
return inputs |
|
|
if '.valid' in gate: |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
return [reg.get_id(f"$x[{i}]") for i in range(8)] |
|
|
return [] |
|
|
|
|
|
|
|
|
def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]: |
|
|
if gate.startswith('manifest.'): |
|
|
return [] |
|
|
if gate.startswith('boolean.'): |
|
|
return infer_boolean_inputs(gate, reg) |
|
|
if gate.startswith('arithmetic.'): |
|
|
if 'halfadder' in gate: |
|
|
return infer_halfadder_inputs(gate, "arithmetic.halfadder", reg) |
|
|
if 'fulladder' in gate: |
|
|
return infer_fulladder_inputs(gate, "arithmetic.fulladder", reg) |
|
|
if 'ripplecarry2bit' in gate: |
|
|
return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry2bit", 2, reg) |
|
|
if 'ripplecarry4bit' in gate: |
|
|
return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry4bit", 4, reg) |
|
|
if 'ripplecarry8bit' in gate: |
|
|
return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg) |
|
|
if 'add3_8bit' in gate: |
|
|
return infer_add3_inputs(gate, reg) |
|
|
if 'expr_add_mul' in gate and 'paren' not in gate: |
|
|
return infer_expr_add_mul_inputs(gate, reg) |
|
|
if 'expr_paren_add_mul' in gate: |
|
|
return infer_expr_paren_add_mul_inputs(gate, reg) |
|
|
if 'adc8bit' in gate: |
|
|
return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg) |
|
|
if 'sbc8bit' in gate: |
|
|
return infer_adcsbc_inputs(gate, "arithmetic.sbc8bit", True, reg) |
|
|
if 'sub8bit' in gate: |
|
|
return infer_sub8bit_inputs(gate, reg) |
|
|
if any(cmp in gate for cmp in ['greaterthan8bit', 'lessthan8bit', 'greaterorequal8bit', 'lessorequal8bit']): |
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
if 'equality8bit' in gate: |
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
if 'layer1' in gate: |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
if 'layer2' in gate: |
|
|
return [reg.register("arithmetic.equality8bit.layer1.geq"), reg.register("arithmetic.equality8bit.layer1.leq")] |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] |
|
|
for i in range(8): |
|
|
reg.register(f"$a[{i}]") |
|
|
reg.register(f"$b[{i}]") |
|
|
return [reg.get_id(f"$a[{i}]") for i in range(8)] |
|
|
if gate.startswith('threshold.'): |
|
|
return infer_threshold_inputs(gate, reg) |
|
|
if gate.startswith('modular.'): |
|
|
return infer_modular_inputs(gate, reg) |
|
|
if gate.startswith('control.'): |
|
|
if any(j in gate for j in ['jz', 'jc', 'jn', 'jv', 'jp', 'jnz', 'jnc', 'jnv', 'conditionaljump']): |
|
|
prefix = gate.split('.bit')[0] if '.bit' in gate else gate.rsplit('.', 1)[0] |
|
|
return infer_control_jump_inputs(gate, prefix, reg) |
|
|
if any(b in gate for b in ['fetch', 'load', 'store', 'mem_addr']): |
|
|
return infer_buffer_inputs(gate, reg) |
|
|
if 'push.sp_dec' in gate or 'pop.sp_inc' in gate: |
|
|
for i in range(16): |
|
|
reg.register(f"$sp[{i}]") |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
op = 'push.sp_dec' if 'push' in gate else 'pop.sp_inc' |
|
|
prefix = f"control.{op}.bit{bit}" |
|
|
if 'layer1' in gate: |
|
|
if bit == 15: |
|
|
return [reg.get_id(f"$sp[{bit}]"), reg.get_id("#1")] |
|
|
else: |
|
|
carry_name = 'borrow' if 'push' in gate else 'carry' |
|
|
return [reg.get_id(f"$sp[{bit}]"), reg.register(f"control.{op}.bit{bit+1}.{carry_name}")] |
|
|
if 'layer2' in gate: |
|
|
return [reg.register(f"{prefix}.xor.layer1.or"), reg.register(f"{prefix}.xor.layer1.nand")] |
|
|
if '.borrow' in gate or '.carry' in gate: |
|
|
if bit == 15: |
|
|
return [reg.get_id(f"$sp[{bit}]"), reg.get_id("#1")] |
|
|
else: |
|
|
carry_name = 'borrow' if 'push' in gate else 'carry' |
|
|
return [reg.get_id(f"$sp[{bit}]"), reg.register(f"control.{op}.bit{bit+1}.{carry_name}")] |
|
|
return [reg.get_id(f"$sp[{i}]") for i in range(16)] |
|
|
if 'ret.addr' in gate: |
|
|
m = re.search(r'bit(\d+)', gate) |
|
|
if m: |
|
|
bit = int(m.group(1)) |
|
|
return [reg.register(f"$ret_addr[{bit}]")] |
|
|
return [reg.register(f"$ret_addr[{i}]") for i in range(16)] |
|
|
return [reg.register("$ctrl")] |
|
|
if gate.startswith('memory.'): |
|
|
return infer_memory_inputs(gate, reg) |
|
|
if gate.startswith('alu.'): |
|
|
return infer_alu_inputs(gate, reg) |
|
|
if gate.startswith('pattern_recognition.'): |
|
|
return infer_pattern_inputs(gate, reg) |
|
|
if gate.startswith('error_detection.'): |
|
|
return infer_error_detection_inputs(gate, reg) |
|
|
if gate.startswith('combinational.'): |
|
|
return infer_combinational_inputs(gate, reg) |
|
|
weight_key = f"{gate}.weight" |
|
|
if weight_key in tensors: |
|
|
w = tensors[weight_key] |
|
|
n_inputs = w.shape[0] if w.dim() == 1 else w.shape[-1] |
|
|
for i in range(n_inputs): |
|
|
reg.register(f"$input[{i}]") |
|
|
return [reg.get_id(f"$input[{i}]") for i in range(n_inputs)] |
|
|
return [] |
|
|
|
|
|
|
|
|
def build_inputs(tensors: Dict[str, torch.Tensor]) -> tuple[Dict[str, torch.Tensor], SignalRegistry, dict]: |
|
|
reg = SignalRegistry() |
|
|
gates = get_all_gates(tensors) |
|
|
stats = {"added": 0, "skipped": 0, "empty": 0} |
|
|
for gate in sorted(gates): |
|
|
inputs_key = f"{gate}.inputs" |
|
|
if inputs_key in tensors: |
|
|
stats["skipped"] += 1 |
|
|
continue |
|
|
inputs = infer_inputs_for_gate(gate, reg, tensors) |
|
|
if inputs: |
|
|
tensors[inputs_key] = torch.tensor(inputs, dtype=torch.int64) |
|
|
stats["added"] += 1 |
|
|
else: |
|
|
stats["empty"] += 1 |
|
|
return tensors, reg, stats |
|
|
|
|
|
|
|
|
def resolve_memory_config(args) -> tuple: |
|
|
"""Resolve memory configuration from args, returns (addr_bits, mem_bytes).""" |
|
|
if hasattr(args, 'memory_profile') and args.memory_profile: |
|
|
addr_bits = MEMORY_PROFILES[args.memory_profile] |
|
|
elif hasattr(args, 'addr_bits') and args.addr_bits is not None: |
|
|
addr_bits = args.addr_bits |
|
|
else: |
|
|
addr_bits = DEFAULT_ADDR_BITS |
|
|
mem_bytes = (1 << addr_bits) if addr_bits > 0 else 0 |
|
|
return addr_bits, mem_bytes |
|
|
|
|
|
|
|
|
def cmd_memory(args) -> None: |
|
|
addr_bits, mem_bytes = resolve_memory_config(args) |
|
|
|
|
|
print("=" * 60) |
|
|
print(" BUILD MEMORY CIRCUITS") |
|
|
print("=" * 60) |
|
|
print(f"\nMemory configuration:") |
|
|
print(f" Address bits: {addr_bits}") |
|
|
print(f" Memory bytes: {mem_bytes:,}") |
|
|
if addr_bits == 0: |
|
|
print(f" Mode: PURE ALU (no memory)") |
|
|
elif addr_bits <= 4: |
|
|
print(f" Mode: LLM registers") |
|
|
elif addr_bits <= 8: |
|
|
print(f" Mode: LLM scratchpad") |
|
|
elif addr_bits <= 12: |
|
|
print(f" Mode: Reduced CPU") |
|
|
else: |
|
|
print(f" Mode: Full CPU") |
|
|
|
|
|
print(f"\nLoading: {args.model}") |
|
|
tensors = load_tensors(args.model) |
|
|
print(f" Loaded {len(tensors)} tensors") |
|
|
|
|
|
print("\nDropping existing memory/control tensors...") |
|
|
drop_prefixes(tensors, [ |
|
|
"memory.addr_decode.", "memory.read.", "memory.write.", |
|
|
"control.fetch.ir.", "control.load.", "control.store.", "control.mem_addr.", |
|
|
"control.push.", "control.pop.", "control.ret.", |
|
|
"control.jz.", "control.jnz.", "control.jc.", "control.jnc.", |
|
|
"control.jp.", "control.jn.", "control.jv.", "control.jnv.", |
|
|
"flags.", |
|
|
]) |
|
|
print(f" Now {len(tensors)} tensors") |
|
|
|
|
|
if addr_bits > 0: |
|
|
print("\nGenerating memory circuits...") |
|
|
add_decoder(tensors, addr_bits, mem_bytes) |
|
|
add_memory_read_mux(tensors, mem_bytes) |
|
|
add_memory_write_cells(tensors, mem_bytes) |
|
|
print(" Added decoder, read mux, write cells") |
|
|
|
|
|
print("\nGenerating buffer gates...") |
|
|
try: |
|
|
add_fetch_load_store_buffers(tensors, args.bits, addr_bits) |
|
|
print(f" Added fetch/load/store/mem_addr buffers ({args.bits}-bit data, {addr_bits}-bit addr)") |
|
|
except ValueError as e: |
|
|
print(f" Buffers already exist: {e}") |
|
|
|
|
|
print("\nGenerating stack operation circuits...") |
|
|
try: |
|
|
add_stack_ops(tensors, args.bits, addr_bits) |
|
|
sp_gates = addr_bits * 4 * 2 |
|
|
data_gates = args.bits * 2 |
|
|
ret_gates = addr_bits |
|
|
total_gates = sp_gates + data_gates + ret_gates |
|
|
print(f" Added PUSH/POP/RET ({total_gates} gates: {args.bits}-bit data, {addr_bits}-bit SP)") |
|
|
except ValueError as e: |
|
|
print(f" Stack ops already exist: {e}") |
|
|
|
|
|
print("\nGenerating conditional jump circuits...") |
|
|
try: |
|
|
add_conditional_jumps(tensors, addr_bits) |
|
|
jump_gates = 8 * addr_bits * 4 |
|
|
print(f" Added JZ/JNZ/JC/JNC/JP/JN/JV/JNV ({jump_gates} gates: {addr_bits}-bit addresses)") |
|
|
except ValueError as e: |
|
|
print(f" Conditional jumps already exist: {e}") |
|
|
|
|
|
print("\nGenerating status flag circuits...") |
|
|
try: |
|
|
add_status_flags(tensors, args.bits) |
|
|
print(f" Added Z/N/C/V flags ({args.bits}-bit aware)") |
|
|
except ValueError as e: |
|
|
print(f" Status flags already exist: {e}") |
|
|
else: |
|
|
print("\nSkipping memory circuits (addr_bits=0, pure ALU mode)") |
|
|
|
|
|
print("\nUpdating manifest...") |
|
|
update_manifest(tensors, args.bits, addr_bits, mem_bytes) |
|
|
print(f" data_bits={args.bits}, addr_bits={addr_bits}, memory_bytes={mem_bytes:,}") |
|
|
|
|
|
if args.apply: |
|
|
print(f"\nSaving: {args.model}") |
|
|
save_file(tensors, str(args.model)) |
|
|
if args.manifest: |
|
|
write_manifest(MANIFEST_PATH, tensors) |
|
|
print(f" Wrote manifest: {MANIFEST_PATH}") |
|
|
print(" Done.") |
|
|
else: |
|
|
print("\n[DRY-RUN] Use --apply to save.") |
|
|
|
|
|
print(f"\nTotal: {len(tensors)} tensors") |
|
|
mem_params = sum(t.numel() for k, t in tensors.items() if k.startswith("memory.")) |
|
|
alu_params = sum(t.numel() for k, t in tensors.items() if not k.startswith("memory.") and not k.startswith("manifest.")) |
|
|
print(f" Memory params: {mem_params:,}") |
|
|
print(f" ALU/Logic params: {alu_params:,}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
def cmd_inputs(args) -> None: |
|
|
print("=" * 60) |
|
|
print(" BUILD .inputs TENSORS") |
|
|
print("=" * 60) |
|
|
print(f"\nLoading: {args.model}") |
|
|
tensors = load_tensors(args.model) |
|
|
print(f" Loaded {len(tensors)} tensors") |
|
|
gates = get_all_gates(tensors) |
|
|
print(f" Found {len(gates)} gates") |
|
|
print("\nBuilding .inputs tensors...") |
|
|
tensors, reg, stats = build_inputs(tensors) |
|
|
print(f"\nResults:") |
|
|
print(f" Added: {stats['added']}") |
|
|
print(f" Skipped: {stats['skipped']}") |
|
|
print(f" Empty: {stats['empty']}") |
|
|
print(f" Signals: {len(reg.name_to_id)}") |
|
|
print(f" Total: {len(tensors)}") |
|
|
if args.apply: |
|
|
print(f"\nSaving: {args.model}") |
|
|
metadata = {"signal_registry": reg.to_metadata()} |
|
|
save_file(tensors, str(args.model), metadata=metadata) |
|
|
print(" Done.") |
|
|
else: |
|
|
print("\n[DRY-RUN] Use --apply to save.") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
def cmd_alu(args) -> None: |
|
|
bits = getattr(args, 'bits', 8) or 8 |
|
|
print("=" * 60) |
|
|
print(f" BUILD ALU CIRCUITS ({bits}-bit)") |
|
|
print("=" * 60) |
|
|
print(f"\nLoading: {args.model}") |
|
|
tensors = load_tensors(args.model) |
|
|
print(f" Loaded {len(tensors)} tensors") |
|
|
|
|
|
drop_list = [ |
|
|
"alu.alu8bit.shl.", "alu.alu8bit.shr.", |
|
|
"alu.alu8bit.mul.", "alu.alu8bit.div.", |
|
|
"alu.alu8bit.inc.", "alu.alu8bit.dec.", |
|
|
"alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.", |
|
|
"arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.", |
|
|
"arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.", |
|
|
"arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.", |
|
|
"combinational.barrelshifter.", "combinational.priorityencoder.", |
|
|
"float16.", "float32.", |
|
|
] |
|
|
|
|
|
if bits in [16, 32]: |
|
|
drop_list.extend([ |
|
|
f"alu.alu{bits}bit.", f"arithmetic.ripplecarry{bits}bit.", |
|
|
f"arithmetic.sub{bits}bit.", f"arithmetic.greaterthan{bits}bit.", |
|
|
f"arithmetic.lessthan{bits}bit.", f"arithmetic.greaterorequal{bits}bit.", |
|
|
f"arithmetic.lessorequal{bits}bit.", f"arithmetic.equality{bits}bit.", |
|
|
]) |
|
|
|
|
|
print("\nDropping existing ALU extension tensors...") |
|
|
drop_prefixes(tensors, drop_list) |
|
|
print(f" Now {len(tensors)} tensors") |
|
|
print("\nGenerating SHL/SHR circuits...") |
|
|
try: |
|
|
add_shl_shr(tensors) |
|
|
print(" Added SHL (8 gates), SHR (8 gates)") |
|
|
except ValueError as e: |
|
|
print(f" SHL/SHR already exist: {e}") |
|
|
print("\nGenerating MUL circuit...") |
|
|
try: |
|
|
add_mul(tensors) |
|
|
print(" Added MUL (64 partial product AND gates)") |
|
|
except ValueError as e: |
|
|
print(f" MUL already exists: {e}") |
|
|
print("\nGenerating DIV circuit...") |
|
|
try: |
|
|
add_div(tensors) |
|
|
print(" Added DIV (8 stages x comparison + mux)") |
|
|
except ValueError as e: |
|
|
print(f" DIV already exists: {e}") |
|
|
print("\nGenerating INC/DEC circuits...") |
|
|
try: |
|
|
add_inc_dec(tensors) |
|
|
print(" Added INC (32 gates), DEC (40 gates)") |
|
|
except ValueError as e: |
|
|
print(f" INC/DEC already exist: {e}") |
|
|
print("\nGenerating NEG circuit...") |
|
|
try: |
|
|
add_neg(tensors) |
|
|
print(" Added NEG (40 gates)") |
|
|
except ValueError as e: |
|
|
print(f" NEG already exists: {e}") |
|
|
print("\nGenerating ROL/ROR circuits...") |
|
|
try: |
|
|
add_rol_ror(tensors) |
|
|
print(" Added ROL (8 gates), ROR (8 gates)") |
|
|
except ValueError as e: |
|
|
print(f" ROL/ROR already exist: {e}") |
|
|
print("\nGenerating barrel shifter...") |
|
|
try: |
|
|
add_barrel_shifter(tensors) |
|
|
print(" Added barrel shifter (96 gates)") |
|
|
except ValueError as e: |
|
|
print(f" Barrel shifter already exists: {e}") |
|
|
print("\nGenerating priority encoder...") |
|
|
try: |
|
|
add_priority_encoder(tensors) |
|
|
print(" Added priority encoder (28 gates)") |
|
|
except ValueError as e: |
|
|
print(f" Priority encoder already exists: {e}") |
|
|
print("\nGenerating comparator circuits...") |
|
|
try: |
|
|
add_comparators(tensors) |
|
|
print(" Added GT, GE, LT, LE (single-layer), EQ (two-layer)") |
|
|
except ValueError as e: |
|
|
print(f" Comparators already exist: {e}") |
|
|
print("\nGenerating 3-operand adder circuit...") |
|
|
try: |
|
|
add_add3(tensors) |
|
|
print(" Added ADD3 (16 full adders = 144 gates)") |
|
|
except ValueError as e: |
|
|
print(f" ADD3 already exists: {e}") |
|
|
print("\nGenerating expression A + B × C circuit...") |
|
|
try: |
|
|
add_expr_add_mul(tensors) |
|
|
print(" Added EXPR_ADD_MUL (64 AND + 56 + 8 full adders = 640 gates)") |
|
|
except ValueError as e: |
|
|
print(f" EXPR_ADD_MUL already exists: {e}") |
|
|
print("\nGenerating expression (A + B) × C circuit...") |
|
|
try: |
|
|
add_expr_paren(tensors) |
|
|
print(" Added EXPR_PAREN (8 + 64 AND + 56 full adders = 640 gates)") |
|
|
except ValueError as e: |
|
|
print(f" EXPR_PAREN already exists: {e}") |
|
|
|
|
|
if bits in [16, 32]: |
|
|
print(f"\n{'=' * 60}") |
|
|
print(f" GENERATING {bits}-BIT CIRCUITS") |
|
|
print(f"{'=' * 60}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit ripple carry adder...") |
|
|
try: |
|
|
add_ripple_carry_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit adder ({bits} full adders = {bits * 9} gates)") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit adder already exists: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit subtractor...") |
|
|
try: |
|
|
add_sub_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit subtractor ({bits} NOT + {bits} full adders)") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit subtractor already exists: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit comparators...") |
|
|
try: |
|
|
add_comparators_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit GT, GE, LT, LE, EQ") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit comparators already exist: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit multiplication...") |
|
|
try: |
|
|
add_mul_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit MUL ({bits * bits} partial product AND gates)") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit MUL already exists: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit division...") |
|
|
try: |
|
|
add_div_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit DIV ({bits} stages)") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit DIV already exists: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit bitwise ops (AND, OR, XOR, NOT)...") |
|
|
try: |
|
|
add_bitwise_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit AND, OR, XOR, NOT") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit bitwise ops already exist: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit shift ops (SHL, SHR)...") |
|
|
try: |
|
|
add_shift_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit SHL, SHR") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit shift ops already exist: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit INC/DEC...") |
|
|
try: |
|
|
add_inc_dec_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit INC, DEC") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit INC/DEC already exist: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit NEG...") |
|
|
try: |
|
|
add_neg_nbits(tensors, bits) |
|
|
print(f" Added {bits}-bit NEG") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit NEG already exists: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit barrel shifter...") |
|
|
try: |
|
|
add_barrel_shifter_nbits(tensors, bits) |
|
|
import math |
|
|
num_layers = max(1, math.ceil(math.log2(bits))) |
|
|
print(f" Added {bits}-bit barrel shifter ({num_layers} layers x {bits} muxes)") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit barrel shifter already exists: {e}") |
|
|
|
|
|
print(f"\nGenerating {bits}-bit priority encoder...") |
|
|
try: |
|
|
add_priority_encoder_nbits(tensors, bits) |
|
|
import math |
|
|
out_bits = max(1, math.ceil(math.log2(bits))) |
|
|
print(f" Added {bits}-bit priority encoder ({out_bits}-bit output)") |
|
|
except ValueError as e: |
|
|
print(f" {bits}-bit priority encoder already exists: {e}") |
|
|
|
|
|
print(f"\n{'=' * 60}") |
|
|
print(f" GENERATING FLOAT CIRCUITS") |
|
|
print(f"{'=' * 60}") |
|
|
|
|
|
print("\nGenerating float16 core circuits...") |
|
|
try: |
|
|
add_float16_core(tensors) |
|
|
print(" Added float16 unpack/pack/classify/normalize") |
|
|
except ValueError as e: |
|
|
print(f" float16 core already exists: {e}") |
|
|
|
|
|
print("\nGenerating float16 ADD circuit...") |
|
|
try: |
|
|
add_float16_add(tensors) |
|
|
print(" Added float16 addition (exp align + mantissa add/sub)") |
|
|
except ValueError as e: |
|
|
print(f" float16 ADD already exists: {e}") |
|
|
|
|
|
print("\nGenerating float16 MUL circuit...") |
|
|
try: |
|
|
add_float16_mul(tensors) |
|
|
print(" Added float16 multiplication (11x11 mantissa mul)") |
|
|
except ValueError as e: |
|
|
print(f" float16 MUL already exists: {e}") |
|
|
|
|
|
print("\nGenerating float16 DIV circuit...") |
|
|
try: |
|
|
add_float16_div(tensors) |
|
|
print(" Added float16 division (11-stage restoring div)") |
|
|
except ValueError as e: |
|
|
print(f" float16 DIV already exists: {e}") |
|
|
|
|
|
print("\nGenerating float16 CMP circuits...") |
|
|
try: |
|
|
add_float16_cmp(tensors) |
|
|
print(" Added float16 comparisons (EQ, LT, LE, GT, GE)") |
|
|
except ValueError as e: |
|
|
print(f" float16 CMP already exists: {e}") |
|
|
|
|
|
print("\nGenerating float32 core circuits...") |
|
|
try: |
|
|
add_float32_core(tensors) |
|
|
print(" Added float32 unpack/pack/classify/normalize") |
|
|
except ValueError as e: |
|
|
print(f" float32 core already exists: {e}") |
|
|
|
|
|
print("\nGenerating float32 ADD circuit...") |
|
|
try: |
|
|
add_float32_add(tensors) |
|
|
print(" Added float32 addition (exp align + mantissa add/sub)") |
|
|
except ValueError as e: |
|
|
print(f" float32 ADD already exists: {e}") |
|
|
|
|
|
print("\nGenerating float32 MUL circuit...") |
|
|
try: |
|
|
add_float32_mul(tensors) |
|
|
print(" Added float32 multiplication (24x24 mantissa mul)") |
|
|
except ValueError as e: |
|
|
print(f" float32 MUL already exists: {e}") |
|
|
|
|
|
print("\nGenerating float32 DIV circuit...") |
|
|
try: |
|
|
add_float32_div(tensors) |
|
|
print(" Added float32 division (24-stage restoring div)") |
|
|
except ValueError as e: |
|
|
print(f" float32 DIV already exists: {e}") |
|
|
|
|
|
print("\nGenerating float32 CMP circuits...") |
|
|
try: |
|
|
add_float32_cmp(tensors) |
|
|
print(" Added float32 comparisons (EQ, LT, LE, GT, GE)") |
|
|
except ValueError as e: |
|
|
print(f" float32 CMP already exists: {e}") |
|
|
|
|
|
if args.apply: |
|
|
print(f"\nSaving: {args.model}") |
|
|
save_file(tensors, str(args.model)) |
|
|
print(" Done.") |
|
|
else: |
|
|
print("\n[DRY-RUN] Use --apply to save.") |
|
|
|
|
|
print(f"\nTotal: {len(tensors)} tensors") |
|
|
total_params = sum(t.numel() for t in tensors.values()) |
|
|
print(f"Total params: {total_params:,}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
def cmd_all(args) -> None: |
|
|
print("Running: memory") |
|
|
cmd_memory(args) |
|
|
print("\nRunning: alu") |
|
|
cmd_alu(args) |
|
|
print("\nRunning: inputs") |
|
|
cmd_inputs(args) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Build tools for threshold computer safetensors", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Memory Profiles: |
|
|
full 64KB (16-bit addr) - Full CPU mode |
|
|
reduced 4KB (12-bit addr) - Reduced CPU |
|
|
small 1KB (10-bit addr) - 32-bit arithmetic scratch |
|
|
scratchpad 256B (8-bit addr) - LLM scratchpad |
|
|
registers 16B (4-bit addr) - LLM register file |
|
|
none 0B (no memory) - Pure ALU for LLM |
|
|
|
|
|
ALU Bit Widths: |
|
|
8 Standard 8-bit ALU (default) |
|
|
16 16-bit ALU (0-65535) |
|
|
32 32-bit ALU (0-4294967295) |
|
|
|
|
|
Output Filenames (auto-generated from config): |
|
|
Format: neural_{alu|computer}{BITS}[_{MEMORY}].safetensors |
|
|
|
|
|
Memory suffix: |
|
|
-m full -> (none) |
|
|
-m reduced -> _reduced |
|
|
-m small -> _small |
|
|
-m scratchpad -> _scratchpad |
|
|
-m registers -> _registers |
|
|
-m none -> (uses "alu" instead of "computer") |
|
|
-a N -> _addrN |
|
|
|
|
|
Examples: |
|
|
neural_alu8.safetensors # 8-bit, no memory |
|
|
neural_alu16.safetensors # 16-bit, no memory |
|
|
neural_alu32.safetensors # 32-bit, no memory |
|
|
neural_computer8.safetensors # 8-bit, full memory |
|
|
neural_computer16.safetensors # 16-bit, full memory |
|
|
neural_computer32.safetensors # 32-bit, full memory |
|
|
neural_computer8_reduced.safetensors # 8-bit, reduced memory |
|
|
neural_computer32_reduced.safetensors# 32-bit, reduced memory |
|
|
neural_computer8_small.safetensors # 8-bit, small memory |
|
|
neural_computer32_small.safetensors # 32-bit, small memory |
|
|
neural_computer8_addr12.safetensors # 8-bit, custom 12-bit address |
|
|
neural_computer32_addr10.safetensors # 32-bit, custom 10-bit address |
|
|
|
|
|
Usage (note: options must come BEFORE subcommand): |
|
|
python build.py --apply all # -> neural_computer8.safetensors |
|
|
python build.py -m none --apply all # -> neural_alu8.safetensors |
|
|
python build.py -m reduced --apply all # -> neural_computer8_reduced.safetensors |
|
|
python build.py --bits 16 --apply all # -> neural_computer16.safetensors |
|
|
python build.py --bits 32 --apply all # -> neural_computer32.safetensors |
|
|
python build.py --bits 32 -m none --apply all # -> neural_alu32.safetensors |
|
|
python build.py --bits 32 -m small --apply all # -> neural_computer32_small.safetensors |
|
|
python build.py --bits 32 -a 10 --apply all # -> neural_computer32_addr10.safetensors |
|
|
""" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model", type=Path, default=None, |
|
|
help="Output path. Auto-generated as neural_{alu|computer}{BITS}[_{MEMORY}].safetensors if not specified" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--apply", action="store_true", |
|
|
help="Apply changes to model file. Without this flag, runs in dry-run mode (no writes)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--manifest", action="store_true", |
|
|
help="Write tensors.txt manifest listing all tensors (memory command only)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--bits", "-b", |
|
|
type=int, |
|
|
choices=SUPPORTED_BITS, |
|
|
default=8, |
|
|
help="ALU bit width. 8=0-255 (default), 16=0-65535, 32=0-4294967295" |
|
|
) |
|
|
|
|
|
mem_group = parser.add_mutually_exclusive_group() |
|
|
mem_group.add_argument( |
|
|
"--memory-profile", "-m", |
|
|
choices=list(MEMORY_PROFILES.keys()), |
|
|
help="""Memory profile: |
|
|
full=64KB/16-bit addr (suffix: none), |
|
|
reduced=4KB/12-bit (suffix: _reduced), |
|
|
small=1KB/10-bit (suffix: _small), |
|
|
scratchpad=256B/8-bit (suffix: _scratchpad), |
|
|
registers=16B/4-bit (suffix: _registers), |
|
|
none=0B/pure ALU (uses 'alu' in filename)""" |
|
|
) |
|
|
mem_group.add_argument( |
|
|
"--addr-bits", "-a", |
|
|
type=int, |
|
|
choices=range(0, 17), |
|
|
metavar="N", |
|
|
help="Custom address bus width 0-16. Memory size=2^N bytes. 0=pure ALU. Suffix: _addrN" |
|
|
) |
|
|
|
|
|
subparsers = parser.add_subparsers(dest="command", help="Subcommands") |
|
|
subparsers.add_parser("memory", help="Generate memory circuits (decoder, read mux, write cells)") |
|
|
subparsers.add_parser("alu", help="Generate N-bit ALU circuits (adder, sub, mul, div, cmp, bitwise, shift)") |
|
|
subparsers.add_parser("inputs", help="Add .inputs metadata tensors for gate routing") |
|
|
subparsers.add_parser("all", help="Run all: memory -> alu -> inputs") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.model is None: |
|
|
args.model = get_model_path( |
|
|
bits=args.bits, |
|
|
memory_profile=getattr(args, 'memory_profile', None), |
|
|
addr_bits=getattr(args, 'addr_bits', None) |
|
|
) |
|
|
print(f"Auto-generated model path: {args.model}") |
|
|
|
|
|
if args.command == "memory": |
|
|
cmd_memory(args) |
|
|
elif args.command == "alu": |
|
|
cmd_alu(args) |
|
|
elif args.command == "inputs": |
|
|
cmd_inputs(args) |
|
|
elif args.command == "all": |
|
|
cmd_all(args) |
|
|
else: |
|
|
parser.print_help() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|