CharlesCNorton commited on
Commit
2e671d4
·
1 Parent(s): 3942c4f

cmd_inputs: detect and regenerate stale .inputs metadata

Browse files

When a gate's .weight tensor was rewritten by a later build pass (most
notably the bit-cascade comparator and modular ternarization work),
its existing .inputs entry from an earlier seed file no longer matches
the new fan-in. build_inputs now compares each existing .inputs length
against its corresponding .weight (for single-gate tensors with
weight.dim() == 1) and regenerates when they disagree, rather than
silently keeping stale routing.

Packed multi-gate tensors (weight.dim() > 1, e.g. memory.read.and)
use a different routing convention and are left alone.

This fixes the inconsistency that downstream tools like
safetensors2verilog were tripping on; on neural_alu8.safetensors,
15,264 stale entries get regenerated. The remaining ~3.4k gates
whose new naming patterns infer_inputs_for_gate doesn't yet recognize
end up with no .inputs at all (rather than a wrong-length one), which
is the cleaner failure mode.

Files changed (1) hide show
  1. build.py +15 -3
build.py CHANGED
@@ -2977,12 +2977,24 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
2977
  def build_inputs(tensors: Dict[str, torch.Tensor]) -> tuple[Dict[str, torch.Tensor], SignalRegistry, dict]:
2978
  reg = SignalRegistry()
2979
  gates = get_all_gates(tensors)
2980
- stats = {"added": 0, "skipped": 0, "empty": 0}
2981
  for gate in sorted(gates):
2982
  inputs_key = f"{gate}.inputs"
 
2983
  if inputs_key in tensors:
2984
- stats["skipped"] += 1
2985
- continue
 
 
 
 
 
 
 
 
 
 
 
2986
  inputs = infer_inputs_for_gate(gate, reg, tensors)
2987
  if inputs:
2988
  tensors[inputs_key] = torch.tensor(inputs, dtype=torch.int64)
 
2977
  def build_inputs(tensors: Dict[str, torch.Tensor]) -> tuple[Dict[str, torch.Tensor], SignalRegistry, dict]:
2978
  reg = SignalRegistry()
2979
  gates = get_all_gates(tensors)
2980
+ stats = {"added": 0, "skipped": 0, "empty": 0, "regenerated": 0}
2981
  for gate in sorted(gates):
2982
  inputs_key = f"{gate}.inputs"
2983
+ weight_key = f"{gate}.weight"
2984
  if inputs_key in tensors:
2985
+ # Detect stale .inputs (length doesn't match the gate's fan-in)
2986
+ # for single-gate tensors and regenerate them. Packed multi-gate
2987
+ # tensors have weight.dim() > 1 and use a different convention,
2988
+ # so we leave their .inputs alone.
2989
+ existing = tensors[inputs_key]
2990
+ weight = tensors.get(weight_key)
2991
+ if (weight is not None and weight.dim() == 1
2992
+ and existing.numel() != weight.numel()):
2993
+ del tensors[inputs_key]
2994
+ stats["regenerated"] += 1
2995
+ else:
2996
+ stats["skipped"] += 1
2997
+ continue
2998
  inputs = infer_inputs_for_gate(gate, reg, tensors)
2999
  if inputs:
3000
  tensors[inputs_key] = torch.tensor(inputs, dtype=torch.int64)