CharlesCNorton commited on
Commit ·
3b3f7aa
1
Parent(s): 597e7c2
infer .inputs from routing/routing.json (restored): unblocks multiplier8x8, multiplier2x2, float16/32 add/normalize/pack/classify, and ~120 other circuits with explicit inter-gate wiring
Browse files- build.py +115 -1
- routing/routing.json +0 -0
build.py
CHANGED
|
@@ -105,7 +105,7 @@ import argparse
|
|
| 105 |
import json
|
| 106 |
import re
|
| 107 |
from pathlib import Path
|
| 108 |
-
from typing import Dict, Iterable, List, Set
|
| 109 |
|
| 110 |
import torch
|
| 111 |
from safetensors import safe_open
|
|
@@ -2861,9 +2861,123 @@ def _infer_threshold_computer_bit_cascade(gate: str, reg: SignalRegistry) -> Lis
|
|
| 2861 |
return None
|
| 2862 |
|
| 2863 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2864 |
def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]:
|
| 2865 |
if gate.startswith('manifest.'):
|
| 2866 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2867 |
# Bit-cascade comparators (integer + float + div) and ternary modular detectors.
|
| 2868 |
# These are the gates emitted by add_bit_cascade_compare and the ternary
|
| 2869 |
# modular rebuild in quantize.py; cover them up front so they don't fall
|
|
|
|
| 105 |
import json
|
| 106 |
import re
|
| 107 |
from pathlib import Path
|
| 108 |
+
from typing import Any, Dict, Iterable, List, Optional, Set
|
| 109 |
|
| 110 |
import torch
|
| 111 |
from safetensors import safe_open
|
|
|
|
| 2861 |
return None
|
| 2862 |
|
| 2863 |
|
| 2864 |
+
_ROUTING_TABLE_CACHE: Dict[str, Any] | None = None
|
| 2865 |
+
|
| 2866 |
+
|
| 2867 |
+
def _load_routing_table() -> Dict[str, Any]:
|
| 2868 |
+
"""Load `routing/routing.json` once; on missing file, return {}.
|
| 2869 |
+
|
| 2870 |
+
The JSON maps each top-level circuit name (e.g. ``arithmetic.multiplier8x8``)
|
| 2871 |
+
to a dict with at least ``inputs`` (top-level external ports) and
|
| 2872 |
+
``internal`` (gate-suffix -> list-of-source-signal-name). The source
|
| 2873 |
+
signal names are either external port references (``$a[3]``), constants
|
| 2874 |
+
(``#0`` / ``#1``), or relative gate names within the same circuit
|
| 2875 |
+
(e.g. ``pp.r0.c3``). Anchored cross-circuit references are not
|
| 2876 |
+
expressed here; circuits in the table are assumed self-contained.
|
| 2877 |
+
"""
|
| 2878 |
+
global _ROUTING_TABLE_CACHE
|
| 2879 |
+
if _ROUTING_TABLE_CACHE is not None:
|
| 2880 |
+
return _ROUTING_TABLE_CACHE
|
| 2881 |
+
path = Path(__file__).parent / "routing" / "routing.json"
|
| 2882 |
+
if not path.exists():
|
| 2883 |
+
_ROUTING_TABLE_CACHE = {}
|
| 2884 |
+
return _ROUTING_TABLE_CACHE
|
| 2885 |
+
with open(path, encoding="utf-8") as f:
|
| 2886 |
+
data = json.load(f)
|
| 2887 |
+
_ROUTING_TABLE_CACHE = data.get("circuits", {}) or {}
|
| 2888 |
+
return _ROUTING_TABLE_CACHE
|
| 2889 |
+
|
| 2890 |
+
|
| 2891 |
+
def _infer_from_routing_table(
|
| 2892 |
+
gate: str, reg: SignalRegistry,
|
| 2893 |
+
known_gate_suffixes: Optional[Set[str]] = None,
|
| 2894 |
+
) -> Optional[List[int]]:
|
| 2895 |
+
"""Resolve a gate's inputs by consulting routing.json.
|
| 2896 |
+
|
| 2897 |
+
``known_gate_suffixes`` is the set of gate-name suffixes that actually
|
| 2898 |
+
exist in the loaded safetensors (within the same circuit prefix).
|
| 2899 |
+
When the routing JSON names a producer like ``stage0.bit0.ha2.sum``
|
| 2900 |
+
but that suffix isn't a real gate (only ``stage0.bit0.ha2.sum.layer2``
|
| 2901 |
+
is), this set is consulted to redirect the reference to the canonical
|
| 2902 |
+
final-layer signal. Returns None when the gate's circuit isn't
|
| 2903 |
+
covered or its suffix isn't in the routing's ``internal`` map.
|
| 2904 |
+
"""
|
| 2905 |
+
table = _load_routing_table()
|
| 2906 |
+
if not table:
|
| 2907 |
+
return None
|
| 2908 |
+
# Find the longest circuit prefix that matches.
|
| 2909 |
+
best: Optional[str] = None
|
| 2910 |
+
for circuit in table:
|
| 2911 |
+
if gate == circuit or gate.startswith(circuit + "."):
|
| 2912 |
+
if best is None or len(circuit) > len(best):
|
| 2913 |
+
best = circuit
|
| 2914 |
+
if best is None:
|
| 2915 |
+
return None
|
| 2916 |
+
info = table[best]
|
| 2917 |
+
internal = info.get("internal") or {}
|
| 2918 |
+
suffix = gate[len(best):].lstrip(".")
|
| 2919 |
+
src_names = internal.get(suffix)
|
| 2920 |
+
if src_names is None:
|
| 2921 |
+
return None
|
| 2922 |
+
|
| 2923 |
+
out: List[int] = []
|
| 2924 |
+
for nm in src_names:
|
| 2925 |
+
if nm in ("#0", "#1"):
|
| 2926 |
+
out.append(reg.register(nm))
|
| 2927 |
+
continue
|
| 2928 |
+
if nm.startswith("$"):
|
| 2929 |
+
out.append(reg.register(f"{best}.{nm}"))
|
| 2930 |
+
continue
|
| 2931 |
+
# Internal cross-gate reference. Some routing entries use a
|
| 2932 |
+
# short form (e.g. ``stage0.bit0.ha2.sum``) that resolves to a
|
| 2933 |
+
# multi-layer gate (``stage0.bit0.ha2.sum.layer2``). When the
|
| 2934 |
+
# short name isn't a real gate, redirect to its ``.layer2``
|
| 2935 |
+
# final-output sibling, which is what threshold-network XOR
|
| 2936 |
+
# cells expose.
|
| 2937 |
+
resolved = nm
|
| 2938 |
+
if known_gate_suffixes is not None and nm not in known_gate_suffixes:
|
| 2939 |
+
for cand in (f"{nm}.layer2", f"{nm}.layer1", f"{nm}.out"):
|
| 2940 |
+
if cand in known_gate_suffixes:
|
| 2941 |
+
resolved = cand
|
| 2942 |
+
break
|
| 2943 |
+
out.append(reg.register(f"{best}.{resolved}"))
|
| 2944 |
+
return out
|
| 2945 |
+
|
| 2946 |
+
|
| 2947 |
+
def _gate_suffixes_under(prefix: str, tensors: Dict[str, torch.Tensor]) -> Set[str]:
|
| 2948 |
+
"""Return the set of gate-name suffixes whose full name starts with `prefix`."""
|
| 2949 |
+
out: Set[str] = set()
|
| 2950 |
+
for k in tensors:
|
| 2951 |
+
for suf in (".weight", ".bias", ".inputs"):
|
| 2952 |
+
if k.endswith(suf):
|
| 2953 |
+
gate = k[: -len(suf)]
|
| 2954 |
+
if gate == prefix or gate.startswith(prefix + "."):
|
| 2955 |
+
out.add(gate[len(prefix):].lstrip("."))
|
| 2956 |
+
break
|
| 2957 |
+
return out
|
| 2958 |
+
|
| 2959 |
+
|
| 2960 |
def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]:
|
| 2961 |
if gate.startswith('manifest.'):
|
| 2962 |
return []
|
| 2963 |
+
# routing.json holds explicit per-gate input lists for circuits where
|
| 2964 |
+
# naming alone is insufficient (multipliers, float pipelines, etc.).
|
| 2965 |
+
# Consult it first; family-specific inferrers below remain the
|
| 2966 |
+
# fallback for anything not covered there.
|
| 2967 |
+
table = _load_routing_table()
|
| 2968 |
+
matched_circuit: Optional[str] = None
|
| 2969 |
+
if table:
|
| 2970 |
+
for circuit in table:
|
| 2971 |
+
if gate == circuit or gate.startswith(circuit + "."):
|
| 2972 |
+
if matched_circuit is None or len(circuit) > len(matched_circuit):
|
| 2973 |
+
matched_circuit = circuit
|
| 2974 |
+
suffixes: Optional[Set[str]] = (
|
| 2975 |
+
_gate_suffixes_under(matched_circuit, tensors)
|
| 2976 |
+
if matched_circuit else None
|
| 2977 |
+
)
|
| 2978 |
+
rt = _infer_from_routing_table(gate, reg, known_gate_suffixes=suffixes)
|
| 2979 |
+
if rt is not None:
|
| 2980 |
+
return rt
|
| 2981 |
# Bit-cascade comparators (integer + float + div) and ternary modular detectors.
|
| 2982 |
# These are the gates emitted by add_bit_cascade_compare and the ternary
|
| 2983 |
# modular rebuild in quantize.py; cover them up front so they don't fall
|
routing/routing.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|