cmd_inputs: detect and regenerate stale .inputs metadata
Browse filesWhen a gate's .weight tensor was rewritten by a later build pass (most
notably the bit-cascade comparator and modular ternarization work),
its existing .inputs entry from an earlier seed file no longer matches
the new fan-in. build_inputs now compares each existing .inputs length
against its corresponding .weight (for single-gate tensors with
weight.dim() == 1) and regenerates when they disagree, rather than
silently keeping stale routing.
Packed multi-gate tensors (weight.dim() > 1, e.g. memory.read.and)
use a different routing convention and are left alone.
This fixes the inconsistency that downstream tools like
safetensors2verilog were tripping on; on neural_alu8.safetensors,
15,264 stale entries get regenerated. The remaining ~3.4k gates
whose new naming patterns infer_inputs_for_gate doesn't yet recognize
end up with no .inputs at all (rather than a wrong-length one), which
is the cleaner failure mode.
|
@@ -2977,12 +2977,24 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
|
|
| 2977 |
def build_inputs(tensors: Dict[str, torch.Tensor]) -> tuple[Dict[str, torch.Tensor], SignalRegistry, dict]:
|
| 2978 |
reg = SignalRegistry()
|
| 2979 |
gates = get_all_gates(tensors)
|
| 2980 |
-
stats = {"added": 0, "skipped": 0, "empty": 0}
|
| 2981 |
for gate in sorted(gates):
|
| 2982 |
inputs_key = f"{gate}.inputs"
|
|
|
|
| 2983 |
if inputs_key in tensors:
|
| 2984 |
-
|
| 2985 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2986 |
inputs = infer_inputs_for_gate(gate, reg, tensors)
|
| 2987 |
if inputs:
|
| 2988 |
tensors[inputs_key] = torch.tensor(inputs, dtype=torch.int64)
|
|
|
|
| 2977 |
def build_inputs(tensors: Dict[str, torch.Tensor]) -> tuple[Dict[str, torch.Tensor], SignalRegistry, dict]:
|
| 2978 |
reg = SignalRegistry()
|
| 2979 |
gates = get_all_gates(tensors)
|
| 2980 |
+
stats = {"added": 0, "skipped": 0, "empty": 0, "regenerated": 0}
|
| 2981 |
for gate in sorted(gates):
|
| 2982 |
inputs_key = f"{gate}.inputs"
|
| 2983 |
+
weight_key = f"{gate}.weight"
|
| 2984 |
if inputs_key in tensors:
|
| 2985 |
+
# Detect stale .inputs (length doesn't match the gate's fan-in)
|
| 2986 |
+
# for single-gate tensors and regenerate them. Packed multi-gate
|
| 2987 |
+
# tensors have weight.dim() > 1 and use a different convention,
|
| 2988 |
+
# so we leave their .inputs alone.
|
| 2989 |
+
existing = tensors[inputs_key]
|
| 2990 |
+
weight = tensors.get(weight_key)
|
| 2991 |
+
if (weight is not None and weight.dim() == 1
|
| 2992 |
+
and existing.numel() != weight.numel()):
|
| 2993 |
+
del tensors[inputs_key]
|
| 2994 |
+
stats["regenerated"] += 1
|
| 2995 |
+
else:
|
| 2996 |
+
stats["skipped"] += 1
|
| 2997 |
+
continue
|
| 2998 |
inputs = infer_inputs_for_gate(gate, reg, tensors)
|
| 2999 |
if inputs:
|
| 3000 |
tensors[inputs_key] = torch.tensor(inputs, dtype=torch.int64)
|