| """ |
| Generate 64KB memory circuits and fetch/load/store buffers for the 8-bit threshold computer. |
| Updates neural_computer.safetensors and tensors.txt in-place. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Dict, Iterable, List |
|
|
| import torch |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
|
|
| MODEL_PATH = Path(__file__).resolve().parent.parent / "neural_computer.safetensors" |
| MANIFEST_PATH = Path(__file__).resolve().parent.parent / "tensors.txt" |
|
|
| ADDR_BITS = 16 |
| MEM_BYTES = 1 << ADDR_BITS |
|
|
|
|
| def load_tensors(path: Path) -> Dict[str, torch.Tensor]: |
| tensors: Dict[str, torch.Tensor] = {} |
| with safe_open(str(path), framework="pt") as f: |
| for name in f.keys(): |
| tensors[name] = f.get_tensor(name).float().cpu().clone() |
| return tensors |
|
|
|
|
| def add_gate(tensors: Dict[str, torch.Tensor], name: str, weight: Iterable[float], bias: Iterable[float]) -> None: |
| w_key = f"{name}.weight" |
| b_key = f"{name}.bias" |
| if w_key in tensors or b_key in tensors: |
| raise ValueError(f"Gate already exists: {name}") |
| tensors[w_key] = torch.tensor(list(weight), dtype=torch.float32) |
| tensors[b_key] = torch.tensor(list(bias), dtype=torch.float32) |
|
|
|
|
| def drop_prefixes(tensors: Dict[str, torch.Tensor], prefixes: List[str]) -> None: |
| for key in list(tensors.keys()): |
| if any(key.startswith(prefix) for prefix in prefixes): |
| del tensors[key] |
|
|
|
|
| def add_decoder(tensors: Dict[str, torch.Tensor]) -> None: |
| weights = torch.empty((MEM_BYTES, ADDR_BITS), dtype=torch.float32) |
| bias = torch.empty((MEM_BYTES,), dtype=torch.float32) |
| for addr in range(MEM_BYTES): |
| bits = [(addr >> (ADDR_BITS - 1 - i)) & 1 for i in range(ADDR_BITS)] |
| weights[addr] = torch.tensor([1.0 if bit == 1 else -1.0 for bit in bits], dtype=torch.float32) |
| bias[addr] = -float(sum(bits)) |
| tensors["memory.addr_decode.weight"] = weights |
| tensors["memory.addr_decode.bias"] = bias |
|
|
|
|
| def add_memory_read_mux(tensors: Dict[str, torch.Tensor]) -> None: |
| |
| and_weight = torch.ones((8, MEM_BYTES, 2), dtype=torch.float32) |
| and_bias = torch.full((8, MEM_BYTES), -2.0, dtype=torch.float32) |
| or_weight = torch.ones((8, MEM_BYTES), dtype=torch.float32) |
| or_bias = torch.full((8,), -1.0, dtype=torch.float32) |
| tensors["memory.read.and.weight"] = and_weight |
| tensors["memory.read.and.bias"] = and_bias |
| tensors["memory.read.or.weight"] = or_weight |
| tensors["memory.read.or.bias"] = or_bias |
|
|
|
|
| def add_memory_write_cells(tensors: Dict[str, torch.Tensor]) -> None: |
| |
| sel_weight = torch.ones((MEM_BYTES, 2), dtype=torch.float32) |
| sel_bias = torch.full((MEM_BYTES,), -2.0, dtype=torch.float32) |
| nsel_weight = torch.full((MEM_BYTES, 1), -1.0, dtype=torch.float32) |
| nsel_bias = torch.zeros((MEM_BYTES,), dtype=torch.float32) |
|
|
| and_old_weight = torch.ones((MEM_BYTES, 8, 2), dtype=torch.float32) |
| and_old_bias = torch.full((MEM_BYTES, 8), -2.0, dtype=torch.float32) |
| and_new_weight = torch.ones((MEM_BYTES, 8, 2), dtype=torch.float32) |
| and_new_bias = torch.full((MEM_BYTES, 8), -2.0, dtype=torch.float32) |
| or_weight = torch.ones((MEM_BYTES, 8, 2), dtype=torch.float32) |
| or_bias = torch.full((MEM_BYTES, 8), -1.0, dtype=torch.float32) |
|
|
| tensors["memory.write.sel.weight"] = sel_weight |
| tensors["memory.write.sel.bias"] = sel_bias |
| tensors["memory.write.nsel.weight"] = nsel_weight |
| tensors["memory.write.nsel.bias"] = nsel_bias |
| tensors["memory.write.and_old.weight"] = and_old_weight |
| tensors["memory.write.and_old.bias"] = and_old_bias |
| tensors["memory.write.and_new.weight"] = and_new_weight |
| tensors["memory.write.and_new.bias"] = and_new_bias |
| tensors["memory.write.or.weight"] = or_weight |
| tensors["memory.write.or.bias"] = or_bias |
|
|
|
|
| def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor]) -> None: |
| |
| for bit in range(16): |
| add_gate(tensors, f"control.fetch.ir.bit{bit}", [1.0], [-1.0]) |
| for bit in range(8): |
| add_gate(tensors, f"control.load.bit{bit}", [1.0], [-1.0]) |
| add_gate(tensors, f"control.store.bit{bit}", [1.0], [-1.0]) |
| for bit in range(ADDR_BITS): |
| add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0]) |
|
|
|
|
| def update_manifest(tensors: Dict[str, torch.Tensor]) -> None: |
| |
| tensors["manifest.memory_bytes"] = torch.tensor([float(MEM_BYTES)], dtype=torch.float32) |
| tensors["manifest.pc_width"] = torch.tensor([float(ADDR_BITS)], dtype=torch.float32) |
| tensors["manifest.version"] = torch.tensor([3.0], dtype=torch.float32) |
|
|
|
|
| def write_manifest(path: Path, tensors: Dict[str, torch.Tensor]) -> None: |
| lines: List[str] = [] |
| lines.append("# Tensor Manifest") |
| lines.append(f"# Total: {len(tensors)} tensors") |
| for name in sorted(tensors.keys()): |
| t = tensors[name] |
| values = ", ".join(f"{v:.1f}" for v in t.flatten().tolist()) |
| lines.append(f"{name}: shape={list(t.shape)}, values=[{values}]") |
| path.write_text("\n".join(lines) + "\n", encoding="utf-8") |
|
|
|
|
| def main() -> None: |
| tensors = load_tensors(MODEL_PATH) |
| drop_prefixes( |
| tensors, |
| [ |
| "memory.addr_decode.", |
| "memory.read.", |
| "memory.write.", |
| "control.fetch.ir.", |
| "control.load.", |
| "control.store.", |
| "control.mem_addr.", |
| ], |
| ) |
| add_decoder(tensors) |
| add_memory_read_mux(tensors) |
| add_memory_write_cells(tensors) |
| add_fetch_load_store_buffers(tensors) |
| update_manifest(tensors) |
| save_file(tensors, str(MODEL_PATH)) |
| write_manifest(MANIFEST_PATH, tensors) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|