CharlesCNorton commited on
Commit
3942c4f
·
1 Parent(s): 536bb59

infer_inputs_for_gate: handle bit-cascade and ternary modular patterns

Browse files

Adds pattern matchers covering every gate name family the bit-cascade
migration introduced:

arithmetic.cmp{N}bit.bit{i}.{gt,lt,eq.layer1.and,eq.layer1.nor,eq}
arithmetic.cmp{N}bit.cascade.eq_prefix.bit{i}
arithmetic.cmp{N}bit.cascade.{gt,lt}.bit{i}
arithmetic.{greaterthan,lessthan,equality,greaterorequal,lessorequal}{N}bit
arithmetic.{greaterorequal,lessorequal}{N}bit.{not_lt,not_gt}

alu.alu{N}bit.div.stage{S}.cmp_bc.* (same internal cascade structure)
alu.alu{N}bit.div.stage{S}.cmp(.not_lt)?

float{16,32}.cmp.mag_bc.* + mag_a_{gt,lt,ge,le}_b + mag_eq.and
float{16,32}.add.exp_cmp_bc.* + exp_cmp.{a_gt_b,a_lt_b}
float{16,32}.div.mant_div.stage{S}.cmp_bc.*
float{16,32}.div.mant_div.stage{S}.cmp(.not_lt)?

modular.mod{N}.eq.k{k}.bit{i}.match
modular.mod{N}.eq.k{k}.all
modular.mod{N} (final OR over per-multiple equality detectors)

Two helper functions: _infer_bit_cascade walks one cascade tree given
its prefix and per-bit input name templates; _infer_compare_final
handles the public output gates emitted by add_bit_cascade_compare
(N-input OR for gt/lt, N-input AND for eq, NOT-then-buffer for ge/le).

