CharlesCNorton commited on
Commit
3b3f7aa
·
1 Parent(s): 597e7c2

infer .inputs from routing/routing.json (restored): unblocks multiplier8x8, multiplier2x2, float16/32 add/normalize/pack/classify, and ~120 other circuits with explicit inter-gate wiring

Browse files
Files changed (2) hide show
  1. build.py +115 -1
  2. routing/routing.json +0 -0
build.py CHANGED
@@ -105,7 +105,7 @@ import argparse
105
  import json
106
  import re
107
  from pathlib import Path
108
- from typing import Dict, Iterable, List, Set
109
 
110
  import torch
111
  from safetensors import safe_open
@@ -2861,9 +2861,123 @@ def _infer_threshold_computer_bit_cascade(gate: str, reg: SignalRegistry) -> Lis
2861
  return None
2862
 
2863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2864
  def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]:
2865
  if gate.startswith('manifest.'):
2866
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2867
  # Bit-cascade comparators (integer + float + div) and ternary modular detectors.
2868
  # These are the gates emitted by add_bit_cascade_compare and the ternary
2869
  # modular rebuild in quantize.py; cover them up front so they don't fall
 
105
  import json
106
  import re
107
  from pathlib import Path
108
+ from typing import Any, Dict, Iterable, List, Optional, Set
109
 
110
  import torch
111
  from safetensors import safe_open
 
2861
  return None
2862
 
2863
 
2864
+ _ROUTING_TABLE_CACHE: Dict[str, Any] | None = None
2865
+
2866
+
2867
+ def _load_routing_table() -> Dict[str, Any]:
2868
+ """Load `routing/routing.json` once; on missing file, return {}.
2869
+
2870
+ The JSON maps each top-level circuit name (e.g. ``arithmetic.multiplier8x8``)
2871
+ to a dict with at least ``inputs`` (top-level external ports) and
2872
+ ``internal`` (gate-suffix -> list-of-source-signal-name). The source
2873
+ signal names are either external port references (``$a[3]``), constants
2874
+ (``#0`` / ``#1``), or relative gate names within the same circuit
2875
+ (e.g. ``pp.r0.c3``). Anchored cross-circuit references are not
2876
+ expressed here; circuits in the table are assumed self-contained.
2877
+ """
2878
+ global _ROUTING_TABLE_CACHE
2879
+ if _ROUTING_TABLE_CACHE is not None:
2880
+ return _ROUTING_TABLE_CACHE
2881
+ path = Path(__file__).parent / "routing" / "routing.json"
2882
+ if not path.exists():
2883
+ _ROUTING_TABLE_CACHE = {}
2884
+ return _ROUTING_TABLE_CACHE
2885
+ with open(path, encoding="utf-8") as f:
2886
+ data = json.load(f)
2887
+ _ROUTING_TABLE_CACHE = data.get("circuits", {}) or {}
2888
+ return _ROUTING_TABLE_CACHE
2889
+
2890
+
2891
+ def _infer_from_routing_table(
2892
+ gate: str, reg: SignalRegistry,
2893
+ known_gate_suffixes: Optional[Set[str]] = None,
2894
+ ) -> Optional[List[int]]:
2895
+ """Resolve a gate's inputs by consulting routing.json.
2896
+
2897
+ ``known_gate_suffixes`` is the set of gate-name suffixes that actually
2898
+ exist in the loaded safetensors (within the same circuit prefix).
2899
+ When the routing JSON names a producer like ``stage0.bit0.ha2.sum``
2900
+ but that suffix isn't a real gate (only ``stage0.bit0.ha2.sum.layer2``
2901
+ is), this set is consulted to redirect the reference to the canonical
2902
+ final-layer signal. Returns None when the gate's circuit isn't
2903
+ covered or its suffix isn't in the routing's ``internal`` map.
2904
+ """
2905
+ table = _load_routing_table()
2906
+ if not table:
2907
+ return None
2908
+ # Find the longest circuit prefix that matches.
2909
+ best: Optional[str] = None
2910
+ for circuit in table:
2911
+ if gate == circuit or gate.startswith(circuit + "."):
2912
+ if best is None or len(circuit) > len(best):
2913
+ best = circuit
2914
+ if best is None:
2915
+ return None
2916
+ info = table[best]
2917
+ internal = info.get("internal") or {}
2918
+ suffix = gate[len(best):].lstrip(".")
2919
+ src_names = internal.get(suffix)
2920
+ if src_names is None:
2921
+ return None
2922
+
2923
+ out: List[int] = []
2924
+ for nm in src_names:
2925
+ if nm in ("#0", "#1"):
2926
+ out.append(reg.register(nm))
2927
+ continue
2928
+ if nm.startswith("$"):
2929
+ out.append(reg.register(f"{best}.{nm}"))
2930
+ continue
2931
+ # Internal cross-gate reference. Some routing entries use a
2932
+ # short form (e.g. ``stage0.bit0.ha2.sum``) that resolves to a
2933
+ # multi-layer gate (``stage0.bit0.ha2.sum.layer2``). When the
2934
+ # short name isn't a real gate, redirect to its ``.layer2``
2935
+ # final-output sibling, which is what threshold-network XOR
2936
+ # cells expose.
2937
+ resolved = nm
2938
+ if known_gate_suffixes is not None and nm not in known_gate_suffixes:
2939
+ for cand in (f"{nm}.layer2", f"{nm}.layer1", f"{nm}.out"):
2940
+ if cand in known_gate_suffixes:
2941
+ resolved = cand
2942
+ break
2943
+ out.append(reg.register(f"{best}.{resolved}"))
2944
+ return out
2945
+
2946
+
2947
+ def _gate_suffixes_under(prefix: str, tensors: Dict[str, torch.Tensor]) -> Set[str]:
2948
+ """Return the set of gate-name suffixes whose full name starts with `prefix`."""
2949
+ out: Set[str] = set()
2950
+ for k in tensors:
2951
+ for suf in (".weight", ".bias", ".inputs"):
2952
+ if k.endswith(suf):
2953
+ gate = k[: -len(suf)]
2954
+ if gate == prefix or gate.startswith(prefix + "."):
2955
+ out.add(gate[len(prefix):].lstrip("."))
2956
+ break
2957
+ return out
2958
+
2959
+
2960
  def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]:
2961
  if gate.startswith('manifest.'):
2962
  return []
2963
+ # routing.json holds explicit per-gate input lists for circuits where
2964
+ # naming alone is insufficient (multipliers, float pipelines, etc.).
2965
+ # Consult it first; family-specific inferrers below remain the
2966
+ # fallback for anything not covered there.
2967
+ table = _load_routing_table()
2968
+ matched_circuit: Optional[str] = None
2969
+ if table:
2970
+ for circuit in table:
2971
+ if gate == circuit or gate.startswith(circuit + "."):
2972
+ if matched_circuit is None or len(circuit) > len(matched_circuit):
2973
+ matched_circuit = circuit
2974
+ suffixes: Optional[Set[str]] = (
2975
+ _gate_suffixes_under(matched_circuit, tensors)
2976
+ if matched_circuit else None
2977
+ )
2978
+ rt = _infer_from_routing_table(gate, reg, known_gate_suffixes=suffixes)
2979
+ if rt is not None:
2980
+ return rt
2981
  # Bit-cascade comparators (integer + float + div) and ternary modular detectors.
2982
  # These are the gates emitted by add_bit_cascade_compare and the ternary
2983
  # modular rebuild in quantize.py; cover them up front so they don't fall
routing/routing.json ADDED
The diff for this file is too large to render. See raw diff