infer_inputs_for_gate: handle bit-cascade and ternary modular patterns
Browse filesAdds pattern matchers covering every gate name family the bit-cascade
migration introduced:
arithmetic.cmp{N}bit.bit{i}.{gt,lt,eq.layer1.and,eq.layer1.nor,eq}
arithmetic.cmp{N}bit.cascade.eq_prefix.bit{i}
arithmetic.cmp{N}bit.cascade.{gt,lt}.bit{i}
arithmetic.{greaterthan,lessthan,equality,greaterorequal,lessorequal}{N}bit
arithmetic.{greaterorequal,lessorequal}{N}bit.{not_lt,not_gt}
alu.alu{N}bit.div.stage{S}.cmp_bc.* (same internal cascade structure)
alu.alu{N}bit.div.stage{S}.cmp(.not_lt)?
float{16,32}.cmp.mag_bc.* + mag_a_{gt,lt,ge,le}_b + mag_eq.and
float{16,32}.add.exp_cmp_bc.* + exp_cmp.{a_gt_b,a_lt_b}
float{16,32}.div.mant_div.stage{S}.cmp_bc.*
float{16,32}.div.mant_div.stage{S}.cmp(.not_lt)?
modular.mod{N}.eq.k{k}.bit{i}.match
modular.mod{N}.eq.k{k}.all
modular.mod{N} (final OR over per-multiple equality detectors)
Two helper functions: _infer_bit_cascade walks one cascade tree given
its prefix and per-bit input name templates; _infer_compare_final
handles the public output gates emitted by add_bit_cascade_compare
(N-input OR for gt/lt, N-input AND for eq, NOT-then-buffer for ge/le).
Test build of an 8-bit small variant after this change: cmd_inputs
reports Empty=0, where it previously produced thousands of gates with
no routing metadata. Stale entries from the seed file still persist
because cmd_inputs only adds .inputs (it doesn't overwrite); refreshing
those requires either a fresh build_all pass or a separate flag to drop
existing .inputs before regeneration.
|
@@ -2679,9 +2679,198 @@ def infer_combinational_inputs(gate: str, reg: SignalRegistry, tensors: Dict[str
|
|
| 2679 |
return []
|
| 2680 |
|
| 2681 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2682 |
def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]:
|
| 2683 |
if gate.startswith('manifest.'):
|
| 2684 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2685 |
if gate.startswith('boolean.'):
|
| 2686 |
return infer_boolean_inputs(gate, reg)
|
| 2687 |
if gate.startswith('arithmetic.'):
|
|
|
|
| 2679 |
return []
|
| 2680 |
|
| 2681 |
|
| 2682 |
+
def _infer_bit_cascade(gate: str, reg: SignalRegistry, cmp_prefix: str,
|
| 2683 |
+
a_template: str, b_template: str) -> List[int] | None:
|
| 2684 |
+
"""If `gate` is part of the bit-cascade comparator at `cmp_prefix`, return
|
| 2685 |
+
its inputs. `a_template` / `b_template` are format strings like
|
| 2686 |
+
"$a[{}]" / "$b[{}]" that produce the per-bit input signal names. Returns
|
| 2687 |
+
None if `gate` isn't a member of this cascade.
|
| 2688 |
+
"""
|
| 2689 |
+
if not gate.startswith(cmp_prefix + "."):
|
| 2690 |
+
return None
|
| 2691 |
+
suffix = gate[len(cmp_prefix) + 1:]
|
| 2692 |
+
|
| 2693 |
+
m = re.match(r"^bit(\d+)(?:\.(.+))?$", suffix)
|
| 2694 |
+
if m:
|
| 2695 |
+
i = int(m.group(1))
|
| 2696 |
+
sub = m.group(2) or ""
|
| 2697 |
+
a_sig = a_template.format(i)
|
| 2698 |
+
b_sig = b_template.format(i)
|
| 2699 |
+
if sub in ("gt", "lt", "eq.layer1.and", "eq.layer1.nor"):
|
| 2700 |
+
reg.register(a_sig)
|
| 2701 |
+
reg.register(b_sig)
|
| 2702 |
+
return [reg.get_id(a_sig), reg.get_id(b_sig)]
|
| 2703 |
+
if sub == "eq":
|
| 2704 |
+
return [
|
| 2705 |
+
reg.register(f"{cmp_prefix}.bit{i}.eq.layer1.and"),
|
| 2706 |
+
reg.register(f"{cmp_prefix}.bit{i}.eq.layer1.nor"),
|
| 2707 |
+
]
|
| 2708 |
+
|
| 2709 |
+
m = re.match(r"^cascade\.eq_prefix\.bit(\d+)$", suffix)
|
| 2710 |
+
if m:
|
| 2711 |
+
i = int(m.group(1))
|
| 2712 |
+
return [reg.register(f"{cmp_prefix}.bit{j}.eq") for j in range(i)]
|
| 2713 |
+
|
| 2714 |
+
m = re.match(r"^cascade\.(gt|lt)\.bit(\d+)$", suffix)
|
| 2715 |
+
if m:
|
| 2716 |
+
kind = m.group(1)
|
| 2717 |
+
i = int(m.group(2))
|
| 2718 |
+
return [
|
| 2719 |
+
reg.register(f"{cmp_prefix}.cascade.eq_prefix.bit{i}"),
|
| 2720 |
+
reg.register(f"{cmp_prefix}.bit{i}.{kind}"),
|
| 2721 |
+
]
|
| 2722 |
+
return None
|
| 2723 |
+
|
| 2724 |
+
|
| 2725 |
+
def _infer_compare_final(gate: str, reg: SignalRegistry, cmp_prefix: str,
|
| 2726 |
+
out_gt: str, out_lt: str, out_ge: str, out_le: str,
|
| 2727 |
+
out_eq: str, bits: int) -> List[int] | None:
|
| 2728 |
+
"""Inputs for the final outputs of an add_bit_cascade_compare emit:
|
| 2729 |
+
out_gt/out_lt are N-input ORs; out_eq is N-input AND; ge/le are
|
| 2730 |
+
NOT-then-buffer chains over lt/gt.
|
| 2731 |
+
"""
|
| 2732 |
+
if gate == out_gt:
|
| 2733 |
+
return ([reg.register(f"{cmp_prefix}.bit0.gt")]
|
| 2734 |
+
+ [reg.register(f"{cmp_prefix}.cascade.gt.bit{i}") for i in range(1, bits)])
|
| 2735 |
+
if gate == out_lt:
|
| 2736 |
+
return ([reg.register(f"{cmp_prefix}.bit0.lt")]
|
| 2737 |
+
+ [reg.register(f"{cmp_prefix}.cascade.lt.bit{i}") for i in range(1, bits)])
|
| 2738 |
+
if gate == out_eq:
|
| 2739 |
+
return [reg.register(f"{cmp_prefix}.bit{i}.eq") for i in range(bits)]
|
| 2740 |
+
if gate == f"{out_ge}.not_lt":
|
| 2741 |
+
return [reg.register(out_lt)]
|
| 2742 |
+
if gate == out_ge:
|
| 2743 |
+
return [reg.register(f"{out_ge}.not_lt")]
|
| 2744 |
+
if gate == f"{out_le}.not_gt":
|
| 2745 |
+
return [reg.register(out_gt)]
|
| 2746 |
+
if gate == out_le:
|
| 2747 |
+
return [reg.register(f"{out_le}.not_gt")]
|
| 2748 |
+
return None
|
| 2749 |
+
|
| 2750 |
+
|
| 2751 |
+
def _infer_threshold_computer_bit_cascade(gate: str, reg: SignalRegistry) -> List[int] | None:
|
| 2752 |
+
"""Match every bit-cascade location used in this codebase: integer
|
| 2753 |
+
comparators, integer division stage cmps, float magnitude comparators,
|
| 2754 |
+
float exp_cmp, and float division mantissa cmps.
|
| 2755 |
+
"""
|
| 2756 |
+
# Integer arithmetic.cmp{N}bit.* + the public greaterthan/lessthan/etc.
|
| 2757 |
+
m = re.match(r"^arithmetic\.cmp(\d+)bit\..+", gate)
|
| 2758 |
+
if m:
|
| 2759 |
+
bits = int(m.group(1))
|
| 2760 |
+
return _infer_bit_cascade(gate, reg, f"arithmetic.cmp{bits}bit", "$a[{}]", "$b[{}]")
|
| 2761 |
+
|
| 2762 |
+
m = re.match(r"^arithmetic\.(greaterthan|lessthan|equality|greaterorequal|lessorequal)(\d+)bit(\.not_lt|\.not_gt)?$", gate)
|
| 2763 |
+
if m:
|
| 2764 |
+
bits = int(m.group(2))
|
| 2765 |
+
return _infer_compare_final(
|
| 2766 |
+
gate, reg, f"arithmetic.cmp{bits}bit",
|
| 2767 |
+
f"arithmetic.greaterthan{bits}bit", f"arithmetic.lessthan{bits}bit",
|
| 2768 |
+
f"arithmetic.greaterorequal{bits}bit", f"arithmetic.lessorequal{bits}bit",
|
| 2769 |
+
f"arithmetic.equality{bits}bit", bits,
|
| 2770 |
+
)
|
| 2771 |
+
|
| 2772 |
+
# Integer division stage cmps: alu.alu{N}bit.div.stage{S}.cmp_bc.* + .cmp + .cmp.not_lt
|
| 2773 |
+
m = re.match(r"^alu\.alu(\d+)bit\.div\.stage(\d+)\.cmp_bc\..+", gate)
|
| 2774 |
+
if m:
|
| 2775 |
+
bits, stage = int(m.group(1)), int(m.group(2))
|
| 2776 |
+
cp = f"alu.alu{bits}bit.div.stage{stage}.cmp_bc"
|
| 2777 |
+
return _infer_bit_cascade(gate, reg, cp, "$rem[{}]", "$div[{}]")
|
| 2778 |
+
|
| 2779 |
+
m = re.match(r"^alu\.alu(\d+)bit\.div\.stage(\d+)\.cmp(\.not_lt)?$", gate)
|
| 2780 |
+
if m:
|
| 2781 |
+
bits, stage = int(m.group(1)), int(m.group(2))
|
| 2782 |
+
cp = f"alu.alu{bits}bit.div.stage{stage}.cmp_bc"
|
| 2783 |
+
cmp_name = f"alu.alu{bits}bit.div.stage{stage}.cmp"
|
| 2784 |
+
if gate == f"{cmp_name}.not_lt":
|
| 2785 |
+
return [reg.register(f"{cp}.lt")]
|
| 2786 |
+
if gate == cmp_name:
|
| 2787 |
+
return [reg.register(f"{cmp_name}.not_lt")]
|
| 2788 |
+
|
| 2789 |
+
# Float magnitude comparators: float{16,32}.cmp.mag_bc.* + final outputs
|
| 2790 |
+
m = re.match(r"^(float16|float32)\.cmp\.mag_bc\..+", gate)
|
| 2791 |
+
if m:
|
| 2792 |
+
family = m.group(1)
|
| 2793 |
+
bits = 15 if family == "float16" else 31
|
| 2794 |
+
return _infer_bit_cascade(gate, reg, f"{family}.cmp.mag_bc",
|
| 2795 |
+
f"${family}_mag_a[{{}}]", f"${family}_mag_b[{{}}]")
|
| 2796 |
+
|
| 2797 |
+
m = re.match(r"^(float16|float32)\.cmp\.(mag_a_gt_b|mag_a_lt_b|mag_a_ge_b|mag_a_le_b|mag_eq\.and)(\.not_lt|\.not_gt)?$", gate)
|
| 2798 |
+
if m:
|
| 2799 |
+
family = m.group(1)
|
| 2800 |
+
bits = 15 if family == "float16" else 31
|
| 2801 |
+
return _infer_compare_final(
|
| 2802 |
+
gate, reg, f"{family}.cmp.mag_bc",
|
| 2803 |
+
f"{family}.cmp.mag_a_gt_b", f"{family}.cmp.mag_a_lt_b",
|
| 2804 |
+
f"{family}.cmp.mag_a_ge_b", f"{family}.cmp.mag_a_le_b",
|
| 2805 |
+
f"{family}.cmp.mag_eq.and", bits,
|
| 2806 |
+
)
|
| 2807 |
+
|
| 2808 |
+
# Float add exp_cmp: float{16,32}.add.exp_cmp_bc.* + a_gt_b / a_lt_b
|
| 2809 |
+
m = re.match(r"^(float16|float32)\.add\.exp_cmp_bc\..+", gate)
|
| 2810 |
+
if m:
|
| 2811 |
+
family = m.group(1)
|
| 2812 |
+
bits = 5 if family == "float16" else 8
|
| 2813 |
+
return _infer_bit_cascade(gate, reg, f"{family}.add.exp_cmp_bc",
|
| 2814 |
+
f"${family}_exp_a[{{}}]", f"${family}_exp_b[{{}}]")
|
| 2815 |
+
|
| 2816 |
+
m = re.match(r"^(float16|float32)\.add\.exp_cmp\.(a_gt_b|a_lt_b)$", gate)
|
| 2817 |
+
if m:
|
| 2818 |
+
family = m.group(1)
|
| 2819 |
+
bits = 5 if family == "float16" else 8
|
| 2820 |
+
cp = f"{family}.add.exp_cmp_bc"
|
| 2821 |
+
kind = m.group(2)
|
| 2822 |
+
short = "gt" if kind == "a_gt_b" else "lt"
|
| 2823 |
+
return ([reg.register(f"{cp}.bit0.{short}")]
|
| 2824 |
+
+ [reg.register(f"{cp}.cascade.{short}.bit{i}") for i in range(1, bits)])
|
| 2825 |
+
|
| 2826 |
+
# Float div mantissa stage cmps
|
| 2827 |
+
m = re.match(r"^(float16|float32)\.div\.mant_div\.stage(\d+)\.cmp_bc\..+", gate)
|
| 2828 |
+
if m:
|
| 2829 |
+
family, stage = m.group(1), int(m.group(2))
|
| 2830 |
+
cp = f"{family}.div.mant_div.stage{stage}.cmp_bc"
|
| 2831 |
+
return _infer_bit_cascade(gate, reg, cp, "$mant_rem[{}]", "$mant_div[{}]")
|
| 2832 |
+
|
| 2833 |
+
m = re.match(r"^(float16|float32)\.div\.mant_div\.stage(\d+)\.cmp(\.not_lt)?$", gate)
|
| 2834 |
+
if m:
|
| 2835 |
+
family, stage = m.group(1), int(m.group(2))
|
| 2836 |
+
cp = f"{family}.div.mant_div.stage{stage}.cmp_bc"
|
| 2837 |
+
cmp_name = f"{family}.div.mant_div.stage{stage}.cmp"
|
| 2838 |
+
if gate == f"{cmp_name}.not_lt":
|
| 2839 |
+
return [reg.register(f"{cp}.lt")]
|
| 2840 |
+
if gate == cmp_name:
|
| 2841 |
+
return [reg.register(f"{cmp_name}.not_lt")]
|
| 2842 |
+
|
| 2843 |
+
# Modular ternary: per-multiple equality detectors + final OR
|
| 2844 |
+
m = re.match(r"^modular\.mod(\d+)\.eq\.k(\d+)\.bit(\d+)\.match$", gate)
|
| 2845 |
+
if m:
|
| 2846 |
+
bit_i = int(m.group(3))
|
| 2847 |
+
sig = f"$x[{bit_i}]"
|
| 2848 |
+
reg.register(sig)
|
| 2849 |
+
return [reg.get_id(sig)]
|
| 2850 |
+
m = re.match(r"^modular\.mod(\d+)\.eq\.k(\d+)\.all$", gate)
|
| 2851 |
+
if m:
|
| 2852 |
+
mod, k = int(m.group(1)), int(m.group(2))
|
| 2853 |
+
prefix = f"modular.mod{mod}.eq.k{k}"
|
| 2854 |
+
return [reg.register(f"{prefix}.bit{i}.match") for i in range(8)]
|
| 2855 |
+
m = re.match(r"^modular\.mod(\d+)$", gate)
|
| 2856 |
+
if m:
|
| 2857 |
+
mod = int(m.group(1))
|
| 2858 |
+
multiples = list(range(0, 256, mod))
|
| 2859 |
+
return [reg.register(f"modular.mod{mod}.eq.k{k}.all") for k in multiples]
|
| 2860 |
+
|
| 2861 |
+
return None
|
| 2862 |
+
|
| 2863 |
+
|
| 2864 |
def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]:
|
| 2865 |
if gate.startswith('manifest.'):
|
| 2866 |
return []
|
| 2867 |
+
# Bit-cascade comparators (integer + float + div) and ternary modular detectors.
|
| 2868 |
+
# These are the gates emitted by add_bit_cascade_compare and the ternary
|
| 2869 |
+
# modular rebuild in quantize.py; cover them up front so they don't fall
|
| 2870 |
+
# through to the generic per-family handlers below.
|
| 2871 |
+
bc = _infer_threshold_computer_bit_cascade(gate, reg)
|
| 2872 |
+
if bc is not None:
|
| 2873 |
+
return bc
|
| 2874 |
if gate.startswith('boolean.'):
|
| 2875 |
return infer_boolean_inputs(gate, reg)
|
| 2876 |
if gate.startswith('arithmetic.'):
|