Test build of an 8-bit small variant after this change: cmd_inputs
reports Empty=0, where it previously produced thousands of gates with
no routing metadata. Stale entries from the seed file still persist
because cmd_inputs only adds .inputs (it doesn't overwrite); refreshing
those requires either a fresh build_all pass or a separate flag to drop
existing .inputs before regeneration.

Files changed (1) hide show
  1. build.py +189 -0
build.py CHANGED
@@ -2679,9 +2679,198 @@ def infer_combinational_inputs(gate: str, reg: SignalRegistry, tensors: Dict[str
2679
  return []
2680
 
2681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2682
  def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]:
2683
  if gate.startswith('manifest.'):
2684
  return []
 
 
 
 
 
 
 
2685
  if gate.startswith('boolean.'):
2686
  return infer_boolean_inputs(gate, reg)
2687
  if gate.startswith('arithmetic.'):
 
2679
  return []
2680
 
2681
 
2682
+ def _infer_bit_cascade(gate: str, reg: SignalRegistry, cmp_prefix: str,
2683
+ a_template: str, b_template: str) -> List[int] | None:
2684
+ """If `gate` is part of the bit-cascade comparator at `cmp_prefix`, return
2685
+ its inputs. `a_template` / `b_template` are format strings like
2686
+ "$a[{}]" / "$b[{}]" that produce the per-bit input signal names. Returns
2687
+ None if `gate` isn't a member of this cascade.
2688
+ """
2689
+ if not gate.startswith(cmp_prefix + "."):
2690
+ return None
2691
+ suffix = gate[len(cmp_prefix) + 1:]
2692
+
2693
+ m = re.match(r"^bit(\d+)(?:\.(.+))?$", suffix)
2694
+ if m:
2695
+ i = int(m.group(1))
2696
+ sub = m.group(2) or ""
2697
+ a_sig = a_template.format(i)
2698
+ b_sig = b_template.format(i)
2699
+ if sub in ("gt", "lt", "eq.layer1.and", "eq.layer1.nor"):
2700
+ reg.register(a_sig)
2701
+ reg.register(b_sig)
2702
+ return [reg.get_id(a_sig), reg.get_id(b_sig)]
2703
+ if sub == "eq":
2704
+ return [
2705
+ reg.register(f"{cmp_prefix}.bit{i}.eq.layer1.and"),
2706
+ reg.register(f"{cmp_prefix}.bit{i}.eq.layer1.nor"),
2707
+ ]
2708
+
2709
+ m = re.match(r"^cascade\.eq_prefix\.bit(\d+)$", suffix)
2710
+ if m:
2711
+ i = int(m.group(1))
2712
+ return [reg.register(f"{cmp_prefix}.bit{j}.eq") for j in range(i)]
2713
+
2714
+ m = re.match(r"^cascade\.(gt|lt)\.bit(\d+)$", suffix)
2715
+ if m:
2716
+ kind = m.group(1)
2717
+ i = int(m.group(2))
2718
+ return [
2719
+ reg.register(f"{cmp_prefix}.cascade.eq_prefix.bit{i}"),
2720
+ reg.register(f"{cmp_prefix}.bit{i}.{kind}"),
2721
+ ]
2722
+ return None
2723
+
2724
+
2725
+ def _infer_compare_final(gate: str, reg: SignalRegistry, cmp_prefix: str,
2726
+ out_gt: str, out_lt: str, out_ge: str, out_le: str,
2727
+ out_eq: str, bits: int) -> List[int] | None:
2728
+ """Inputs for the final outputs of an add_bit_cascade_compare emit:
2729
+ out_gt/out_lt are N-input ORs; out_eq is N-input AND; ge/le are
2730
+ NOT-then-buffer chains over lt/gt.
2731
+ """
2732
+ if gate == out_gt:
2733
+ return ([reg.register(f"{cmp_prefix}.bit0.gt")]
2734
+ + [reg.register(f"{cmp_prefix}.cascade.gt.bit{i}") for i in range(1, bits)])
2735
+ if gate == out_lt:
2736
+ return ([reg.register(f"{cmp_prefix}.bit0.lt")]
2737
+ + [reg.register(f"{cmp_prefix}.cascade.lt.bit{i}") for i in range(1, bits)])
2738
+ if gate == out_eq:
2739
+ return [reg.register(f"{cmp_prefix}.bit{i}.eq") for i in range(bits)]
2740
+ if gate == f"{out_ge}.not_lt":
2741
+ return [reg.register(out_lt)]
2742
+ if gate == out_ge:
2743
+ return [reg.register(f"{out_ge}.not_lt")]
2744
+ if gate == f"{out_le}.not_gt":
2745
+ return [reg.register(out_gt)]
2746
+ if gate == out_le:
2747
+ return [reg.register(f"{out_le}.not_gt")]
2748
+ return None
2749
+
2750
+
2751
+ def _infer_threshold_computer_bit_cascade(gate: str, reg: SignalRegistry) -> List[int] | None:
2752
+ """Match every bit-cascade location used in this codebase: integer
2753
+ comparators, integer division stage cmps, float magnitude comparators,
2754
+ float exp_cmp, and float division mantissa cmps.
2755
+ """
2756
+ # Integer arithmetic.cmp{N}bit.* + the public greaterthan/lessthan/etc.
2757
+ m = re.match(r"^arithmetic\.cmp(\d+)bit\..+", gate)
2758
+ if m:
2759
+ bits = int(m.group(1))
2760
+ return _infer_bit_cascade(gate, reg, f"arithmetic.cmp{bits}bit", "$a[{}]", "$b[{}]")
2761
+
2762
+ m = re.match(r"^arithmetic\.(greaterthan|lessthan|equality|greaterorequal|lessorequal)(\d+)bit(\.not_lt|\.not_gt)?$", gate)
2763
+ if m:
2764
+ bits = int(m.group(2))
2765
+ return _infer_compare_final(
2766
+ gate, reg, f"arithmetic.cmp{bits}bit",
2767
+ f"arithmetic.greaterthan{bits}bit", f"arithmetic.lessthan{bits}bit",
2768
+ f"arithmetic.greaterorequal{bits}bit", f"arithmetic.lessorequal{bits}bit",
2769
+ f"arithmetic.equality{bits}bit", bits,
2770
+ )
2771
+
2772
+ # Integer division stage cmps: alu.alu{N}bit.div.stage{S}.cmp_bc.* + .cmp + .cmp.not_lt
2773
+ m = re.match(r"^alu\.alu(\d+)bit\.div\.stage(\d+)\.cmp_bc\..+", gate)
2774
+ if m:
2775
+ bits, stage = int(m.group(1)), int(m.group(2))
2776
+ cp = f"alu.alu{bits}bit.div.stage{stage}.cmp_bc"
2777
+ return _infer_bit_cascade(gate, reg, cp, "$rem[{}]", "$div[{}]")
2778
+
2779
+ m = re.match(r"^alu\.alu(\d+)bit\.div\.stage(\d+)\.cmp(\.not_lt)?$", gate)
2780
+ if m:
2781
+ bits, stage = int(m.group(1)), int(m.group(2))
2782
+ cp = f"alu.alu{bits}bit.div.stage{stage}.cmp_bc"
2783
+ cmp_name = f"alu.alu{bits}bit.div.stage{stage}.cmp"
2784
+ if gate == f"{cmp_name}.not_lt":
2785
+ return [reg.register(f"{cp}.lt")]
2786
+ if gate == cmp_name:
2787
+ return [reg.register(f"{cmp_name}.not_lt")]
2788
+
2789
+ # Float magnitude comparators: float{16,32}.cmp.mag_bc.* + final outputs
2790
+ m = re.match(r"^(float16|float32)\.cmp\.mag_bc\..+", gate)
2791
+ if m:
2792
+ family = m.group(1)
2793
+ bits = 15 if family == "float16" else 31
2794
+ return _infer_bit_cascade(gate, reg, f"{family}.cmp.mag_bc",
2795
+ f"${family}_mag_a[{{}}]", f"${family}_mag_b[{{}}]")
2796
+
2797
+ m = re.match(r"^(float16|float32)\.cmp\.(mag_a_gt_b|mag_a_lt_b|mag_a_ge_b|mag_a_le_b|mag_eq\.and)(\.not_lt|\.not_gt)?$", gate)
2798
+ if m:
2799
+ family = m.group(1)
2800
+ bits = 15 if family == "float16" else 31
2801
+ return _infer_compare_final(
2802
+ gate, reg, f"{family}.cmp.mag_bc",
2803
+ f"{family}.cmp.mag_a_gt_b", f"{family}.cmp.mag_a_lt_b",
2804
+ f"{family}.cmp.mag_a_ge_b", f"{family}.cmp.mag_a_le_b",
2805
+ f"{family}.cmp.mag_eq.and", bits,
2806
+ )
2807
+
2808
+ # Float add exp_cmp: float{16,32}.add.exp_cmp_bc.* + a_gt_b / a_lt_b
2809
+ m = re.match(r"^(float16|float32)\.add\.exp_cmp_bc\..+", gate)
2810
+ if m:
2811
+ family = m.group(1)
2812
+ bits = 5 if family == "float16" else 8
2813
+ return _infer_bit_cascade(gate, reg, f"{family}.add.exp_cmp_bc",
2814
+ f"${family}_exp_a[{{}}]", f"${family}_exp_b[{{}}]")
2815
+
2816
+ m = re.match(r"^(float16|float32)\.add\.exp_cmp\.(a_gt_b|a_lt_b)$", gate)
2817
+ if m:
2818
+ family = m.group(1)
2819
+ bits = 5 if family == "float16" else 8
2820
+ cp = f"{family}.add.exp_cmp_bc"
2821
+ kind = m.group(2)
2822
+ short = "gt" if kind == "a_gt_b" else "lt"
2823
+ return ([reg.register(f"{cp}.bit0.{short}")]
2824
+ + [reg.register(f"{cp}.cascade.{short}.bit{i}") for i in range(1, bits)])
2825
+
2826
+ # Float div mantissa stage cmps
2827
+ m = re.match(r"^(float16|float32)\.div\.mant_div\.stage(\d+)\.cmp_bc\..+", gate)
2828
+ if m:
2829
+ family, stage = m.group(1), int(m.group(2))
2830
+ cp = f"{family}.div.mant_div.stage{stage}.cmp_bc"
2831
+ return _infer_bit_cascade(gate, reg, cp, "$mant_rem[{}]", "$mant_div[{}]")
2832
+
2833
+ m = re.match(r"^(float16|float32)\.div\.mant_div\.stage(\d+)\.cmp(\.not_lt)?$", gate)
2834
+ if m:
2835
+ family, stage = m.group(1), int(m.group(2))
2836
+ cp = f"{family}.div.mant_div.stage{stage}.cmp_bc"
2837
+ cmp_name = f"{family}.div.mant_div.stage{stage}.cmp"
2838
+ if gate == f"{cmp_name}.not_lt":
2839
+ return [reg.register(f"{cp}.lt")]
2840
+ if gate == cmp_name:
2841
+ return [reg.register(f"{cmp_name}.not_lt")]
2842
+
2843
+ # Modular ternary: per-multiple equality detectors + final OR
2844
+ m = re.match(r"^modular\.mod(\d+)\.eq\.k(\d+)\.bit(\d+)\.match$", gate)
2845
+ if m:
2846
+ bit_i = int(m.group(3))
2847
+ sig = f"$x[{bit_i}]"
2848
+ reg.register(sig)
2849
+ return [reg.get_id(sig)]
2850
+ m = re.match(r"^modular\.mod(\d+)\.eq\.k(\d+)\.all$", gate)
2851
+ if m:
2852
+ mod, k = int(m.group(1)), int(m.group(2))
2853
+ prefix = f"modular.mod{mod}.eq.k{k}"
2854
+ return [reg.register(f"{prefix}.bit{i}.match") for i in range(8)]
2855
+ m = re.match(r"^modular\.mod(\d+)$", gate)
2856
+ if m:
2857
+ mod = int(m.group(1))
2858
+ multiples = list(range(0, 256, mod))
2859
+ return [reg.register(f"modular.mod{mod}.eq.k{k}.all") for k in multiples]
2860
+
2861
+ return None
2862
+
2863
+
2864
  def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]:
2865
  if gate.startswith('manifest.'):
2866
  return []
2867
+ # Bit-cascade comparators (integer + float + div) and ternary modular detectors.
2868
+ # These are the gates emitted by add_bit_cascade_compare and the ternary
2869
+ # modular rebuild in quantize.py; cover them up front so they don't fall
2870
+ # through to the generic per-family handlers below.
2871
+ bc = _infer_threshold_computer_bit_cascade(gate, reg)
2872
+ if bc is not None:
2873
+ return bc
2874
  if gate.startswith('boolean.'):
2875
  return infer_boolean_inputs(gate, reg)
2876
  if gate.startswith('arithmetic.'):