8bit-threshold-computer / eval /build_memory.py
PortfolioAI
Add packed memory routing and 16-bit addressing
ea46629
raw
history blame
5.87 kB
"""
Generate 64KB memory circuits and fetch/load/store buffers for the 8-bit threshold computer.
Updates neural_computer.safetensors and tensors.txt in-place.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, Iterable, List
import torch
from safetensors import safe_open
from safetensors.torch import save_file
MODEL_PATH = Path(__file__).resolve().parent.parent / "neural_computer.safetensors"
MANIFEST_PATH = Path(__file__).resolve().parent.parent / "tensors.txt"
ADDR_BITS = 16
MEM_BYTES = 1 << ADDR_BITS
def load_tensors(path: Path) -> Dict[str, torch.Tensor]:
tensors: Dict[str, torch.Tensor] = {}
with safe_open(str(path), framework="pt") as f:
for name in f.keys():
tensors[name] = f.get_tensor(name).float().cpu().clone()
return tensors
def add_gate(tensors: Dict[str, torch.Tensor], name: str, weight: Iterable[float], bias: Iterable[float]) -> None:
w_key = f"{name}.weight"
b_key = f"{name}.bias"
if w_key in tensors or b_key in tensors:
raise ValueError(f"Gate already exists: {name}")
tensors[w_key] = torch.tensor(list(weight), dtype=torch.float32)
tensors[b_key] = torch.tensor(list(bias), dtype=torch.float32)
def drop_prefixes(tensors: Dict[str, torch.Tensor], prefixes: List[str]) -> None:
for key in list(tensors.keys()):
if any(key.startswith(prefix) for prefix in prefixes):
del tensors[key]
def add_decoder(tensors: Dict[str, torch.Tensor]) -> None:
weights = torch.empty((MEM_BYTES, ADDR_BITS), dtype=torch.float32)
bias = torch.empty((MEM_BYTES,), dtype=torch.float32)
for addr in range(MEM_BYTES):
bits = [(addr >> (ADDR_BITS - 1 - i)) & 1 for i in range(ADDR_BITS)] # MSB-first
weights[addr] = torch.tensor([1.0 if bit == 1 else -1.0 for bit in bits], dtype=torch.float32)
bias[addr] = -float(sum(bits))
tensors["memory.addr_decode.weight"] = weights
tensors["memory.addr_decode.bias"] = bias
def add_memory_read_mux(tensors: Dict[str, torch.Tensor]) -> None:
# Packed AND/OR weights for read mux.
and_weight = torch.ones((8, MEM_BYTES, 2), dtype=torch.float32)
and_bias = torch.full((8, MEM_BYTES), -2.0, dtype=torch.float32)
or_weight = torch.ones((8, MEM_BYTES), dtype=torch.float32)
or_bias = torch.full((8,), -1.0, dtype=torch.float32)
tensors["memory.read.and.weight"] = and_weight
tensors["memory.read.and.bias"] = and_bias
tensors["memory.read.or.weight"] = or_weight
tensors["memory.read.or.bias"] = or_bias
def add_memory_write_cells(tensors: Dict[str, torch.Tensor]) -> None:
# Packed write gate weights.
sel_weight = torch.ones((MEM_BYTES, 2), dtype=torch.float32)
sel_bias = torch.full((MEM_BYTES,), -2.0, dtype=torch.float32)
nsel_weight = torch.full((MEM_BYTES, 1), -1.0, dtype=torch.float32)
nsel_bias = torch.zeros((MEM_BYTES,), dtype=torch.float32)
and_old_weight = torch.ones((MEM_BYTES, 8, 2), dtype=torch.float32)
and_old_bias = torch.full((MEM_BYTES, 8), -2.0, dtype=torch.float32)
and_new_weight = torch.ones((MEM_BYTES, 8, 2), dtype=torch.float32)
and_new_bias = torch.full((MEM_BYTES, 8), -2.0, dtype=torch.float32)
or_weight = torch.ones((MEM_BYTES, 8, 2), dtype=torch.float32)
or_bias = torch.full((MEM_BYTES, 8), -1.0, dtype=torch.float32)
tensors["memory.write.sel.weight"] = sel_weight
tensors["memory.write.sel.bias"] = sel_bias
tensors["memory.write.nsel.weight"] = nsel_weight
tensors["memory.write.nsel.bias"] = nsel_bias
tensors["memory.write.and_old.weight"] = and_old_weight
tensors["memory.write.and_old.bias"] = and_old_bias
tensors["memory.write.and_new.weight"] = and_new_weight
tensors["memory.write.and_new.bias"] = and_new_bias
tensors["memory.write.or.weight"] = or_weight
tensors["memory.write.or.bias"] = or_bias
def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor]) -> None:
# Buffer gates: output = input (weight=1, bias=-1)
for bit in range(16):
add_gate(tensors, f"control.fetch.ir.bit{bit}", [1.0], [-1.0])
for bit in range(8):
add_gate(tensors, f"control.load.bit{bit}", [1.0], [-1.0])
add_gate(tensors, f"control.store.bit{bit}", [1.0], [-1.0])
for bit in range(ADDR_BITS):
add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
def update_manifest(tensors: Dict[str, torch.Tensor]) -> None:
# Update manifest constants to reflect 16-bit address space.
tensors["manifest.memory_bytes"] = torch.tensor([float(MEM_BYTES)], dtype=torch.float32)
tensors["manifest.pc_width"] = torch.tensor([float(ADDR_BITS)], dtype=torch.float32)
tensors["manifest.version"] = torch.tensor([3.0], dtype=torch.float32)
def write_manifest(path: Path, tensors: Dict[str, torch.Tensor]) -> None:
lines: List[str] = []
lines.append("# Tensor Manifest")
lines.append(f"# Total: {len(tensors)} tensors")
for name in sorted(tensors.keys()):
t = tensors[name]
values = ", ".join(f"{v:.1f}" for v in t.flatten().tolist())
lines.append(f"{name}: shape={list(t.shape)}, values=[{values}]")
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def main() -> None:
tensors = load_tensors(MODEL_PATH)
drop_prefixes(
tensors,
[
"memory.addr_decode.",
"memory.read.",
"memory.write.",
"control.fetch.ir.",
"control.load.",
"control.store.",
"control.mem_addr.",
],
)
add_decoder(tensors)
add_memory_read_mux(tensors)
add_memory_write_cells(tensors)
add_fetch_load_store_buffers(tensors)
update_manifest(tensors)
save_file(tensors, str(MODEL_PATH))
write_manifest(MANIFEST_PATH, tensors)
if __name__ == "__main__":
main()