""" Convert arithmetic.safetensors to self-documenting format with explicit .inputs tensors. Each gate gets: - .weight (existing) - .bias (existing) - .inputs (NEW) - tensor of signal IDs referencing input sources Signal registry stored in file metadata maps IDs to signal names: - "$name" = external input (e.g., "$a", "$b", "$dividend[0]") - "#value" = constant (e.g., "#0", "#1") - "gate.path" = output of another gate """ import torch from safetensors import safe_open from safetensors.torch import save_file import json import re import struct import math from collections import defaultdict from typing import Dict, List, Tuple, Set, Callable, Optional class SignalRegistry: """Manages signal ID assignments.""" def __init__(self): self.name_to_id: Dict[str, int] = {} self.id_to_name: Dict[int, str] = {} self.next_id = 0 # Pre-register constants self.register("#0") self.register("#1") def register(self, name: str) -> int: if name not in self.name_to_id: self.name_to_id[name] = self.next_id self.id_to_name[self.next_id] = name self.next_id += 1 return self.name_to_id[name] def get_id(self, name: str) -> int: return self.name_to_id.get(name, -1) def to_metadata(self) -> str: return json.dumps(self.id_to_name) def float16_bits_to_float(bits: int) -> float: """Interpret 16-bit int as IEEE-754 float16.""" packed = struct.pack('>H', bits & 0xFFFF) return struct.unpack('>e', packed)[0] def float16_float_to_bits(val: float) -> int: """Convert float to IEEE-754 float16 bits with canonical NaN.""" try: packed = struct.pack('>e', float(val)) return struct.unpack('>H', packed)[0] except (OverflowError, struct.error): if val == float('inf'): return 0x7C00 if val == float('-inf'): return 0xFC00 if val != val: return 0x7E00 return 0x7BFF if val > 0 else 0xFBFF def float32_bits_to_float(bits: int) -> float: """Interpret 32-bit int as IEEE-754 float32.""" packed = struct.pack('>I', bits & 0xFFFFFFFF) return struct.unpack('>f', packed)[0] def float32_float_to_bits(val: float) -> int: """Convert float to IEEE-754 float32 bits with canonical NaN.""" try: packed = struct.pack('>f', float(val)) return struct.unpack('>I', packed)[0] except (OverflowError, struct.error): if val == float('inf'): return 0x7F800000 if val == float('-inf'): return 0xFF800000 if val != val: return 0x7FC00000 return 0x7F7FFFFF if val > 0 else 0xFF7FFFFF def build_float16_const_tensors(prefix: str, value: float) -> Dict[str, torch.Tensor]: """Build constant float16 outputs (prefix.out0..out15) using #1 input.""" tensors: Dict[str, torch.Tensor] = {} bits = float16_float_to_bits(value) for i in range(16): bit = (bits >> i) & 1 if bit: tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) else: tensors[f"{prefix}.out{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([0.0]) return tensors def compute_float16_unary_lut_outputs(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> List[int]: """Compute output bits for all 65536 float16 inputs using a unary op.""" outputs: List[int] = [0] * 65536 for bits in range(65536): val = float16_bits_to_float(bits) out = op_fn(torch.tensor(val, dtype=torch.float16)).item() if out != out: outputs[bits] = 0x7E00 else: outputs[bits] = float16_float_to_bits(float(out)) return outputs def compute_float16_domain_flags(op: str) -> List[int]: """Compute domain error flags (1=invalid) for all 65536 float16 inputs.""" flags: List[int] = [0] * 65536 for bits in range(65536): val = float16_bits_to_float(bits) invalid = False if val != val: invalid = True elif op in ("sqrt", "rsqrt") and val < 0: invalid = True elif op in ("ln", "log2", "log10") and val <= 0: invalid = True elif op in ("asin", "acos", "asin_deg", "acos_deg") and abs(val) > 1.0: invalid = True flags[bits] = 1 if invalid else 0 return flags def unary_float32(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]: """Wrap unary op to run in float32 for wider CPU support, returning float32.""" def _fn(x: torch.Tensor) -> torch.Tensor: return op_fn(x.float()) return _fn def wrap_deg_trig(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]: """Wrap trig op to interpret input as degrees.""" def _fn(x: torch.Tensor) -> torch.Tensor: return op_fn(x.float() * (math.pi / 180.0)) return _fn def wrap_inv_trig_deg(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]: """Wrap inverse trig op to return degrees.""" def _fn(x: torch.Tensor) -> torch.Tensor: return op_fn(x.float()) * (180.0 / math.pi) return _fn def build_float16_lut_match_tensors(prefix: str) -> Dict[str, torch.Tensor]: """Build exact-match gates for all 16-bit patterns under prefix.matchXXXX.""" tensors: Dict[str, torch.Tensor] = {} for bits in range(65536): ones = bits.bit_count() weights = [1.0 if (bits >> i) & 1 else -1.0 for i in range(16)] bias = -(ones - 0.5) name = f"{prefix}.match{bits:04x}" tensors[f"{name}.weight"] = torch.tensor(weights) tensors[f"{name}.bias"] = torch.tensor([bias]) return tensors def build_float16_lut_output_tensors(prefix: str, outputs: List[int]) -> Dict[str, torch.Tensor]: """Build LUT output gates (prefix.out0..out15) using one-hot match inputs.""" tensors: Dict[str, torch.Tensor] = {} for bit in range(16): weights = torch.zeros(65536) for idx, out_bits in enumerate(outputs): if (out_bits >> bit) & 1: weights[idx] = 1.0 tensors[f"{prefix}.out{bit}.weight"] = weights tensors[f"{prefix}.out{bit}.bias"] = torch.tensor([-0.5]) return tensors def build_float16_lut_flag_tensors(prefix: str, flags: List[int], flag_name: str = "domain") -> Dict[str, torch.Tensor]: """Build a 1-bit LUT flag gate (prefix.{flag_name}) using one-hot match inputs.""" weights = torch.zeros(65536) for idx, flag in enumerate(flags): if flag: weights[idx] = 1.0 tensors: Dict[str, torch.Tensor] = {} tensors[f"{prefix}.{flag_name}.weight"] = weights tensors[f"{prefix}.{flag_name}.bias"] = torch.tensor([-0.5]) return tensors def build_float16_checked_outputs(prefix: str) -> Dict[str, torch.Tensor]: """Build checked outputs that force NaN bits when domain flag is set.""" tensors: Dict[str, torch.Tensor] = {} add_not_gate(tensors, f"{prefix}.domain_not") nan_bits = 0x7E00 for i in range(16): gate = f"{prefix}.checked_out{i}" if (nan_bits >> i) & 1: add_or_gate(tensors, gate) else: add_and_gate(tensors, gate) return tensors def clone_prefix_tensors(src: Dict[str, torch.Tensor], old_prefix: str, new_prefix: str) -> Dict[str, torch.Tensor]: """Clone tensors and rewrite the prefix in tensor names.""" out: Dict[str, torch.Tensor] = {} for name, tensor in src.items(): if name.startswith(old_prefix + "."): out_name = new_prefix + name[len(old_prefix):] out[out_name] = tensor.clone() return out def extract_gate_name(tensor_name: str) -> str: """Extract gate name from tensor name (remove .weight or .bias suffix).""" if tensor_name.endswith('.weight'): return tensor_name[:-7] elif tensor_name.endswith('.bias'): return tensor_name[:-5] return tensor_name def get_all_gates(tensors: Dict[str, torch.Tensor]) -> Set[str]: """Get all unique gate names (anything with a .weight).""" gates = set() for name in tensors: if name.endswith('.weight'): gates.add(extract_gate_name(name)) return gates def infer_boolean_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for boolean gates.""" base = gate.split('.')[-1] if gate == 'boolean.not': registry.register("$x") return [registry.get_id("$x")] if gate in ['boolean.and', 'boolean.or', 'boolean.nand', 'boolean.nor', 'boolean.implies']: registry.register("$a") registry.register("$b") return [registry.get_id("$a"), registry.get_id("$b")] # Two-layer gates (xor, xnor, biimplies) if 'layer1.neuron1' in gate or 'layer1.neuron2' in gate: registry.register("$a") registry.register("$b") return [registry.get_id("$a"), registry.get_id("$b")] if 'layer2' in gate: parent = gate.rsplit('.layer2', 1)[0] n1_out = registry.register(f"{parent}.layer1.neuron1") n2_out = registry.register(f"{parent}.layer1.neuron2") return [n1_out, n2_out] return [] def get_lut_match_ids(registry: SignalRegistry, match_prefix: str) -> List[int]: """Get (and cache) match gate IDs for a LUT prefix.""" cache = getattr(registry, "_lut_match_ids", None) if cache is None: cache = {} setattr(registry, "_lut_match_ids", cache) if match_prefix not in cache: cache[match_prefix] = [registry.register(f"{match_prefix}.match{idx:04x}") for idx in range(65536)] return cache[match_prefix] def infer_float16_lut_match_inputs(gate: str, registry: SignalRegistry, match_prefix: str, input_bits: List[str]) -> List[int]: """Infer inputs for LUT match gates (exact pattern match).""" if not gate.startswith(f"{match_prefix}.match"): return [] for name in input_bits: registry.register(name) return [registry.get_id(name) for name in input_bits] def infer_float16_lut_out_inputs(gate: str, registry: SignalRegistry, match_prefix: str) -> List[int]: """Infer inputs for LUT output gates (one-hot match vector).""" if gate.endswith(".domain"): return get_lut_match_ids(registry, match_prefix) match = re.search(r'\.out(\d+)$', gate) if not match: return [] return get_lut_match_ids(registry, match_prefix) def infer_float16_lut_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for shared float16.lut match gates.""" prefix = "float16.lut" input_bits = [f"{prefix}.$x[{i}]" for i in range(16)] return infer_float16_lut_match_inputs(gate, registry, prefix, input_bits) def infer_float16_pow_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.pow circuit (ln -> mul -> exp).""" prefix = "float16.pow" # External inputs for i in range(16): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # ln subcircuit (match + outputs) ln_prefix = f"{prefix}.ln" ln_input_bits = [f"{prefix}.$a[{i}]" for i in range(16)] inputs = infer_float16_lut_match_inputs(gate, registry, ln_prefix, ln_input_bits) if inputs: return inputs if gate.startswith(f"{ln_prefix}."): return infer_float16_lut_out_inputs(gate, registry, ln_prefix) # mul subcircuit (a = ln.out, b = external b) if gate.startswith(f"{prefix}.mul."): a_bits = [f"{ln_prefix}.out{i}" for i in range(16)] b_bits = [f"{prefix}.$b[{i}]" for i in range(16)] return infer_float16_mul_inputs(gate, registry, prefix=f"{prefix}.mul", a_bits=a_bits, b_bits=b_bits) # exp subcircuit (match + outputs) with input from mul outputs exp_prefix = f"{prefix}.exp" exp_input_bits = [f"{prefix}.mul.out{i}" for i in range(16)] inputs = infer_float16_lut_match_inputs(gate, registry, exp_prefix, exp_input_bits) if inputs: return inputs if gate.startswith(f"{exp_prefix}."): return infer_float16_lut_out_inputs(gate, registry, exp_prefix) # pow outputs (pass-through from exp.out) match = re.search(r'\.out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{exp_prefix}.out{i}")] return [] def infer_halfadder_inputs(gate: str, prefix: str, registry: SignalRegistry) -> List[int]: """Infer inputs for half adder gates.""" registry.register(f"{prefix}.$a") registry.register(f"{prefix}.$b") if '.sum.layer1' in gate: return [registry.get_id(f"{prefix}.$a"), registry.get_id(f"{prefix}.$b")] if '.sum.layer2' in gate: parent = gate.rsplit('.layer2', 1)[0] or_out = registry.register(f"{parent}.layer1.or") nand_out = registry.register(f"{parent}.layer1.nand") return [or_out, nand_out] if '.carry' in gate and 'layer' not in gate: return [registry.get_id(f"{prefix}.$a"), registry.get_id(f"{prefix}.$b")] return [] def infer_fulladder_inputs(gate: str, prefix: str, registry: SignalRegistry) -> List[int]: """Infer inputs for full adder gates.""" # Register external inputs registry.register(f"{prefix}.$a") registry.register(f"{prefix}.$b") registry.register(f"{prefix}.$cin") # HA1 inputs if '.ha1.sum.layer1' in gate or '.ha1.carry' in gate: return [registry.get_id(f"{prefix}.$a"), registry.get_id(f"{prefix}.$b")] if '.ha1.sum.layer2' in gate: parent = gate.rsplit('.layer2', 1)[0] or_out = registry.register(f"{parent}.layer1.or") nand_out = registry.register(f"{parent}.layer1.nand") return [or_out, nand_out] # HA2 inputs (ha1.sum output + cin) # Use full gate name for ha1.sum output (which is ha1.sum.layer2) ha1_sum = registry.register(f"{prefix}.ha1.sum.layer2") if '.ha2.sum.layer1' in gate or '.ha2.carry' in gate: return [ha1_sum, registry.get_id(f"{prefix}.$cin")] if '.ha2.sum.layer2' in gate: parent = gate.rsplit('.layer2', 1)[0] or_out = registry.register(f"{parent}.layer1.or") nand_out = registry.register(f"{parent}.layer1.nand") return [or_out, nand_out] # Carry OR if '.carry_or' in gate: ha1_carry = registry.register(f"{prefix}.ha1.carry") ha2_carry = registry.register(f"{prefix}.ha2.carry") return [ha1_carry, ha2_carry] return [] def infer_ripplecarry_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]: """Infer inputs for ripple carry adder gates.""" # Register all input bits for i in range(bits): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # Find which FA this gate belongs to match = re.search(r'\.fa(\d+)\.', gate) if not match: return [] fa_idx = int(match.group(1)) fa_prefix = f"{prefix}.fa{fa_idx}" # Determine carry input if fa_idx == 0: cin = registry.register("#0") else: # Carry output is from carry_or gate cin = registry.register(f"{prefix}.fa{fa_idx-1}.carry_or") # Register this FA's external inputs a_bit = registry.get_id(f"{prefix}.$a[{fa_idx}]") b_bit = registry.get_id(f"{prefix}.$b[{fa_idx}]") # Now infer based on gate type within FA if '.ha1.sum.layer1' in gate or '.ha1.carry' in gate: return [a_bit, b_bit] if '.ha1.sum.layer2' in gate: parent = gate.rsplit('.layer2', 1)[0] or_out = registry.register(f"{parent}.layer1.or") nand_out = registry.register(f"{parent}.layer1.nand") return [or_out, nand_out] ha1_sum = registry.register(f"{fa_prefix}.ha1.sum.layer2") if '.ha2.sum.layer1' in gate or '.ha2.carry' in gate: return [ha1_sum, cin] if '.ha2.sum.layer2' in gate: parent = gate.rsplit('.layer2', 1)[0] or_out = registry.register(f"{parent}.layer1.or") nand_out = registry.register(f"{parent}.layer1.nand") return [or_out, nand_out] if '.carry_or' in gate or '.or_carry' in gate: ha1_carry = registry.register(f"{fa_prefix}.ha1.carry") ha2_carry = registry.register(f"{fa_prefix}.ha2.carry") return [ha1_carry, ha2_carry] return [] def infer_threshold_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for threshold gates (k-of-n).""" # 8-bit input inputs = [] for i in range(8): sig = registry.register(f"{gate}.$x[{i}]") inputs.append(sig) return inputs def infer_modular_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for modular arithmetic gates.""" # Extract mod value match = re.search(r'modular\.mod(\d+)', gate) if not match: return [] mod = int(match.group(1)) prefix = f"modular.mod{mod}" # Register 8-bit input for i in range(8): registry.register(f"{prefix}.$x[{i}]") # Single layer (powers of 2) - handles both old (modular.mod2) and new (modular.mod2.out0) naming if mod in [2, 4, 8] and (gate == prefix or gate.startswith(f"{prefix}.out")): return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(8)] # Multi-layer if '.layer1.geq' in gate or '.layer1.leq' in gate: return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(8)] if '.layer2.eq' in gate: match = re.search(r'\.eq(\d+)', gate) if match: idx = int(match.group(1)) geq = registry.register(f"{prefix}.layer1.geq{idx}") leq = registry.register(f"{prefix}.layer1.leq{idx}") return [geq, leq] if '.layer3.or' in gate: # Find all eq outputs inputs = [] idx = 0 while True: eq_name = f"{prefix}.layer2.eq{idx}" if eq_name in registry.name_to_id: inputs.append(registry.get_id(eq_name)) idx += 1 else: break return inputs if inputs else [registry.register(f"{prefix}.layer2.eq0")] return [] def infer_comparator_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for comparator gates.""" prefix = gate.rsplit('.', 1)[0] # Remove .comparator if "32bit" in prefix: bits = 32 elif "16bit" in prefix: bits = 16 else: bits = 8 inputs = [] for i in range(bits): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # Comparator takes difference of bit pairs for i in range(bits): inputs.append(registry.get_id(f"{prefix}.$a[{i}]")) for i in range(bits): inputs.append(registry.get_id(f"{prefix}.$b[{i}]")) return inputs def infer_adc_sbc_inputs(gate: str, prefix: str, registry: SignalRegistry, bits: int = 8) -> List[int]: """Infer inputs for ADC/SBC (add/subtract with carry) gates.""" # Register inputs for i in range(bits): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") registry.register(f"{prefix}.$cin") # SBC has NOT gates for B if '.notb' in gate: match = re.search(r'\.notb(\d+)', gate) if match: idx = int(match.group(1)) return [registry.get_id(f"{prefix}.$b[{idx}]")] # Find which FA this belongs to match = re.search(r'\.fa(\d+)\.', gate) if not match: return [] fa_idx = int(match.group(1)) fa_prefix = f"{prefix}.fa{fa_idx}" a_bit = registry.get_id(f"{prefix}.$a[{fa_idx}]") b_bit = registry.get_id(f"{prefix}.$b[{fa_idx}]") # Carry chain if fa_idx == 0: cin = registry.get_id(f"{prefix}.$cin") else: cin = registry.register(f"{prefix}.fa{fa_idx-1}.cout") # XOR1: a XOR b if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: or_out = registry.register(f"{fa_prefix}.xor1.layer1.or") nand_out = registry.register(f"{fa_prefix}.xor1.layer1.nand") return [or_out, nand_out] xor1_out = registry.register(f"{fa_prefix}.xor1") # XOR2: xor1 XOR cin if '.xor2.layer1' in gate: return [xor1_out, cin] if '.xor2.layer2' in gate: or_out = registry.register(f"{fa_prefix}.xor2.layer1.or") nand_out = registry.register(f"{fa_prefix}.xor2.layer1.nand") return [or_out, nand_out] # AND gates for carry if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1_out, cin] # OR for carry out if '.or_carry' in gate: and1 = registry.register(f"{fa_prefix}.and1") and2 = registry.register(f"{fa_prefix}.and2") return [and1, and2] return [] def infer_sub_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]: """Infer inputs for subtractor (complement addition) gates.""" for i in range(bits): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # NOT gates for B (two's complement) if '.notb' in gate: match = re.search(r'\.notb(\d+)', gate) if match: idx = int(match.group(1)) return [registry.get_id(f"{prefix}.$b[{idx}]")] # Carry in (set to 1 for subtraction) if '.carry_in' in gate: return [registry.get_id("#1")] # Full adder chain match = re.search(r'\.fa(\d+)\.', gate) if match: fa_idx = int(match.group(1)) fa_prefix = f"{prefix}.fa{fa_idx}" a_bit = registry.get_id(f"{prefix}.$a[{fa_idx}]") notb_bit = registry.register(f"{prefix}.notb{fa_idx}") if fa_idx == 0: cin = registry.register(f"{prefix}.carry_in") else: cin = registry.register(f"{prefix}.fa{fa_idx-1}.cout") if '.xor1.layer1' in gate: return [a_bit, notb_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1_out = registry.register(f"{fa_prefix}.xor1") if '.xor2.layer1' in gate: return [xor1_out, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, notb_bit] if '.and2' in gate: return [xor1_out, cin] if '.or_carry' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] return [] def infer_sub8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for SUB8BIT (subtraction via complement addition).""" return infer_sub_inputs(gate, "arithmetic.sub8bit", 8, registry) def infer_sub16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for SUB16BIT (subtraction via complement addition).""" return infer_sub_inputs(gate, "arithmetic.sub16bit", 16, registry) def infer_cmp_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]: """Infer inputs for comparator via subtraction.""" for i in range(bits): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") if '.notb' in gate: match = re.search(r'\.notb(\d+)', gate) if match: idx = int(match.group(1)) return [registry.get_id(f"{prefix}.$b[{idx}]")] match = re.search(r'\.fa(\d+)\.', gate) if match: fa_idx = int(match.group(1)) fa_prefix = f"{prefix}.fa{fa_idx}" a_bit = registry.get_id(f"{prefix}.$a[{fa_idx}]") notb_bit = registry.register(f"{prefix}.notb{fa_idx}") if fa_idx == 0: cin = registry.get_id("#1") else: cin = registry.register(f"{prefix}.fa{fa_idx-1}.cout") if '.xor1.layer1' in gate: return [a_bit, notb_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1_out = registry.register(f"{fa_prefix}.xor1") if '.xor2.layer1' in gate: return [xor1_out, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, notb_bit] if '.and2' in gate: return [xor1_out, cin] if '.or_carry' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] if '.flags.' in gate: return [registry.register(f"{prefix}.fa{i}.sum") for i in range(bits)] return [] def infer_cmp8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for CMP8BIT (compare via subtraction).""" return infer_cmp_inputs(gate, "arithmetic.cmp8bit", 8, registry) def infer_cmp16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for CMP16BIT (compare via subtraction).""" return infer_cmp_inputs(gate, "arithmetic.cmp16bit", 16, registry) def infer_equality_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]: """Infer inputs for equality circuit (XNOR chain + AND).""" for i in range(bits): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") match = re.search(r'\.xnor(\d+)\.', gate) if match: idx = int(match.group(1)) a_bit = registry.get_id(f"{prefix}.$a[{idx}]") b_bit = registry.get_id(f"{prefix}.$b[{idx}]") if '.layer1.and' in gate or '.layer1.nor' in gate: return [a_bit, b_bit] if '.layer2' in gate: and_out = registry.register(f"{prefix}.xnor{idx}.layer1.and") nor_out = registry.register(f"{prefix}.xnor{idx}.layer1.nor") return [and_out, nor_out] if '.and' in gate or '.final_and' in gate: return [registry.register(f"{prefix}.xnor{i}") for i in range(bits)] return [] def infer_equality8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for equality8bit circuit (XNOR chain + AND).""" return infer_equality_inputs(gate, "arithmetic.equality8bit", 8, registry) def infer_equality16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for equality16bit circuit (XNOR chain + AND).""" return infer_equality_inputs(gate, "arithmetic.equality16bit", 16, registry) def infer_neg_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]: """Infer inputs for negation (two's complement).""" for i in range(bits): registry.register(f"{prefix}.$x[{i}]") if '.not' in gate and 'layer' not in gate: match = re.search(r'\.not(\d+)', gate) if match: idx = int(match.group(1)) return [registry.get_id(f"{prefix}.$x[{idx}]")] if '.sum0' in gate: return [registry.register(f"{prefix}.not0")] if '.carry0' in gate: return [registry.register(f"{prefix}.not0"), registry.get_id("#1")] match = re.search(r'\.xor(\d+)\.', gate) if match: idx = int(match.group(1)) not_bit = registry.register(f"{prefix}.not{idx}") if idx == 1: carry_in = registry.register(f"{prefix}.carry0") else: carry_in = registry.register(f"{prefix}.and{idx-1}") if '.layer1' in gate: return [not_bit, carry_in] if '.layer2' in gate: return [registry.register(f"{prefix}.xor{idx}.layer1.nand"), registry.register(f"{prefix}.xor{idx}.layer1.or")] match = re.search(r'\.and(\d+)', gate) if match and 'layer' not in gate: idx = int(match.group(1)) not_bit = registry.register(f"{prefix}.not{idx}") if idx == 1: carry_in = registry.register(f"{prefix}.carry0") else: carry_in = registry.register(f"{prefix}.and{idx-1}") return [not_bit, carry_in] return [] def infer_neg8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for NEG8BIT (two's complement negation).""" return infer_neg_inputs(gate, "arithmetic.neg8bit", 8, registry) def infer_neg16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for NEG16BIT (two's complement negation).""" return infer_neg_inputs(gate, "arithmetic.neg16bit", 16, registry) def infer_shift_rotate_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for ASR, ROL, ROR.""" # Determine which circuit if 'asr32bit' in gate: prefix = "arithmetic.asr32bit" bits = 32 elif 'rol32bit' in gate: prefix = "arithmetic.rol32bit" bits = 32 elif 'ror32bit' in gate: prefix = "arithmetic.ror32bit" bits = 32 elif 'asr16bit' in gate: prefix = "arithmetic.asr16bit" bits = 16 elif 'rol16bit' in gate: prefix = "arithmetic.rol16bit" bits = 16 elif 'ror16bit' in gate: prefix = "arithmetic.ror16bit" bits = 16 elif 'asr8bit' in gate: prefix = "arithmetic.asr8bit" bits = 8 elif 'rol8bit' in gate: prefix = "arithmetic.rol8bit" bits = 8 elif 'ror8bit' in gate: prefix = "arithmetic.ror8bit" bits = 8 else: return [] for i in range(bits): registry.register(f"{prefix}.$x[{i}]") # Bit selectors match = re.search(r'\.bit(\d+)', gate) if match: idx = int(match.group(1)) # Each output bit selects from input bits based on shift return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(bits)] # Carry/shift out if '.cout' in gate or '.shiftout' in gate: if 'rol' in gate: return [registry.get_id(f"{prefix}.$x[{bits-1}]")] # MSB shifts out elif 'ror' in gate: return [registry.get_id(f"{prefix}.$x[0]")] # LSB shifts out elif 'asr' in gate: return [registry.get_id(f"{prefix}.$x[0]")] # src tensors (metadata, not gates) if '.src' in gate: return [] return [] def infer_multiplier_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for multiplier circuits.""" # Determine size if 'multiplier8x8' in gate: prefix = "arithmetic.multiplier8x8" size = 8 elif 'multiplier4x4' in gate: prefix = "arithmetic.multiplier4x4" size = 4 elif 'multiplier2x2' in gate: prefix = "arithmetic.multiplier2x2" size = 2 else: return [] for i in range(size): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # Partial products (AND gates) if '.pp.' in gate: match = re.search(r'\.r(\d+)\.c(\d+)', gate) if match: row, col = int(match.group(1)), int(match.group(2)) return [registry.get_id(f"{prefix}.$a[{col}]"), registry.get_id(f"{prefix}.$b[{row}]")] # Direct AND gates used by multiplier2x2 if 'multiplier2x2' in gate: match = re.search(r'\.and(\d)(\d)$', gate) if match: row, col = int(match.group(1)), int(match.group(2)) if row < size and col < size: return [registry.get_id(f"{prefix}.$a[{col}]"), registry.get_id(f"{prefix}.$b[{row}]")] # Stage adders match = re.search(r'\.stage(\d+)\.bit(\d+)\.', gate) if match: stage, bit = int(match.group(1)), int(match.group(2)) stage_prefix = f"{prefix}.stage{stage}.bit{bit}" # Previous result bit (output is ha2.sum.layer2 for stage adders) if stage == 0: if bit < size: prev_bit = registry.register(f"{prefix}.pp.r0.c{bit}") else: prev_bit = registry.get_id("#0") else: prev_bit = registry.register(f"{prefix}.stage{stage-1}.bit{bit}.ha2.sum.layer2") # Partial product for this stage row = stage + 1 shift = row if bit >= shift and bit < shift + size: pp_bit = registry.register(f"{prefix}.pp.r{row}.c{bit-shift}") else: pp_bit = registry.get_id("#0") # Carry from previous bit if bit == 0: carry_in = registry.get_id("#0") else: carry_in = registry.register(f"{prefix}.stage{stage}.bit{bit-1}.carry_or") if '.ha1.sum.layer1' in gate or '.ha1.carry' in gate: return [prev_bit, pp_bit] if '.ha1.sum.layer2' in gate: return [registry.register(f"{stage_prefix}.ha1.sum.layer1.or"), registry.register(f"{stage_prefix}.ha1.sum.layer1.nand")] ha1_sum = registry.register(f"{stage_prefix}.ha1.sum.layer2") if '.ha2.sum.layer1' in gate or '.ha2.carry' in gate: return [ha1_sum, carry_in] if '.ha2.sum.layer2' in gate: return [registry.register(f"{stage_prefix}.ha2.sum.layer1.or"), registry.register(f"{stage_prefix}.ha2.sum.layer1.nand")] if '.carry_or' in gate or '.or_carry' in gate: return [registry.register(f"{stage_prefix}.ha1.carry"), registry.register(f"{stage_prefix}.ha2.carry")] # 2x2 multiplier special cases if 'multiplier2x2' in gate: if '.ha0.sum' in gate or '.ha0.carry' in gate: return [registry.register(f"{prefix}.and01"), registry.register(f"{prefix}.and10")] return [] def infer_incr_decr_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for incrementer/decrementer.""" if 'incrementer32bit' in gate: prefix = "arithmetic.incrementer32bit" bits = 32 elif 'decrementer32bit' in gate: prefix = "arithmetic.decrementer32bit" bits = 32 elif 'incrementer16bit' in gate: prefix = "arithmetic.incrementer16bit" bits = 16 elif 'decrementer16bit' in gate: prefix = "arithmetic.decrementer16bit" bits = 16 elif 'incrementer' in gate: prefix = "arithmetic.incrementer8bit" bits = 8 elif 'decrementer' in gate: prefix = "arithmetic.decrementer8bit" bits = 8 else: return [] for i in range(bits): registry.register(f"{prefix}.$x[{i}]") # These typically just reference adder and constant return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(bits)] def infer_minmax_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for min/max/absolutedifference.""" if 'max32bit' in gate: prefix = "arithmetic.max32bit" bits = 32 elif 'min32bit' in gate: prefix = "arithmetic.min32bit" bits = 32 elif 'absolutedifference32bit' in gate: prefix = "arithmetic.absolutedifference32bit" bits = 32 elif 'max16bit' in gate: prefix = "arithmetic.max16bit" bits = 16 elif 'min16bit' in gate: prefix = "arithmetic.min16bit" bits = 16 elif 'absolutedifference16bit' in gate: prefix = "arithmetic.absolutedifference16bit" bits = 16 elif 'max8bit' in gate: prefix = "arithmetic.max8bit" bits = 8 elif 'min8bit' in gate: prefix = "arithmetic.min8bit" bits = 8 elif 'absolutedifference' in gate: prefix = "arithmetic.absolutedifference8bit" bits = 8 else: return [] for i in range(bits): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # Select/diff weights take comparison + both operands inputs = [] for i in range(bits): inputs.append(registry.get_id(f"{prefix}.$a[{i}]")) for i in range(bits): inputs.append(registry.get_id(f"{prefix}.$b[{i}]")) return inputs def infer_clz16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for CLZ16BIT (count leading zeros, 16-bit).""" prefix = "arithmetic.clz16bit" # Register 16-bit input for i in range(16): registry.register(f"{prefix}.$x[{i}]") # pz gates: prefix zero detectors (NOR of top k bits) if '.pz' in gate: match = re.search(r'\.pz(\d+)', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.$x[{15-i}]") for i in range(k)] # Register pz outputs for i in range(1, 17): registry.register(f"{prefix}.pz{i}") pz_ids = [registry.get_id(f"{prefix}.pz{i}") for i in range(1, 17)] # ge gates: sum of pz >= k if '.ge' in gate and '.not_ge' not in gate: match = re.search(r'\.ge(\d+)', gate) if match: return pz_ids # Register ge outputs for k in range(1, 17): registry.register(f"{prefix}.ge{k}") # NOT gates if '.not_ge' in gate: match = re.search(r'\.not_ge(\d+)', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.ge{k}")] # Register NOT outputs for k in [2, 4, 6, 8, 10, 12, 14, 16]: registry.register(f"{prefix}.not_ge{k}") # AND gates for ranges if '.and_8_15' in gate: return [registry.get_id(f"{prefix}.ge8"), registry.get_id(f"{prefix}.not_ge16")] if '.and_4_7' in gate: return [registry.get_id(f"{prefix}.ge4"), registry.get_id(f"{prefix}.not_ge8")] if '.and_12_15' in gate: return [registry.get_id(f"{prefix}.ge12"), registry.get_id(f"{prefix}.not_ge16")] if '.and_2_3' in gate: return [registry.get_id(f"{prefix}.ge2"), registry.get_id(f"{prefix}.not_ge4")] if '.and_6_7' in gate: return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")] if '.and_10_11' in gate: return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.not_ge12")] if '.and_14_15' in gate: return [registry.get_id(f"{prefix}.ge14"), registry.get_id(f"{prefix}.not_ge16")] # Odd number AND gates (use regex for exact match to avoid .and_1 matching .and_15) match = re.search(r'\.and_(\d+)$', gate) if match: i = int(match.group(1)) if i in [1, 3, 5, 7, 9, 11, 13, 15]: return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{i+1}")] # Register AND outputs for name in ['and_8_15', 'and_4_7', 'and_12_15', 'and_2_3', 'and_6_7', 'and_10_11', 'and_14_15']: registry.register(f"{prefix}.{name}") for i in [1, 3, 5, 7, 9, 11, 13, 15]: registry.register(f"{prefix}.and_{i}") # OR gates for bits if '.or_bit2' in gate: return [registry.get_id(f"{prefix}.and_4_7"), registry.get_id(f"{prefix}.and_12_15")] if '.or_bit1' in gate: return [registry.get_id(f"{prefix}.and_2_3"), registry.get_id(f"{prefix}.and_6_7"), registry.get_id(f"{prefix}.and_10_11"), registry.get_id(f"{prefix}.and_14_15")] if '.or_bit0' in gate: return [registry.get_id(f"{prefix}.and_{i}") for i in [1, 3, 5, 7, 9, 11, 13, 15]] registry.register(f"{prefix}.or_bit2") registry.register(f"{prefix}.or_bit1") registry.register(f"{prefix}.or_bit0") # Output gates if '.out4' in gate: return [registry.get_id(f"{prefix}.ge16")] if '.out3' in gate: return [registry.get_id(f"{prefix}.and_8_15")] if '.out2' in gate: return [registry.get_id(f"{prefix}.or_bit2")] if '.out1' in gate: return [registry.get_id(f"{prefix}.or_bit1")] if '.out0' in gate: return [registry.get_id(f"{prefix}.or_bit0")] return [] def infer_clz8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for CLZ8BIT (count leading zeros).""" prefix = "arithmetic.clz8bit" # Register 8-bit input for i in range(8): registry.register(f"{prefix}.$x[{i}]") # pz gates: prefix zero detectors (NOR of top k bits) if '.pz' in gate: match = re.search(r'\.pz(\d+)', gate) if match: k = int(match.group(1)) # pz[k] takes x[7], x[6], ..., x[7-k+1] (top k bits) return [registry.get_id(f"{prefix}.$x[{7-i}]") for i in range(k)] # Register pz outputs for i in range(1, 9): registry.register(f"{prefix}.pz{i}") pz_ids = [registry.get_id(f"{prefix}.pz{i}") for i in range(1, 9)] # ge gates: sum of pz >= k if '.ge' in gate: match = re.search(r'\.ge(\d+)', gate) if match: return pz_ids # Register ge outputs for k in [1, 2, 3, 4, 5, 6, 7, 8]: registry.register(f"{prefix}.ge{k}") # NOT gates if '.not_ge' in gate: match = re.search(r'\.not_ge(\d+)', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.ge{k}")] # Register NOT outputs for k in [2, 4, 6, 8]: registry.register(f"{prefix}.not_ge{k}") # AND gates for ranges if '.and_2_3' in gate: return [registry.get_id(f"{prefix}.ge2"), registry.get_id(f"{prefix}.not_ge4")] if '.and_6_7' in gate: return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")] if '.and_1' in gate: return [registry.get_id(f"{prefix}.ge1"), registry.get_id(f"{prefix}.not_ge2")] if '.and_3' in gate: return [registry.get_id(f"{prefix}.ge3"), registry.get_id(f"{prefix}.not_ge4")] if '.and_5' in gate: return [registry.get_id(f"{prefix}.ge5"), registry.get_id(f"{prefix}.not_ge6")] if '.and_7' in gate: return [registry.get_id(f"{prefix}.ge7"), registry.get_id(f"{prefix}.not_ge8")] # Register AND outputs for name in ['and_2_3', 'and_6_7', 'and_1', 'and_3', 'and_5', 'and_7']: registry.register(f"{prefix}.{name}") # Output gates if '.out3' in gate: return [registry.get_id(f"{prefix}.ge8")] if '.out2' in gate: return [registry.get_id(f"{prefix}.ge4"), registry.get_id(f"{prefix}.not_ge8")] if '.out1' in gate: return [registry.get_id(f"{prefix}.and_2_3"), registry.get_id(f"{prefix}.and_6_7")] if '.out0' in gate: return [registry.get_id(f"{prefix}.and_1"), registry.get_id(f"{prefix}.and_3"), registry.get_id(f"{prefix}.and_5"), registry.get_id(f"{prefix}.and_7")] return [] def infer_pattern_recognition_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for pattern recognition gates.""" prefix = gate.split('.')[0] + '.' + gate.split('.')[1] # Most take 8-bit input if 'popcount' in gate or 'allzeros' in gate or 'allones' in gate: inputs = [] for i in range(8): sig = registry.register(f"{prefix}.$x[{i}]") inputs.append(sig) return inputs if 'onehotdetector' in gate: if '.atleast1' in gate or '.atmost1' in gate: return [registry.register(f"{prefix}.$x[{i}]") for i in range(8)] if '.and' in gate: return [registry.register(f"{prefix}.atleast1"), registry.register(f"{prefix}.atmost1")] # Default 8-bit input return [registry.register(f"{prefix}.$x[{i}]") for i in range(8)] def infer_combinational_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for combinational gates.""" if 'decoder3to8' in gate: prefix = "combinational.decoder3to8" for i in range(3): registry.register(f"{prefix}.$sel[{i}]") return [registry.get_id(f"{prefix}.$sel[{i}]") for i in range(3)] if 'encoder8to3' in gate: prefix = "combinational.encoder8to3" for i in range(8): registry.register(f"{prefix}.$x[{i}]") return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(8)] if 'multiplexer2to1' in gate: prefix = "combinational.multiplexer2to1" registry.register(f"{prefix}.$a") registry.register(f"{prefix}.$b") registry.register(f"{prefix}.$sel") if '.not_s' in gate: return [registry.get_id(f"{prefix}.$sel")] if '.and0' in gate: not_s = registry.register(f"{prefix}.not_s") return [registry.get_id(f"{prefix}.$a"), not_s] if '.and1' in gate: return [registry.get_id(f"{prefix}.$b"), registry.get_id(f"{prefix}.$sel")] if '.or' in gate: return [registry.register(f"{prefix}.and0"), registry.register(f"{prefix}.and1")] if 'demultiplexer1to2' in gate: prefix = "combinational.demultiplexer1to2" registry.register(f"{prefix}.$in") registry.register(f"{prefix}.$sel") return [registry.get_id(f"{prefix}.$in"), registry.get_id(f"{prefix}.$sel")] return [] def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) -> List[int]: """Infer inputs for any gate.""" # Check routing first for complex circuits if routing: circuits = routing.get('circuits', {}) for circuit_name, circuit_data in circuits.items(): if gate.startswith(circuit_name): internal = circuit_data.get('internal', {}) # Find the gate's local name local_name = gate[len(circuit_name)+1:] if gate.startswith(circuit_name + '.') else gate if local_name in internal: sources = internal[local_name] inputs = [] for src in sources: if src.startswith('$'): full_src = f"{circuit_name}.{src}" elif src.startswith('#'): full_src = src else: full_src = f"{circuit_name}.{src}" inputs.append(registry.register(full_src)) return inputs # Boolean gates if gate.startswith('boolean.'): return infer_boolean_inputs(gate, registry) # Threshold gates if gate.startswith('threshold.'): return infer_threshold_inputs(gate, registry) # Modular arithmetic if gate.startswith('modular.'): return infer_modular_inputs(gate, registry) # Pattern recognition if gate.startswith('pattern_recognition.'): return infer_pattern_recognition_inputs(gate, registry) # Combinational if gate.startswith('combinational.'): return infer_combinational_inputs(gate, registry) # Arithmetic circuits if gate.startswith('arithmetic.'): # Half adder if 'halfadder' in gate and 'ripple' not in gate and 'multiplier' not in gate: return infer_halfadder_inputs(gate, 'arithmetic.halfadder', registry) # Full adder if gate.startswith('arithmetic.fulladder.') and 'ripple' not in gate: return infer_fulladder_inputs(gate, 'arithmetic.fulladder', registry) # Ripple carry adders if 'ripplecarry8bit' in gate: return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry8bit', 8, registry) if 'ripplecarry32bit' in gate: return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry32bit', 32, registry) if 'ripplecarry16bit' in gate: return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry16bit', 16, registry) if 'ripplecarry4bit' in gate: return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry4bit', 4, registry) if 'ripplecarry2bit' in gate: return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry2bit', 2, registry) # ADC/SBC if 'adc8bit' in gate: return infer_adc_sbc_inputs(gate, 'arithmetic.adc8bit', registry, bits=8) if 'adc32bit' in gate: return infer_adc_sbc_inputs(gate, 'arithmetic.adc32bit', registry, bits=32) if 'adc16bit' in gate: return infer_adc_sbc_inputs(gate, 'arithmetic.adc16bit', registry, bits=16) if 'sbc8bit' in gate: return infer_adc_sbc_inputs(gate, 'arithmetic.sbc8bit', registry, bits=8) if 'sbc32bit' in gate: return infer_adc_sbc_inputs(gate, 'arithmetic.sbc32bit', registry, bits=32) if 'sbc16bit' in gate: return infer_adc_sbc_inputs(gate, 'arithmetic.sbc16bit', registry, bits=16) # SUB if 'sub8bit' in gate: return infer_sub8bit_inputs(gate, registry) if 'sub32bit' in gate: return infer_sub_inputs(gate, "arithmetic.sub32bit", 32, registry) if 'sub16bit' in gate: return infer_sub16bit_inputs(gate, registry) # CMP if 'cmp8bit' in gate: return infer_cmp8bit_inputs(gate, registry) if 'cmp32bit' in gate: return infer_cmp_inputs(gate, "arithmetic.cmp32bit", 32, registry) if 'cmp16bit' in gate: return infer_cmp16bit_inputs(gate, registry) # Equality if 'equality8bit' in gate: return infer_equality8bit_inputs(gate, registry) if 'equality32bit' in gate: return infer_equality_inputs(gate, "arithmetic.equality32bit", 32, registry) if 'equality16bit' in gate: return infer_equality16bit_inputs(gate, registry) # Negate if 'neg8bit' in gate: return infer_neg8bit_inputs(gate, registry) if 'neg32bit' in gate: return infer_neg_inputs(gate, "arithmetic.neg32bit", 32, registry) if 'neg16bit' in gate: return infer_neg16bit_inputs(gate, registry) # Shifts and rotates if ('asr8bit' in gate or 'rol8bit' in gate or 'ror8bit' in gate or 'asr16bit' in gate or 'rol16bit' in gate or 'ror16bit' in gate or 'asr32bit' in gate or 'rol32bit' in gate or 'ror32bit' in gate): return infer_shift_rotate_inputs(gate, registry) # Multipliers if 'multiplier' in gate: return infer_multiplier_inputs(gate, registry) # Incrementer/Decrementer if 'incrementer' in gate or 'decrementer' in gate: return infer_incr_decr_inputs(gate, registry) # Min/Max/AbsoluteDifference if 'max8bit' in gate or 'min8bit' in gate or 'absolutedifference' in gate: return infer_minmax_inputs(gate, registry) # Comparators if 'greaterthan8bit' in gate or 'lessthan8bit' in gate or \ 'greaterorequal8bit' in gate or 'lessorequal8bit' in gate or \ 'greaterthan16bit' in gate or 'lessthan16bit' in gate or \ 'greaterorequal16bit' in gate or 'lessorequal16bit' in gate or \ 'greaterthan32bit' in gate or 'lessthan32bit' in gate or \ 'greaterorequal32bit' in gate or 'lessorequal32bit' in gate: return infer_comparator_inputs(gate, registry) # CLZ (count leading zeros) if 'clz16bit' in gate: return infer_clz16bit_inputs(gate, registry) if 'clz8bit' in gate: return infer_clz8bit_inputs(gate, registry) # Float32 circuits if gate.startswith('float32.'): if gate.startswith('float32.unpack'): return infer_float32_unpack_inputs(gate, registry) if gate.startswith('float32.pack'): return infer_float32_pack_inputs(gate, registry) if gate.startswith('float32.cmp'): return infer_float32_cmp_inputs(gate, registry) if gate.startswith('float32.neg'): return infer_float32_neg_inputs(gate, registry) if gate.startswith('float32.abs'): return infer_float32_abs_inputs(gate, registry) # Float16 circuits if gate.startswith('float16.'): if gate.startswith('float16.const_'): return [registry.get_id("#1")] if gate.endswith('.domain_not'): prefix = gate[:-len('.domain_not')] registry.register(f"{prefix}.domain") return [registry.get_id(f"{prefix}.domain")] checked_match = re.search(r'\.checked_out(\d+)$', gate) if checked_match: idx = int(checked_match.group(1)) prefix = gate[:gate.rfind('.checked_out')] nan_bits = 0x7E00 registry.register(f"{prefix}.out{idx}") if (nan_bits >> idx) & 1: registry.register(f"{prefix}.domain") return [registry.get_id(f"{prefix}.out{idx}"), registry.get_id(f"{prefix}.domain")] registry.register(f"{prefix}.domain_not") return [registry.get_id(f"{prefix}.out{idx}"), registry.get_id(f"{prefix}.domain_not")] if gate.startswith('float16.lut'): return infer_float16_lut_inputs(gate, registry) if gate.startswith('float16.pow'): return infer_float16_pow_inputs(gate, registry) if gate.startswith('float16.sqrt') or gate.startswith('float16.rsqrt') or \ gate.startswith('float16.exp') or gate.startswith('float16.ln') or \ gate.startswith('float16.log2') or gate.startswith('float16.log10') or \ gate.startswith('float16.deg2rad') or gate.startswith('float16.rad2deg') or \ gate.startswith('float16.is_nan') or gate.startswith('float16.is_inf') or \ gate.startswith('float16.is_finite') or gate.startswith('float16.is_zero') or \ gate.startswith('float16.is_subnormal') or gate.startswith('float16.is_normal') or \ gate.startswith('float16.is_negative') or \ gate.startswith('float16.sin') or gate.startswith('float16.cos') or \ gate.startswith('float16.tan') or gate.startswith('float16.tanh') or \ gate.startswith('float16.sin_deg') or gate.startswith('float16.cos_deg') or \ gate.startswith('float16.tan_deg') or gate.startswith('float16.asin_deg') or \ gate.startswith('float16.acos_deg') or gate.startswith('float16.atan_deg') or \ gate.startswith('float16.asin') or gate.startswith('float16.acos') or \ gate.startswith('float16.atan') or gate.startswith('float16.sinh') or \ gate.startswith('float16.cosh') or gate.startswith('float16.floor') or \ gate.startswith('float16.ceil') or gate.startswith('float16.round'): return infer_float16_lut_out_inputs(gate, registry, "float16.lut") if 'unpack' in gate: return infer_float16_unpack_inputs(gate, registry) if 'pack' in gate: return infer_float16_pack_inputs(gate, registry) if 'cmp' in gate: return infer_float16_cmp_inputs(gate, registry) if 'normalize' in gate: return infer_float16_normalize_inputs(gate, registry) if gate.startswith('float16.neg'): return infer_float16_neg_inputs(gate, registry) if gate.startswith('float16.abs'): return infer_float16_abs_inputs(gate, registry) if gate.startswith('float16.add'): return infer_float16_add_inputs(gate, registry) if gate.startswith('float16.sub'): return infer_float16_sub_inputs(gate, registry) if gate.startswith('float16.mul'): return infer_float16_mul_inputs(gate, registry) if gate.startswith('float16.div'): return infer_float16_div_inputs(gate, registry) if gate.startswith('float16.toint'): return infer_float16_toint_inputs(gate, registry) if gate.startswith('float16.fromint'): return infer_float16_fromint_inputs(gate, registry) # Default: couldn't infer, return empty (will need manual fix or routing) return [] def infer_float16_add_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.add circuit.""" prefix = "float16.add" # Register 32 input bits (two 16-bit operands) for i in range(16): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # Extract exponent bits (10-14) exp_a_bits = [f"{prefix}.$a[{10+i}]" for i in range(5)] exp_b_bits = [f"{prefix}.$b[{10+i}]" for i in range(5)] mant_a_bits = [f"{prefix}.$a[{i}]" for i in range(10)] mant_b_bits = [f"{prefix}.$b[{i}]" for i in range(10)] # Stage 0: Special case detection if '.exp_a_all_ones' in gate: return [registry.get_id(b) for b in exp_a_bits] if '.exp_b_all_ones' in gate: return [registry.get_id(b) for b in exp_b_bits] if '.exp_a_zero' in gate: return [registry.get_id(b) for b in exp_a_bits] if '.exp_b_zero' in gate: return [registry.get_id(b) for b in exp_b_bits] if '.mant_a_nonzero' in gate: return [registry.get_id(b) for b in mant_a_bits] if '.mant_b_nonzero' in gate: return [registry.get_id(b) for b in mant_b_bits] if '.mant_a_zero' in gate: return [registry.get_id(b) for b in mant_a_bits] if '.mant_b_zero' in gate: return [registry.get_id(b) for b in mant_b_bits] registry.register(f"{prefix}.exp_a_all_ones") registry.register(f"{prefix}.exp_b_all_ones") registry.register(f"{prefix}.exp_a_zero") registry.register(f"{prefix}.exp_b_zero") # Adjusted exponent bit 0 for subnormals if '.a_adj_exp0' in gate: return [registry.get_id(exp_a_bits[0]), registry.get_id(f"{prefix}.exp_a_zero")] if '.b_adj_exp0' in gate: return [registry.get_id(exp_b_bits[0]), registry.get_id(f"{prefix}.exp_b_zero")] if '.not_a_adj_exp0' in gate: return [registry.get_id(f"{prefix}.a_adj_exp0")] if '.not_b_adj_exp0' in gate: return [registry.get_id(f"{prefix}.b_adj_exp0")] registry.register(f"{prefix}.a_adj_exp0") registry.register(f"{prefix}.b_adj_exp0") registry.register(f"{prefix}.not_a_adj_exp0") registry.register(f"{prefix}.not_b_adj_exp0") registry.register(f"{prefix}.mant_a_nonzero") registry.register(f"{prefix}.mant_b_nonzero") registry.register(f"{prefix}.mant_a_zero") registry.register(f"{prefix}.mant_b_zero") if '.a_is_nan' in gate: return [registry.get_id(f"{prefix}.exp_a_all_ones"), registry.get_id(f"{prefix}.mant_a_nonzero")] if '.b_is_nan' in gate: return [registry.get_id(f"{prefix}.exp_b_all_ones"), registry.get_id(f"{prefix}.mant_b_nonzero")] if '.a_is_inf' in gate: return [registry.get_id(f"{prefix}.exp_a_all_ones"), registry.get_id(f"{prefix}.mant_a_zero")] if '.b_is_inf' in gate: return [registry.get_id(f"{prefix}.exp_b_all_ones"), registry.get_id(f"{prefix}.mant_b_zero")] if '.a_is_zero' in gate: return [registry.get_id(f"{prefix}.exp_a_zero"), registry.get_id(f"{prefix}.mant_a_zero")] if '.b_is_zero' in gate: return [registry.get_id(f"{prefix}.exp_b_zero"), registry.get_id(f"{prefix}.mant_b_zero")] if '.a_is_subnormal' in gate: return [registry.get_id(f"{prefix}.exp_a_zero"), registry.get_id(f"{prefix}.mant_a_nonzero")] if '.b_is_subnormal' in gate: return [registry.get_id(f"{prefix}.exp_b_zero"), registry.get_id(f"{prefix}.mant_b_nonzero")] registry.register(f"{prefix}.a_is_nan") registry.register(f"{prefix}.b_is_nan") registry.register(f"{prefix}.a_is_inf") registry.register(f"{prefix}.b_is_inf") if '.either_is_nan' in gate: return [registry.get_id(f"{prefix}.a_is_nan"), registry.get_id(f"{prefix}.b_is_nan")] # Register a_is_zero, b_is_zero before checking both_are_zero registry.register(f"{prefix}.a_is_zero") registry.register(f"{prefix}.b_is_zero") if '.both_are_zero' in gate: return [registry.get_id(f"{prefix}.a_is_zero"), registry.get_id(f"{prefix}.b_is_zero")] if '.either_is_zero' in gate: return [registry.get_id(f"{prefix}.a_is_zero"), registry.get_id(f"{prefix}.b_is_zero")] registry.register(f"{prefix}.both_are_zero") registry.register(f"{prefix}.either_is_zero") if '.both_exp_zero' in gate: return [registry.get_id(f"{prefix}.exp_a_zero"), registry.get_id(f"{prefix}.exp_b_zero")] registry.register(f"{prefix}.both_exp_zero") if '.subnorm_condition' in gate: return [registry.register(f"{prefix}.exp_underflow")] registry.register(f"{prefix}.subnorm_condition") if '.not_both_exp_zero' in gate: return [registry.get_id(f"{prefix}.both_exp_zero")] registry.register(f"{prefix}.not_both_exp_zero") if '.both_are_inf' in gate: return [registry.get_id(f"{prefix}.a_is_inf"), registry.get_id(f"{prefix}.b_is_inf")] # Sign extraction if gate == f"{prefix}.sign_a": return [registry.get_id(f"{prefix}.$a[15]")] if gate == f"{prefix}.sign_b": return [registry.get_id(f"{prefix}.$b[15]")] registry.register(f"{prefix}.sign_a") registry.register(f"{prefix}.sign_b") if '.signs_differ.layer1' in gate: return [registry.get_id(f"{prefix}.sign_a"), registry.get_id(f"{prefix}.sign_b")] if '.signs_differ.layer2' in gate: return [registry.register(f"{prefix}.signs_differ.layer1.or"), registry.register(f"{prefix}.signs_differ.layer1.nand")] registry.register(f"{prefix}.signs_differ.layer2") registry.register(f"{prefix}.either_is_nan") registry.register(f"{prefix}.both_are_inf") if '.inf_cancellation' in gate: return [registry.get_id(f"{prefix}.both_are_inf"), registry.get_id(f"{prefix}.signs_differ.layer2")] registry.register(f"{prefix}.inf_cancellation") if '.result_is_nan' in gate: return [registry.get_id(f"{prefix}.either_is_nan"), registry.get_id(f"{prefix}.inf_cancellation")] if '.either_is_inf' in gate: return [registry.get_id(f"{prefix}.a_is_inf"), registry.get_id(f"{prefix}.b_is_inf")] registry.register(f"{prefix}.result_is_nan") registry.register(f"{prefix}.either_is_inf") if '.not_result_is_nan' in gate: return [registry.get_id(f"{prefix}.result_is_nan")] registry.register(f"{prefix}.not_result_is_nan") if '.result_is_inf' in gate and '.not_result_is_inf' not in gate: return [registry.get_id(f"{prefix}.either_is_inf"), registry.get_id(f"{prefix}.exp_overflow_to_inf"), registry.get_id(f"{prefix}.not_result_is_nan")] # Implicit bit if '.implicit_a' in gate: return [registry.get_id(f"{prefix}.exp_a_zero")] if '.implicit_b' in gate: return [registry.get_id(f"{prefix}.exp_b_zero")] registry.register(f"{prefix}.implicit_a") registry.register(f"{prefix}.implicit_b") for i in range(10): if f'.not_div_b{i}' in gate: return [registry.get_id(f"{prefix}.$b[{i}]")] registry.register(f"{prefix}.not_div_b{i}") if '.not_implicit_b' in gate: return [registry.get_id(f"{prefix}.implicit_b")] registry.register(f"{prefix}.not_implicit_b") # Exponent comparison (using adjusted bit 0 for subnormals) # For subnormals, effective exp = 1, so we use a_adj_exp0 instead of raw bit 0 if '.a_exp_ge_b' in gate or '.a_exp_gt_b' in gate: a_adj = [registry.get_id(f"{prefix}.a_adj_exp0")] + \ [registry.get_id(b) for b in exp_a_bits[1:]] b_adj = [registry.get_id(f"{prefix}.b_adj_exp0")] + \ [registry.get_id(b) for b in exp_b_bits[1:]] return a_adj + b_adj if '.b_exp_gt_a' in gate and 'sel' not in gate: a_adj = [registry.get_id(f"{prefix}.a_adj_exp0")] + \ [registry.get_id(b) for b in exp_a_bits[1:]] b_adj = [registry.get_id(f"{prefix}.b_adj_exp0")] + \ [registry.get_id(b) for b in exp_b_bits[1:]] return b_adj + a_adj registry.register(f"{prefix}.a_exp_ge_b") registry.register(f"{prefix}.a_exp_gt_b") registry.register(f"{prefix}.b_exp_gt_a") if '.b_exp_gt_a_sel' in gate: return [registry.get_id(f"{prefix}.a_exp_ge_b")] registry.register(f"{prefix}.b_exp_gt_a_sel") # NOT gates for exponent bits (use adjusted bit 0 for subnormals) match = re.search(r'\.not_exp_b(\d+)', gate) if match: i = int(match.group(1)) if i == 0: return [registry.get_id(f"{prefix}.b_adj_exp0")] return [registry.get_id(f"{prefix}.$b[{10+i}]")] match = re.search(r'\.not_exp_a(\d+)', gate) if match: i = int(match.group(1)) if i == 0: return [registry.get_id(f"{prefix}.a_adj_exp0")] return [registry.get_id(f"{prefix}.$a[{10+i}]")] for i in range(5): registry.register(f"{prefix}.not_exp_b{i}") registry.register(f"{prefix}.not_exp_a{i}") # Exp diff subtractors (diff_ab and diff_ba) using adjusted bit 0 if '.diff_ab.fa' in gate or '.diff_ba.fa' in gate: is_ab = '.diff_ab' in gate match = re.search(r'\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.diff_{'ab' if is_ab else 'ba'}.fa{i}" if is_ab: if i == 0: a_bit = registry.get_id(f"{prefix}.a_adj_exp0") else: a_bit = registry.get_id(f"{prefix}.$a[{10+i}]") not_b = registry.get_id(f"{prefix}.not_exp_b{i}") else: if i == 0: a_bit = registry.get_id(f"{prefix}.b_adj_exp0") else: a_bit = registry.get_id(f"{prefix}.$b[{10+i}]") not_b = registry.get_id(f"{prefix}.not_exp_a{i}") if i == 0: cin = registry.get_id("#1") else: cin = registry.register(f"{prefix}.diff_{'ab' if is_ab else 'ba'}.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, not_b] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, not_b] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] # Register diff outputs for i in range(5): registry.register(f"{prefix}.diff_ab.fa{i}.xor2.layer2") registry.register(f"{prefix}.diff_ba.fa{i}.xor2.layer2") # Exp diff mux match = re.search(r'\.exp_diff_mux(\d+)\.', gate) if match: i = int(match.group(1)) if '.and_ab' in gate: return [registry.get_id(f"{prefix}.diff_ab.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.a_exp_ge_b")] if '.and_ba' in gate: return [registry.get_id(f"{prefix}.diff_ba.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.b_exp_gt_a_sel")] match = re.search(r'\.exp_diff(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.exp_diff_mux{i}.and_ab"), registry.register(f"{prefix}.exp_diff_mux{i}.and_ba")] for i in range(5): registry.register(f"{prefix}.exp_diff{i}") # Exp larger mux (using adjusted bit 0 for subnormals) match = re.search(r'\.exp_larger_mux(\d+)\.', gate) if match: i = int(match.group(1)) if '.and_a' in gate: if i == 0: return [registry.get_id(f"{prefix}.a_adj_exp0"), registry.get_id(f"{prefix}.a_exp_ge_b")] return [registry.get_id(f"{prefix}.$a[{10+i}]"), registry.get_id(f"{prefix}.a_exp_ge_b")] if '.and_b' in gate: if i == 0: return [registry.get_id(f"{prefix}.b_adj_exp0"), registry.get_id(f"{prefix}.b_exp_gt_a_sel")] return [registry.get_id(f"{prefix}.$b[{10+i}]"), registry.get_id(f"{prefix}.b_exp_gt_a_sel")] match = re.search(r'\.exp_larger(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.exp_larger_mux{i}.and_a"), registry.register(f"{prefix}.exp_larger_mux{i}.and_b")] for i in range(5): registry.register(f"{prefix}.exp_larger{i}") # Mantissa source selection (which mantissa to shift) # Use magnitude compare (exp, then mantissa) so subtraction uses larger magnitude. # mant_shift_src = a_magnitude_ge_b ? mant_b : mant_a # mant_larger = a_magnitude_ge_b ? mant_a : mant_b match = re.search(r'\.mant_shift_src(\d+)\.', gate) if match: i = int(match.group(1)) if i < 10: mant_a = registry.get_id(f"{prefix}.$a[{i}]") mant_b = registry.get_id(f"{prefix}.$b[{i}]") else: mant_a = registry.get_id(f"{prefix}.implicit_a") mant_b = registry.get_id(f"{prefix}.implicit_b") if '.and_b' in gate: return [mant_b, registry.register(f"{prefix}.a_magnitude_ge_b")] if '.and_a' in gate: return [mant_a, registry.register(f"{prefix}.not_a_mag_ge_b")] match = re.search(r'\.mant_shift_src(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.mant_shift_src{i}.and_b"), registry.register(f"{prefix}.mant_shift_src{i}.and_a")] match = re.search(r'\.mant_larger(\d+)\.', gate) if match: i = int(match.group(1)) if i < 10: mant_a = registry.get_id(f"{prefix}.$a[{i}]") mant_b = registry.get_id(f"{prefix}.$b[{i}]") else: mant_a = registry.get_id(f"{prefix}.implicit_a") mant_b = registry.get_id(f"{prefix}.implicit_b") if '.and_a' in gate: return [mant_a, registry.register(f"{prefix}.a_magnitude_ge_b")] if '.and_b' in gate: return [mant_b, registry.register(f"{prefix}.not_a_mag_ge_b")] match = re.search(r'\.mant_larger(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.mant_larger{i}.and_a"), registry.register(f"{prefix}.mant_larger{i}.and_b")] for i in range(11): registry.register(f"{prefix}.mant_shift_src{i}") registry.register(f"{prefix}.mant_larger{i}") # NOT gates for exp_diff bits (barrel shifter control) for i in range(5): if f'.not_exp_diff{i}' in gate and f'.not_exp_diff{i}.' not in gate: return [registry.get_id(f"{prefix}.exp_diff{i}")] registry.register(f"{prefix}.not_exp_diff{i}") # Barrel shifter stage 0 (shift by 1) match = re.search(r'\.shift_s0_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.mant_shift_src{i}"), registry.get_id(f"{prefix}.not_exp_diff0")] if '.shift' in gate and i < 10: return [registry.get_id(f"{prefix}.mant_shift_src{i+1}"), registry.get_id(f"{prefix}.exp_diff0")] match = re.search(r'\.shift_s0_(\d+)$', gate) if match: i = int(match.group(1)) if i < 10: return [registry.register(f"{prefix}.shift_s0_{i}.pass"), registry.register(f"{prefix}.shift_s0_{i}.shift")] else: return [registry.register(f"{prefix}.shift_s0_{i}.pass")] for i in range(11): registry.register(f"{prefix}.shift_s0_{i}") # Barrel shifter stage 1 (shift by 2) match = re.search(r'\.shift_s1_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.shift_s0_{i}"), registry.get_id(f"{prefix}.not_exp_diff1")] if '.shift' in gate and i < 9: return [registry.get_id(f"{prefix}.shift_s0_{i+2}"), registry.get_id(f"{prefix}.exp_diff1")] match = re.search(r'\.shift_s1_(\d+)$', gate) if match: i = int(match.group(1)) if i < 9: return [registry.register(f"{prefix}.shift_s1_{i}.pass"), registry.register(f"{prefix}.shift_s1_{i}.shift")] else: return [registry.register(f"{prefix}.shift_s1_{i}.pass")] for i in range(11): registry.register(f"{prefix}.shift_s1_{i}") # Barrel shifter stage 2 (shift by 4) match = re.search(r'\.shift_s2_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.shift_s1_{i}"), registry.get_id(f"{prefix}.not_exp_diff2")] if '.shift' in gate and i < 7: return [registry.get_id(f"{prefix}.shift_s1_{i+4}"), registry.get_id(f"{prefix}.exp_diff2")] match = re.search(r'\.shift_s2_(\d+)$', gate) if match: i = int(match.group(1)) if i < 7: return [registry.register(f"{prefix}.shift_s2_{i}.pass"), registry.register(f"{prefix}.shift_s2_{i}.shift")] else: return [registry.register(f"{prefix}.shift_s2_{i}.pass")] for i in range(11): registry.register(f"{prefix}.shift_s2_{i}") # Barrel shifter stage 3 (shift by 8) match = re.search(r'\.shift_s3_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.shift_s2_{i}"), registry.get_id(f"{prefix}.not_exp_diff3")] if '.shift' in gate and i < 3: return [registry.get_id(f"{prefix}.shift_s2_{i+8}"), registry.get_id(f"{prefix}.exp_diff3")] match = re.search(r'\.shift_s3_(\d+)$', gate) if match: i = int(match.group(1)) if i < 3: return [registry.register(f"{prefix}.shift_s3_{i}.pass"), registry.register(f"{prefix}.shift_s3_{i}.shift")] else: return [registry.register(f"{prefix}.shift_s3_{i}.pass")] for i in range(11): registry.register(f"{prefix}.shift_s3_{i}") # mant_aligned (masked by not_exp_diff4) match = re.search(r'\.mant_aligned(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.shift_s3_{i}"), registry.get_id(f"{prefix}.not_exp_diff4")] for i in range(11): registry.register(f"{prefix}.mant_aligned{i}") # signs_same = NOT signs_differ if '.signs_same' in gate: return [registry.get_id(f"{prefix}.signs_differ.layer2")] registry.register(f"{prefix}.signs_same") # Mantissa comparison (for equal exponent case) if '.mant_a_ge_b' in gate: mant_a_full = [registry.get_id(f"{prefix}.$a[{i}]") for i in range(10)] + \ [registry.get_id(f"{prefix}.implicit_a")] mant_b_full = [registry.get_id(f"{prefix}.$b[{i}]") for i in range(10)] + \ [registry.get_id(f"{prefix}.implicit_b")] return mant_a_full + mant_b_full registry.register(f"{prefix}.mant_a_ge_b") # NOT gates for mant_aligned (for subtraction) match = re.search(r'\.not_mant_aligned(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.mant_aligned{i}")] for i in range(11): registry.register(f"{prefix}.not_mant_aligned{i}") # ========================================================================= # GUARD/ROUND/STICKY BIT INFERENCE # ========================================================================= # exp_diff_eq[k]: exp_diff == k match = re.search(r'\.exp_diff_eq(\d+)$', gate) if match: return [registry.get_id(f"{prefix}.exp_diff{i}") for i in range(5)] for k in range(1, 12): registry.register(f"{prefix}.exp_diff_eq{k}") # guard_sel[k]: mant_shift_src[k-1] AND exp_diff_eq[k] match = re.search(r'\.guard_sel(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.mant_shift_src{k-1}"), registry.get_id(f"{prefix}.exp_diff_eq{k}")] for k in range(1, 12): registry.register(f"{prefix}.guard_sel{k}") # guard_bit: OR of all guard_sel[k] if gate.endswith('.guard_bit'): return [registry.get_id(f"{prefix}.guard_sel{k}") for k in range(1, 12)] registry.register(f"{prefix}.guard_bit") # round_sel[k]: mant_shift_src[k-2] AND exp_diff_eq[k] match = re.search(r'\.round_sel(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.mant_shift_src{k-2}"), registry.get_id(f"{prefix}.exp_diff_eq{k}")] for k in range(2, 12): registry.register(f"{prefix}.round_sel{k}") # round_bit: OR of all round_sel[k] if gate.endswith('.round_bit'): return [registry.get_id(f"{prefix}.round_sel{k}") for k in range(2, 12)] registry.register(f"{prefix}.round_bit") # exp_diff_gt[k]: exp_diff > k match = re.search(r'\.exp_diff_gt(\d+)$', gate) if match: return [registry.get_id(f"{prefix}.exp_diff{i}") for i in range(5)] for k in range(13): registry.register(f"{prefix}.exp_diff_gt{k}") # sticky_part[i]: mant_shift_src[i] AND exp_diff > i+2 (bits below round) match = re.search(r'\.sticky_part(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.mant_shift_src{i}"), registry.get_id(f"{prefix}.exp_diff_gt{i+2}")] for i in range(11): registry.register(f"{prefix}.sticky_part{i}") # sticky_bit: OR of all sticky_part[i] if gate.endswith('.sticky_bit'): return [registry.get_id(f"{prefix}.sticky_part{i}") for i in range(11)] registry.register(f"{prefix}.sticky_bit") # not_sticky_bit for subtraction if '.not_sticky_bit' in gate: return [registry.get_id(f"{prefix}.sticky_bit")] registry.register(f"{prefix}.not_sticky_bit") # not_round_bit for subtraction if '.not_round_bit' in gate: return [registry.get_id(f"{prefix}.round_bit")] registry.register(f"{prefix}.not_round_bit") # not_guard_bit for subtraction if '.not_guard_bit' in gate: return [registry.get_id(f"{prefix}.guard_bit")] registry.register(f"{prefix}.not_guard_bit") # sub_cin = signs_differ (carry-in for 2's complement subtraction) if '.sub_cin' in gate: return [registry.get_id(f"{prefix}.signs_differ.layer2")] registry.register(f"{prefix}.sub_cin") # addsub_b_s: sticky bit operand for position 0 if '.addsub_b_s.add' in gate: return [registry.get_id("#0"), registry.get_id(f"{prefix}.signs_same")] if '.addsub_b_s.sub' in gate: return [registry.get_id(f"{prefix}.not_sticky_bit"), registry.get_id(f"{prefix}.signs_differ.layer2")] if gate.endswith('.addsub_b_s'): return [registry.register(f"{prefix}.addsub_b_s.add"), registry.register(f"{prefix}.addsub_b_s.sub")] registry.register(f"{prefix}.addsub_b_s") # addsub_b_r: round bit operand for position 1 if '.addsub_b_r.add' in gate: return [registry.get_id(f"{prefix}.round_bit"), registry.get_id(f"{prefix}.signs_same")] if '.addsub_b_r.sub' in gate: return [registry.get_id(f"{prefix}.not_round_bit"), registry.get_id(f"{prefix}.signs_differ.layer2")] if gate.endswith('.addsub_b_r'): return [registry.register(f"{prefix}.addsub_b_r.add"), registry.register(f"{prefix}.addsub_b_r.sub")] registry.register(f"{prefix}.addsub_b_r") # addsub_b_g: guard bit operand for position 2 if '.addsub_b_g.add' in gate: return [registry.get_id(f"{prefix}.guard_bit"), registry.get_id(f"{prefix}.signs_same")] if '.addsub_b_g.sub' in gate: return [registry.get_id(f"{prefix}.not_guard_bit"), registry.get_id(f"{prefix}.signs_differ.layer2")] if gate.endswith('.addsub_b_g'): return [registry.register(f"{prefix}.addsub_b_g.add"), registry.register(f"{prefix}.addsub_b_g.sub")] registry.register(f"{prefix}.addsub_b_g") # addsub_b[0:10] selection for mantissa positions match = re.search(r'\.addsub_b(\d+)\.', gate) if match: i = int(match.group(1)) if gate.endswith('.sub'): return [registry.get_id(f"{prefix}.not_mant_aligned{i}"), registry.get_id(f"{prefix}.signs_differ.layer2")] if gate.endswith('.add'): return [registry.get_id(f"{prefix}.mant_aligned{i}"), registry.get_id(f"{prefix}.signs_same")] match = re.search(r'\.addsub_b(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.addsub_b{i}.add"), registry.register(f"{prefix}.addsub_b{i}.sub")] for i in range(11): registry.register(f"{prefix}.addsub_b{i}") # 15-bit mantissa adder with sticky+round+guard bits # Bit 0: A=#0, B=addsub_b_s (sticky position) # Bit 1: A=#0, B=addsub_b_r (round position) # Bit 2: A=#0, B=addsub_b_g (guard position) # Bits 3-13: A=mant_larger[i-3], B=addsub_b[i-3] # Bit 14: A=#0, B=#0 (overflow) if '.mant_add.fa' in gate: match = re.search(r'\.mant_add\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.mant_add.fa{i}" if i == 0: a_bit = registry.get_id("#0") b_bit = registry.get_id(f"{prefix}.addsub_b_s") elif i == 1: a_bit = registry.get_id("#0") b_bit = registry.get_id(f"{prefix}.addsub_b_r") elif i == 2: a_bit = registry.get_id("#0") b_bit = registry.get_id(f"{prefix}.addsub_b_g") elif i <= 13: a_bit = registry.get_id(f"{prefix}.mant_larger{i-3}") b_bit = registry.get_id(f"{prefix}.addsub_b{i-3}") else: a_bit = registry.get_id("#0") b_bit = registry.get_id("#0") if i == 0: cin = registry.get_id(f"{prefix}.sub_cin") else: cin = registry.register(f"{prefix}.mant_add.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(15): registry.register(f"{prefix}.mant_add.fa{i}.xor2.layer2") registry.register(f"{prefix}.mant_add.fa{i}.cout") # Result sign determination if '.not_a_exp_gt_b' in gate: return [registry.get_id(f"{prefix}.a_exp_gt_b")] registry.register(f"{prefix}.not_a_exp_gt_b") # exp_a_eq_b: NOT(a > b) AND NOT(b > a) = NOT(a_exp_gt_b) AND a_exp_ge_b # a_exp_ge_b = (exp_a >= exp_b) = NOT(exp_b > exp_a) if '.exp_a_eq_b' in gate: return [registry.get_id(f"{prefix}.not_a_exp_gt_b"), registry.get_id(f"{prefix}.a_exp_ge_b")] registry.register(f"{prefix}.exp_a_eq_b") if '.exp_eq_and_mant_a_ge' in gate: return [registry.get_id(f"{prefix}.exp_a_eq_b"), registry.get_id(f"{prefix}.mant_a_ge_b")] registry.register(f"{prefix}.exp_eq_and_mant_a_ge") if '.a_magnitude_ge_b' in gate: return [registry.get_id(f"{prefix}.a_exp_gt_b"), registry.get_id(f"{prefix}.exp_eq_and_mant_a_ge")] registry.register(f"{prefix}.a_magnitude_ge_b") if '.not_a_mag_ge_b' in gate: return [registry.get_id(f"{prefix}.a_magnitude_ge_b")] registry.register(f"{prefix}.not_a_mag_ge_b") if '.diff_sign_sel_a' in gate: return [registry.get_id(f"{prefix}.sign_a"), registry.get_id(f"{prefix}.a_magnitude_ge_b")] if '.diff_sign_sel_b' in gate: return [registry.get_id(f"{prefix}.sign_b"), registry.get_id(f"{prefix}.not_a_mag_ge_b")] registry.register(f"{prefix}.diff_sign_sel_a") registry.register(f"{prefix}.diff_sign_sel_b") if '.diff_result_sign' in gate: return [registry.get_id(f"{prefix}.diff_sign_sel_a"), registry.get_id(f"{prefix}.diff_sign_sel_b")] registry.register(f"{prefix}.diff_result_sign") if '.result_sign_same' in gate: return [registry.get_id(f"{prefix}.sign_a"), registry.get_id(f"{prefix}.signs_same")] if '.result_sign_diff' in gate: return [registry.get_id(f"{prefix}.diff_result_sign"), registry.get_id(f"{prefix}.signs_differ.layer2")] registry.register(f"{prefix}.result_sign_same") registry.register(f"{prefix}.result_sign_diff") if gate == f"{prefix}.result_sign": return [registry.get_id(f"{prefix}.result_sign_same"), registry.get_id(f"{prefix}.result_sign_diff")] registry.register(f"{prefix}.result_sign") # Normalization - sum overflow detection (now bit 14) if gate.endswith('.sum_bit12'): return [registry.get_id(f"{prefix}.mant_add.fa14.xor2.layer2")] registry.register(f"{prefix}.sum_bit12") # sum_overflow: bit 12 AND signs_same if '.sum_overflow' in gate and '.not_sum_overflow' not in gate: return [registry.get_id(f"{prefix}.sum_bit12"), registry.get_id(f"{prefix}.signs_same")] registry.register(f"{prefix}.sum_overflow") # Detect subtraction resulting in zero (bits 13:0 all zero AND signs_differ) sum_bits = [f"{prefix}.mant_add.fa{i}.xor2.layer2" for i in range(14)] # sum_bits_zero: NOR of bits 13:0 (14 bits including sticky/round/guard) if gate.endswith('.sum_bits_zero'): return [registry.get_id(b) for b in sum_bits] registry.register(f"{prefix}.sum_bits_zero") if gate.endswith('.sum_is_zero'): return [registry.get_id(f"{prefix}.sum_bits_zero"), registry.get_id(f"{prefix}.signs_differ.layer2")] registry.register(f"{prefix}.sum_is_zero") if gate.endswith('.not_sum_is_zero'): return [registry.get_id(f"{prefix}.sum_is_zero")] registry.register(f"{prefix}.not_sum_is_zero") if '.subnorm_overflow' in gate: return [registry.get_id(f"{prefix}.mant_add.fa13.xor2.layer2")] registry.register(f"{prefix}.subnorm_overflow") if '.subnorm_enable' in gate: return [registry.register(f"{prefix}.subnorm_condition"), registry.get_id(f"{prefix}.not_sum_is_zero"), registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_result_is_inf")] registry.register(f"{prefix}.subnorm_enable") # CLZ on 14-bit sum (bits 13:0) match = re.search(r'\.sum_pz(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(sum_bits[13-i]) for i in range(k)] for k in range(1, 15): registry.register(f"{prefix}.sum_pz{k}") pz_ids = [registry.get_id(f"{prefix}.sum_pz{k}") for k in range(1, 15)] match = re.search(r'\.sum_ge(\d+)$', gate) if match: return pz_ids for k in range(1, 15): registry.register(f"{prefix}.sum_ge{k}") match = re.search(r'\.sum_not_ge(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.sum_ge{k}")] for k in [2, 4, 6, 8, 10, 12]: registry.register(f"{prefix}.sum_not_ge{k}") if '.norm_shift3' in gate: return [registry.get_id(f"{prefix}.sum_ge8")] if '.norm_and_4_7' in gate: return [registry.get_id(f"{prefix}.sum_ge4"), registry.get_id(f"{prefix}.sum_not_ge8")] registry.register(f"{prefix}.norm_and_4_7") if '.norm_and_12' in gate: return [registry.get_id(f"{prefix}.sum_ge12")] registry.register(f"{prefix}.norm_and_12") # shift2 = norm_and_4_7 OR norm_and_12 if '.norm_shift2' in gate: return [registry.get_id(f"{prefix}.norm_and_4_7"), registry.get_id(f"{prefix}.norm_and_12")] if '.norm_and_2_3' in gate: return [registry.get_id(f"{prefix}.sum_ge2"), registry.get_id(f"{prefix}.sum_not_ge4")] if '.norm_and_6_7' in gate: return [registry.get_id(f"{prefix}.sum_ge6"), registry.get_id(f"{prefix}.sum_not_ge8")] if '.norm_and_10_11' in gate: return [registry.get_id(f"{prefix}.sum_ge10"), registry.get_id(f"{prefix}.sum_not_ge12")] registry.register(f"{prefix}.norm_and_2_3") registry.register(f"{prefix}.norm_and_6_7") registry.register(f"{prefix}.norm_and_10_11") if '.norm_shift1' in gate: return [registry.get_id(f"{prefix}.norm_and_2_3"), registry.get_id(f"{prefix}.norm_and_6_7"), registry.get_id(f"{prefix}.norm_and_10_11")] match = re.search(r'\.norm_and_(\d+)$', gate) if match: i = int(match.group(1)) if i in [1, 3, 5, 7, 9, 11]: return [registry.get_id(f"{prefix}.sum_ge{i}"), registry.get_id(f"{prefix}.sum_not_ge{i+1}")] for i in [1, 3, 5, 7, 9, 11]: registry.register(f"{prefix}.norm_and_{i}") if '.norm_shift0' in gate: return [registry.get_id(f"{prefix}.norm_and_{i}") for i in [1, 3, 5, 7, 9, 11]] + \ [registry.get_id(f"{prefix}.sum_ge13")] for i in range(4): registry.register(f"{prefix}.norm_shift{i}") # Stage 10: Normalization application if '.not_sum_overflow' in gate: return [registry.get_id(f"{prefix}.sum_overflow")] registry.register(f"{prefix}.not_sum_overflow") # Overflow mantissa: sum[i+4] (skipping overflow bit at 14 and sticky/round/guard at 0/1/2) match = re.search(r'\.norm_mant_overflow(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.mant_add.fa{i+4}.xor2.layer2")] for i in range(10): registry.register(f"{prefix}.norm_mant_overflow{i}") # Left barrel shifter NOT gates for i in range(4): if f'.not_norm_shift{i}' in gate and '.not_norm_shift_sub' not in gate: return [registry.get_id(f"{prefix}.norm_shift{i}")] registry.register(f"{prefix}.not_norm_shift{i}") # 12-bit left barrel shifter stage 0 match = re.search(r'\.lshift_s0_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.mant_add.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_norm_shift0")] if '.shift' in gate and i > 0: return [registry.get_id(f"{prefix}.mant_add.fa{i-1}.xor2.layer2"), registry.get_id(f"{prefix}.norm_shift0")] match = re.search(r'\.lshift_s0_(\d+)$', gate) if match: i = int(match.group(1)) if i > 0: return [registry.register(f"{prefix}.lshift_s0_{i}.pass"), registry.register(f"{prefix}.lshift_s0_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s0_{i}.pass")] for i in range(14): registry.register(f"{prefix}.lshift_s0_{i}") # 12-bit left barrel shifter stage 1 match = re.search(r'\.lshift_s1_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.lshift_s0_{i}"), registry.get_id(f"{prefix}.not_norm_shift1")] if '.shift' in gate and i > 1: return [registry.get_id(f"{prefix}.lshift_s0_{i-2}"), registry.get_id(f"{prefix}.norm_shift1")] match = re.search(r'\.lshift_s1_(\d+)$', gate) if match: i = int(match.group(1)) if i > 1: return [registry.register(f"{prefix}.lshift_s1_{i}.pass"), registry.register(f"{prefix}.lshift_s1_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s1_{i}.pass")] for i in range(14): registry.register(f"{prefix}.lshift_s1_{i}") # 12-bit left barrel shifter stage 2 match = re.search(r'\.lshift_s2_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.lshift_s1_{i}"), registry.get_id(f"{prefix}.not_norm_shift2")] if '.shift' in gate and i > 3: return [registry.get_id(f"{prefix}.lshift_s1_{i-4}"), registry.get_id(f"{prefix}.norm_shift2")] match = re.search(r'\.lshift_s2_(\d+)$', gate) if match: i = int(match.group(1)) if i > 3: return [registry.register(f"{prefix}.lshift_s2_{i}.pass"), registry.register(f"{prefix}.lshift_s2_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s2_{i}.pass")] for i in range(14): registry.register(f"{prefix}.lshift_s2_{i}") # 12-bit left barrel shifter stage 3 match = re.search(r'\.lshift_s3_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.lshift_s2_{i}"), registry.get_id(f"{prefix}.not_norm_shift3")] if '.shift' in gate and i > 7: return [registry.get_id(f"{prefix}.lshift_s2_{i-8}"), registry.get_id(f"{prefix}.norm_shift3")] match = re.search(r'\.lshift_s3_(\d+)$', gate) if match: i = int(match.group(1)) if i > 7: return [registry.register(f"{prefix}.lshift_s3_{i}.pass"), registry.register(f"{prefix}.lshift_s3_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s3_{i}.pass")] for i in range(14): registry.register(f"{prefix}.lshift_s3_{i}") # norm_mant[i] = overflow ? sum[i+4] : lshift[i+3] match = re.search(r'\.norm_mant(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: return [registry.get_id(f"{prefix}.norm_mant_overflow{i}"), registry.get_id(f"{prefix}.sum_overflow")] if '.normal_path' in gate: return [registry.get_id(f"{prefix}.lshift_s3_{i+3}"), registry.get_id(f"{prefix}.not_sum_overflow")] match = re.search(r'\.norm_mant(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.norm_mant{i}.overflow_path"), registry.register(f"{prefix}.norm_mant{i}.normal_path")] for i in range(10): registry.register(f"{prefix}.norm_mant{i}") # Exponent increment (for overflow) if '.exp_inc.ha0.sum' in gate: return [registry.get_id(f"{prefix}.exp_larger0")] if '.exp_inc.ha0.cout' in gate: return [registry.get_id(f"{prefix}.exp_larger0")] registry.register(f"{prefix}.exp_inc.ha0.sum") registry.register(f"{prefix}.exp_inc.ha0.cout") for i in range(1, 5): if f'.exp_inc.ha{i}.xor.layer1' in gate: return [registry.get_id(f"{prefix}.exp_larger{i}"), registry.get_id(f"{prefix}.exp_inc.ha{i-1}.cout")] if f'.exp_inc.ha{i}.sum' in gate: return [registry.register(f"{prefix}.exp_inc.ha{i}.xor.layer1.or"), registry.register(f"{prefix}.exp_inc.ha{i}.xor.layer1.nand")] if f'.exp_inc.ha{i}.cout' in gate: return [registry.get_id(f"{prefix}.exp_larger{i}"), registry.get_id(f"{prefix}.exp_inc.ha{i-1}.cout")] registry.register(f"{prefix}.exp_inc.ha{i}.sum") registry.register(f"{prefix}.exp_inc.ha{i}.cout") # Exponent decrement NOT gates for i in range(4): if f'.not_norm_shift_sub{i}' in gate: return [registry.get_id(f"{prefix}.norm_shift{i}")] registry.register(f"{prefix}.not_norm_shift_sub{i}") # Exponent decrement (for no overflow) if '.exp_dec.fa' in gate: match = re.search(r'\.exp_dec\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_dec.fa{i}" exp_bit = registry.get_id(f"{prefix}.exp_larger{i}") if i < 4: not_shift = registry.get_id(f"{prefix}.not_norm_shift_sub{i}") else: not_shift = registry.get_id("#1") if i == 0: cin = registry.get_id("#1") else: cin = registry.register(f"{prefix}.exp_dec.fa{i-1}.cout") if '.xor1.layer1' in gate: return [exp_bit, not_shift] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [exp_bit, not_shift] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.exp_dec.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_dec.fa{i}.cout") # Exponent underflow detection (exp_dec <= 0) when no overflow if '.exp_dec_borrow' in gate: return [registry.get_id(f"{prefix}.exp_dec.fa4.cout")] if '.exp_dec_zero' in gate: return [registry.get_id(f"{prefix}.exp_dec.fa{i}.xor2.layer2") for i in range(5)] if '.exp_underflow_or_zero' in gate: return [registry.get_id(f"{prefix}.exp_dec_borrow"), registry.get_id(f"{prefix}.exp_dec_zero")] if '.exp_underflow' in gate: return [registry.get_id(f"{prefix}.exp_underflow_or_zero"), registry.get_id(f"{prefix}.not_sum_overflow")] registry.register(f"{prefix}.exp_dec_borrow") registry.register(f"{prefix}.exp_dec_zero") registry.register(f"{prefix}.exp_underflow_or_zero") registry.register(f"{prefix}.exp_underflow") # Result exponent selection match = re.search(r'\.result_exp(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: if i == 0: return [registry.get_id(f"{prefix}.exp_inc.ha0.sum"), registry.get_id(f"{prefix}.sum_overflow")] else: return [registry.get_id(f"{prefix}.exp_inc.ha{i}.sum"), registry.get_id(f"{prefix}.sum_overflow")] if '.normal_path' in gate: return [registry.get_id(f"{prefix}.exp_dec.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_sum_overflow")] match = re.search(r'\.result_exp(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.result_exp{i}.overflow_path"), registry.register(f"{prefix}.result_exp{i}.normal_path")] for i in range(5): registry.register(f"{prefix}.result_exp{i}") # sub_shift = norm_shift + 1 - exp_larger (for subnormal right shift) if '.sub_shift_add.fa' in gate: match = re.search(r'\.sub_shift_add\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_shift_add.fa{i}" a_bit = registry.get_id(f"{prefix}.norm_shift{i}") if i < 4 else registry.get_id("#0") b_bit = registry.get_id("#1") if i == 0 else registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.sub_shift_add.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.sub_shift_add.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_shift_add.fa{i}.cout") match = re.search(r'\.sub_shift_not_exp(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.exp_larger{i}")] for i in range(5): registry.register(f"{prefix}.sub_shift_not_exp{i}") if '.sub_shift.fa' in gate: match = re.search(r'\.sub_shift\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_shift.fa{i}" a_bit = registry.get_id(f"{prefix}.sub_shift_add.fa{i}.xor2.layer2") b_bit = registry.get_id(f"{prefix}.sub_shift_not_exp{i}") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.sub_shift.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.sub_shift.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_shift.fa{i}.cout") match = re.search(r'\.not_sub_shift(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_shift.fa{i}.xor2.layer2")] for i in range(5): registry.register(f"{prefix}.not_sub_shift{i}") # Subnormal right barrel shifter (14-bit input from lshift_s3) match = re.search(r'\.sub_rshift_s0_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.lshift_s3_{i}"), registry.get_id(f"{prefix}.not_sub_shift0")] if '.shift' in gate and i < 13: return [registry.get_id(f"{prefix}.lshift_s3_{i+1}"), registry.get_id(f"{prefix}.sub_shift.fa0.xor2.layer2")] match = re.search(r'\.sub_rshift_s0_(\d+)$', gate) if match: i = int(match.group(1)) if i < 13: return [registry.register(f"{prefix}.sub_rshift_s0_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s0_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s0_{i}.pass")] for i in range(14): registry.register(f"{prefix}.sub_rshift_s0_{i}") match = re.search(r'\.sub_rshift_s1_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s0_{i}"), registry.get_id(f"{prefix}.not_sub_shift1")] if '.shift' in gate and i < 12: return [registry.get_id(f"{prefix}.sub_rshift_s0_{i+2}"), registry.get_id(f"{prefix}.sub_shift.fa1.xor2.layer2")] match = re.search(r'\.sub_rshift_s1_(\d+)$', gate) if match: i = int(match.group(1)) if i < 12: return [registry.register(f"{prefix}.sub_rshift_s1_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s1_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s1_{i}.pass")] for i in range(14): registry.register(f"{prefix}.sub_rshift_s1_{i}") match = re.search(r'\.sub_rshift_s2_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s1_{i}"), registry.get_id(f"{prefix}.not_sub_shift2")] if '.shift' in gate and i < 10: return [registry.get_id(f"{prefix}.sub_rshift_s1_{i+4}"), registry.get_id(f"{prefix}.sub_shift.fa2.xor2.layer2")] match = re.search(r'\.sub_rshift_s2_(\d+)$', gate) if match: i = int(match.group(1)) if i < 10: return [registry.register(f"{prefix}.sub_rshift_s2_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s2_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s2_{i}.pass")] for i in range(14): registry.register(f"{prefix}.sub_rshift_s2_{i}") match = re.search(r'\.sub_rshift_s3_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s2_{i}"), registry.get_id(f"{prefix}.not_sub_shift3")] if '.shift' in gate and i < 6: return [registry.get_id(f"{prefix}.sub_rshift_s2_{i+8}"), registry.get_id(f"{prefix}.sub_shift.fa3.xor2.layer2")] match = re.search(r'\.sub_rshift_s3_(\d+)$', gate) if match: i = int(match.group(1)) if i < 6: return [registry.register(f"{prefix}.sub_rshift_s3_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s3_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s3_{i}.pass")] for i in range(14): registry.register(f"{prefix}.sub_rshift_s3_{i}") match = re.search(r'\.sub_shifted(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_rshift_s3_{i}"), registry.get_id(f"{prefix}.not_sub_shift4")] for i in range(14): registry.register(f"{prefix}.sub_shifted{i}") match = re.search(r'\.sub_mant(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_shifted{i+3}")] for i in range(10): registry.register(f"{prefix}.sub_mant{i}") if gate == f"{prefix}.sub_guard": return [registry.get_id(f"{prefix}.sub_shifted2")] registry.register(f"{prefix}.sub_guard") match = re.search(r'\.sub_shift_gt(\d+)$', gate) if match: return [registry.get_id(f"{prefix}.sub_shift.fa{i}.xor2.layer2") for i in range(5)] for k in range(14): registry.register(f"{prefix}.sub_shift_gt{k}") match = re.search(r'\.sub_sticky_part(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.lshift_s3_{i}"), registry.get_id(f"{prefix}.sub_shift_gt{i}")] for i in range(14): registry.register(f"{prefix}.sub_sticky_part{i}") if gate.endswith('.sub_sticky_raw'): return [registry.get_id(f"{prefix}.sub_sticky_part{i}") for i in range(14)] registry.register(f"{prefix}.sub_sticky_raw") if gate.endswith('.sub_sticky'): return [registry.get_id(f"{prefix}.sub_sticky_raw"), registry.get_id(f"{prefix}.sticky_bit")] registry.register(f"{prefix}.sub_sticky") if gate.endswith('.sub_round_lsb_or_sticky'): return [registry.get_id(f"{prefix}.sub_sticky"), registry.get_id(f"{prefix}.sub_mant0")] registry.register(f"{prefix}.sub_round_lsb_or_sticky") if gate.endswith('.sub_round_inc'): return [registry.get_id(f"{prefix}.sub_guard"), registry.get_id(f"{prefix}.sub_round_lsb_or_sticky")] registry.register(f"{prefix}.sub_round_inc") if '.sub_round.fa' in gate: match = re.search(r'\.sub_round\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_round.fa{i}" a_bit = registry.get_id(f"{prefix}.sub_mant{i}") b_bit = registry.get_id(f"{prefix}.sub_round_inc") if i == 0 else registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.sub_round.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(10): registry.register(f"{prefix}.sub_round.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_round.fa{i}.cout") if gate.endswith('.sub_round_overflow'): return [registry.get_id(f"{prefix}.sub_round.fa9.cout")] registry.register(f"{prefix}.sub_round_overflow") if '.subnorm_enable' in gate: return [registry.get_id(f"{prefix}.exp_underflow"), registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_result_is_inf"), registry.get_id(f"{prefix}.not_result_is_zero")] registry.register(f"{prefix}.subnorm_enable") # Rounding: derive guard/round/sticky from post-normalization bits if '.round_guard_overflow.and' in gate: return [registry.register(f"{prefix}.round_guard_overflow"), registry.get_id(f"{prefix}.sum_overflow")] if '.round_guard_norm.and' in gate: return [registry.register(f"{prefix}.round_guard_norm"), registry.get_id(f"{prefix}.not_sum_overflow")] if gate.endswith('.round_guard_overflow'): return [registry.get_id(f"{prefix}.mant_add.fa3.xor2.layer2")] if gate.endswith('.round_guard_norm'): return [registry.get_id(f"{prefix}.lshift_s3_2")] if gate.endswith('.round_guard'): return [registry.register(f"{prefix}.round_guard_overflow.and"), registry.register(f"{prefix}.round_guard_norm.and")] registry.register(f"{prefix}.round_guard_overflow") registry.register(f"{prefix}.round_guard_norm") registry.register(f"{prefix}.round_guard_overflow.and") registry.register(f"{prefix}.round_guard_norm.and") registry.register(f"{prefix}.round_guard") if gate.endswith('.round_post_overflow'): return [registry.get_id(f"{prefix}.mant_add.fa2.xor2.layer2")] if gate.endswith('.round_post_norm'): return [registry.get_id(f"{prefix}.lshift_s3_1")] if '.round_post_overflow.and' in gate: return [registry.register(f"{prefix}.round_post_overflow"), registry.get_id(f"{prefix}.sum_overflow")] if '.round_post_norm.and' in gate: return [registry.register(f"{prefix}.round_post_norm"), registry.get_id(f"{prefix}.not_sum_overflow")] if gate.endswith('.round_post'): return [registry.register(f"{prefix}.round_post_overflow.and"), registry.register(f"{prefix}.round_post_norm.and")] registry.register(f"{prefix}.round_post_overflow") registry.register(f"{prefix}.round_post_norm") registry.register(f"{prefix}.round_post_overflow.and") registry.register(f"{prefix}.round_post_norm.and") registry.register(f"{prefix}.round_post") if gate.endswith('.sticky_overflow'): return [registry.get_id(f"{prefix}.mant_add.fa1.xor2.layer2"), registry.get_id(f"{prefix}.mant_add.fa0.xor2.layer2"), registry.get_id(f"{prefix}.sticky_bit")] if '.sticky_norm.same' in gate: return [registry.get_id(f"{prefix}.sticky_bit"), registry.get_id(f"{prefix}.signs_same")] if '.sticky_norm.diff' in gate: return [registry.get_id(f"{prefix}.lshift_s3_0"), registry.get_id(f"{prefix}.signs_differ.layer2")] if gate.endswith('.sticky_norm'): return [registry.register(f"{prefix}.sticky_norm.same"), registry.register(f"{prefix}.sticky_norm.diff")] if gate.endswith('.round_overflow_or'): return [registry.register(f"{prefix}.round_post_overflow"), registry.register(f"{prefix}.sticky_overflow")] if gate.endswith('.round_norm_or'): return [registry.register(f"{prefix}.round_post_norm"), registry.register(f"{prefix}.sticky_norm")] if '.round_sticky_overflow.and' in gate: return [registry.register(f"{prefix}.round_overflow_or"), registry.get_id(f"{prefix}.sum_overflow")] if '.round_sticky_norm.and' in gate: return [registry.register(f"{prefix}.round_norm_or"), registry.get_id(f"{prefix}.not_sum_overflow")] if gate.endswith('.round_sticky'): return [registry.register(f"{prefix}.round_sticky_overflow.and"), registry.register(f"{prefix}.round_sticky_norm.and")] registry.register(f"{prefix}.sticky_overflow") registry.register(f"{prefix}.sticky_norm.same") registry.register(f"{prefix}.sticky_norm.diff") registry.register(f"{prefix}.sticky_norm") registry.register(f"{prefix}.round_overflow_or") registry.register(f"{prefix}.round_norm_or") registry.register(f"{prefix}.round_sticky_overflow.and") registry.register(f"{prefix}.round_sticky_norm.and") registry.register(f"{prefix}.round_sticky") if '.round_lsb_or_sticky' in gate: return [registry.register(f"{prefix}.round_sticky"), registry.get_id(f"{prefix}.norm_mant0")] registry.register(f"{prefix}.round_lsb_or_sticky") if '.round_inc' in gate: return [registry.get_id(f"{prefix}.round_guard"), registry.get_id(f"{prefix}.round_lsb_or_sticky")] registry.register(f"{prefix}.round_inc") # Mantissa rounding adder if '.round_norm.fa' in gate: match = re.search(r'\.round_norm\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.round_norm.fa{i}" a_bit = registry.get_id(f"{prefix}.norm_mant{i}") b_bit = registry.get_id(f"{prefix}.round_inc") if i == 0 else registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.round_norm.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(10): registry.register(f"{prefix}.round_norm.fa{i}.xor2.layer2") registry.register(f"{prefix}.round_norm.fa{i}.cout") if '.round_overflow' in gate and '.not_' not in gate: return [registry.get_id(f"{prefix}.round_norm.fa9.cout")] registry.register(f"{prefix}.round_overflow") if '.not_round_overflow' in gate: return [registry.get_id(f"{prefix}.round_overflow")] registry.register(f"{prefix}.not_round_overflow") # Exponent increment on rounding overflow if '.round_exp.fa' in gate: match = re.search(r'\.round_exp\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.round_exp.fa{i}" a_bit = registry.get_id(f"{prefix}.result_exp{i}") b_bit = registry.get_id("#0") cin = registry.get_id(f"{prefix}.round_overflow") if i == 0 else registry.register(f"{prefix}.round_exp.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.round_exp.fa{i}.xor2.layer2") registry.register(f"{prefix}.round_exp.fa{i}.cout") # Final exponent mux after rounding match = re.search(r'\.final_exp(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: return [registry.get_id(f"{prefix}.round_exp.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.round_overflow")] if '.normal_path' in gate: return [registry.get_id(f"{prefix}.result_exp{i}"), registry.get_id(f"{prefix}.not_round_overflow")] match = re.search(r'\.final_exp(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.final_exp{i}.overflow_path"), registry.register(f"{prefix}.final_exp{i}.normal_path")] # Final mantissa (zero on rounding overflow) match = re.search(r'\.final_mant(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.round_norm.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_round_overflow")] # Detect exponent overflow to infinity (all 5 final_exp bits = 1) if '.final_exp_all_ones' in gate: return [registry.register(f"{prefix}.final_exp{i}") for i in range(5)] if '.round_exp_all_ones' in gate: return [registry.register(f"{prefix}.round_exp.fa{i}.xor2.layer2") for i in range(5)] if '.exp_overflow_any' in gate: return [registry.get_id(f"{prefix}.round_exp.fa4.cout"), registry.register(f"{prefix}.round_exp_all_ones"), registry.register(f"{prefix}.final_exp_all_ones")] if '.exp_overflow_to_inf' in gate: return [registry.get_id(f"{prefix}.exp_overflow_any"), registry.register(f"{prefix}.not_exp_underflow")] registry.register(f"{prefix}.final_exp_all_ones") registry.register(f"{prefix}.round_exp_all_ones") registry.register(f"{prefix}.exp_overflow_any") registry.register(f"{prefix}.exp_overflow_to_inf") # Output assembly if '.not_result_is_inf' in gate: return [registry.get_id(f"{prefix}.result_is_inf")] if '.not_both_are_zero' in gate: return [registry.get_id(f"{prefix}.both_are_zero")] registry.register(f"{prefix}.not_result_is_inf") registry.register(f"{prefix}.result_is_inf") registry.register(f"{prefix}.not_both_are_zero") registry.register(f"{prefix}.not_both_exp_zero") registry.register(f"{prefix}.not_exp_underflow") # both_neg_zeros: both inputs are -0 → result sign is 1 if '.both_neg_zeros' in gate: return [registry.get_id(f"{prefix}.both_are_zero"), registry.get_id(f"{prefix}.sign_a"), registry.get_id(f"{prefix}.sign_b")] registry.register(f"{prefix}.both_neg_zeros") if '.not_exp_underflow' in gate: return [registry.register(f"{prefix}.exp_underflow")] if '.is_normal_result' in gate: return [registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_result_is_inf"), registry.get_id(f"{prefix}.not_both_are_zero"), registry.get_id(f"{prefix}.not_sum_is_zero"), registry.register(f"{prefix}.not_exp_underflow")] registry.register(f"{prefix}.is_normal_result") # Inf sign selection if '.inf_sign_sel_a' in gate: return [registry.get_id(f"{prefix}.sign_a"), registry.get_id(f"{prefix}.a_is_inf")] if '.inf_sign_sel_b' in gate: return [registry.get_id(f"{prefix}.sign_b"), registry.get_id(f"{prefix}.b_is_inf")] registry.register(f"{prefix}.inf_sign_sel_a") registry.register(f"{prefix}.inf_sign_sel_b") if '.overflow_sign_sel' in gate: return [registry.get_id(f"{prefix}.result_sign"), registry.get_id(f"{prefix}.exp_overflow_to_inf"), registry.get_id(f"{prefix}.not_result_is_nan")] registry.register(f"{prefix}.overflow_sign_sel") if '.inf_sign' in gate and '.inf_sign_sel' not in gate: return [registry.get_id(f"{prefix}.inf_sign_sel_a"), registry.get_id(f"{prefix}.inf_sign_sel_b"), registry.get_id(f"{prefix}.overflow_sign_sel")] registry.register(f"{prefix}.inf_sign") # NaN bits nan_bits = [0]*9 + [1] + [1]*5 + [0] match = re.search(r'\.out_nan(\d+)$', gate) if match: return [registry.get_id(f"{prefix}.result_is_nan")] # Inf bits match = re.search(r'\.out_inf(\d+)$', gate) if match: return [registry.get_id(f"{prefix}.result_is_inf")] # Normal output path match = re.search(r'\.out_normal(\d+)$', gate) if match: i = int(match.group(1)) if i == 15: return [registry.get_id(f"{prefix}.result_sign")] elif i >= 10: return [registry.register(f"{prefix}.final_exp{i-10}")] else: return [registry.register(f"{prefix}.final_mant{i}")] for i in range(16): registry.register(f"{prefix}.out_normal{i}") # Subnormal output path (both exponent bits zero) match = re.search(r'\.out_sub(\d+)$', gate) if match: i = int(match.group(1)) if i == 15: return [registry.get_id(f"{prefix}.result_sign")] elif i == 10: return [registry.get_id(f"{prefix}.sub_round_overflow")] elif 10 < i < 15: return [registry.get_id("#0")] else: return [registry.get_id(f"{prefix}.sub_round.fa{i}.xor2.layer2")] for i in range(16): registry.register(f"{prefix}.out_sub{i}") # Final output gates match = re.search(r'\.out(\d+)\.(nan_gate|inf_gate|normal_gate|sub_gate)$', gate) if match: i = int(match.group(1)) gate_type = match.group(2) if gate_type == 'nan_gate': nan_val = registry.register(f"{prefix}.out_nan{i}") if nan_bits[i] else registry.get_id("#0") return [nan_val, registry.get_id(f"{prefix}.result_is_nan")] elif gate_type == 'inf_gate': if i >= 10 and i < 15: inf_val = registry.register(f"{prefix}.out_inf{i}") elif i == 15: inf_val = registry.get_id(f"{prefix}.inf_sign") else: inf_val = registry.get_id("#0") return [inf_val, registry.get_id(f"{prefix}.result_is_inf")] elif gate_type == 'normal_gate': return [registry.get_id(f"{prefix}.out_normal{i}"), registry.get_id(f"{prefix}.is_normal_result")] elif gate_type == 'sub_gate': return [registry.get_id(f"{prefix}.out_sub{i}"), registry.get_id(f"{prefix}.subnorm_enable")] match = re.search(r'\.out(\d+)$', gate) if match: i = int(match.group(1)) if i == 15: # Sign bit includes both_neg_zeros for -0 + -0 = -0 case return [registry.register(f"{prefix}.out{i}.nan_gate"), registry.register(f"{prefix}.out{i}.inf_gate"), registry.register(f"{prefix}.out{i}.normal_gate"), registry.register(f"{prefix}.out{i}.sub_gate"), registry.get_id(f"{prefix}.both_neg_zeros")] else: return [registry.register(f"{prefix}.out{i}.nan_gate"), registry.register(f"{prefix}.out{i}.inf_gate"), registry.register(f"{prefix}.out{i}.normal_gate"), registry.register(f"{prefix}.out{i}.sub_gate")] return [] def infer_float16_sub_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.sub circuit. Subtraction is implemented as a + (-b), so we flip b's sign and use add logic. This is a thin wrapper that references float16.add with modified input. """ prefix = "float16.sub" for i in range(16): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # Register float16.add outputs so we can reference them for i in range(16): registry.register(f"float16.add.out{i}") if '.b_neg_sign' in gate: return [registry.get_id(f"{prefix}.$b[15]")] registry.register(f"{prefix}.b_neg_sign") match = re.search(r'\.out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"float16.add.out{i}")] return [] def infer_float16_mul_inputs(gate: str, registry: SignalRegistry, prefix: str = "float16.mul", a_bits: Optional[List[str]] = None, b_bits: Optional[List[str]] = None) -> List[int]: """Infer inputs for float16.mul circuit (optionally with custom input sources).""" if a_bits is None: a_bits = [f"{prefix}.$a[{i}]" for i in range(16)] if b_bits is None: b_bits = [f"{prefix}.$b[{i}]" for i in range(16)] for name in a_bits: registry.register(name) for name in b_bits: registry.register(name) exp_a_bits = [a_bits[10 + i] for i in range(5)] exp_b_bits = [b_bits[10 + i] for i in range(5)] mant_a_bits = [a_bits[i] for i in range(10)] mant_b_bits = [b_bits[i] for i in range(10)] if '.exp_a_all_ones' in gate: return [registry.get_id(b) for b in exp_a_bits] if '.exp_b_all_ones' in gate: return [registry.get_id(b) for b in exp_b_bits] if '.exp_a_zero' in gate: return [registry.get_id(b) for b in exp_a_bits] if '.exp_b_zero' in gate: return [registry.get_id(b) for b in exp_b_bits] registry.register(f"{prefix}.exp_a_all_ones") registry.register(f"{prefix}.exp_b_all_ones") registry.register(f"{prefix}.exp_a_zero") registry.register(f"{prefix}.exp_b_zero") if '.a_adj_exp0' in gate: return [registry.get_id(exp_a_bits[0]), registry.get_id(f"{prefix}.exp_a_zero")] if '.b_adj_exp0' in gate: return [registry.get_id(exp_b_bits[0]), registry.get_id(f"{prefix}.exp_b_zero")] registry.register(f"{prefix}.a_adj_exp0") registry.register(f"{prefix}.b_adj_exp0") if '.mant_a_nonzero' in gate: return [registry.get_id(b) for b in mant_a_bits] if '.mant_b_nonzero' in gate: return [registry.get_id(b) for b in mant_b_bits] registry.register(f"{prefix}.mant_a_nonzero") registry.register(f"{prefix}.mant_b_nonzero") if '.mant_a_zero' in gate: return [registry.get_id(f"{prefix}.mant_a_nonzero")] if '.mant_b_zero' in gate: return [registry.get_id(f"{prefix}.mant_b_nonzero")] registry.register(f"{prefix}.mant_a_zero") registry.register(f"{prefix}.mant_b_zero") match = re.search(r'\.mant_a_norm(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(a_bits[i])] match = re.search(r'\.mant_b_norm(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(b_bits[i])] for i in range(10): registry.register(f"{prefix}.mant_a_norm{i}") registry.register(f"{prefix}.mant_b_norm{i}") if '.a_is_nan' in gate: return [registry.get_id(f"{prefix}.exp_a_all_ones"), registry.get_id(f"{prefix}.mant_a_nonzero")] if '.b_is_nan' in gate: return [registry.get_id(f"{prefix}.exp_b_all_ones"), registry.get_id(f"{prefix}.mant_b_nonzero")] if '.a_is_inf' in gate: return [registry.get_id(f"{prefix}.exp_a_all_ones"), registry.get_id(f"{prefix}.mant_a_zero")] if '.b_is_inf' in gate: return [registry.get_id(f"{prefix}.exp_b_all_ones"), registry.get_id(f"{prefix}.mant_b_zero")] if '.a_is_zero' in gate: return [registry.get_id(f"{prefix}.exp_a_zero"), registry.get_id(f"{prefix}.mant_a_zero")] if '.b_is_zero' in gate: return [registry.get_id(f"{prefix}.exp_b_zero"), registry.get_id(f"{prefix}.mant_b_zero")] registry.register(f"{prefix}.a_is_nan") registry.register(f"{prefix}.b_is_nan") registry.register(f"{prefix}.a_is_inf") registry.register(f"{prefix}.b_is_inf") registry.register(f"{prefix}.a_is_zero") registry.register(f"{prefix}.b_is_zero") if '.either_is_nan' in gate: return [registry.get_id(f"{prefix}.a_is_nan"), registry.get_id(f"{prefix}.b_is_nan")] if '.either_is_inf' in gate: return [registry.get_id(f"{prefix}.a_is_inf"), registry.get_id(f"{prefix}.b_is_inf")] if '.either_is_zero' in gate: return [registry.get_id(f"{prefix}.a_is_zero"), registry.get_id(f"{prefix}.b_is_zero")] if '.inf_times_zero' in gate: return [registry.get_id(f"{prefix}.either_is_inf"), registry.get_id(f"{prefix}.either_is_zero")] registry.register(f"{prefix}.either_is_nan") registry.register(f"{prefix}.either_is_inf") registry.register(f"{prefix}.either_is_zero") registry.register(f"{prefix}.inf_times_zero") if gate.endswith('.result_is_nan'): return [registry.get_id(f"{prefix}.either_is_nan"), registry.get_id(f"{prefix}.inf_times_zero")] if '.not_result_is_nan' in gate: return [registry.get_id(f"{prefix}.result_is_nan")] if '.not_either_is_zero' in gate: return [registry.get_id(f"{prefix}.either_is_zero")] if gate.endswith('.result_is_inf'): return [registry.get_id(f"{prefix}.either_is_inf"), registry.register(f"{prefix}.exp_overflow_to_inf"), registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_either_is_zero")] if gate.endswith('.result_is_zero'): return [registry.get_id(f"{prefix}.either_is_zero"), registry.get_id(f"{prefix}.not_result_is_nan")] registry.register(f"{prefix}.result_is_nan") registry.register(f"{prefix}.not_result_is_nan") registry.register(f"{prefix}.not_either_is_zero") registry.register(f"{prefix}.result_is_inf") registry.register(f"{prefix}.result_is_zero") if '.result_sign.layer1.or' in gate: return [registry.get_id(a_bits[15]), registry.get_id(b_bits[15])] if '.result_sign.layer1.nand' in gate: return [registry.get_id(a_bits[15]), registry.get_id(b_bits[15])] if '.result_sign.layer2' in gate: return [registry.register(f"{prefix}.result_sign.layer1.or"), registry.register(f"{prefix}.result_sign.layer1.nand")] registry.register(f"{prefix}.result_sign.layer2") if '.implicit_a' in gate: return [registry.get_id(f"{prefix}.exp_a_zero")] if '.implicit_b' in gate: return [registry.get_id(f"{prefix}.exp_b_zero")] registry.register(f"{prefix}.implicit_a") registry.register(f"{prefix}.implicit_b") match = re.search(r'\.pp(\d+)_(\d+)$', gate) if match: i, j = int(match.group(1)), int(match.group(2)) if i == 10: a_bit = registry.get_id(f"{prefix}.implicit_a") else: a_bit = registry.get_id(a_bits[i]) if j == 10: b_bit = registry.get_id(f"{prefix}.implicit_b") else: b_bit = registry.get_id(b_bits[j]) return [a_bit, b_bit] for i in range(11): for j in range(11): registry.register(f"{prefix}.pp{i}_{j}") for col in range(22): pps = [] for i in range(11): j = col - i if 0 <= j < 11: pps.append(f"{prefix}.pp{i}_{j}") count = len(pps) if count == 1: if f'.col{col}' in gate and f'.col{col}_' not in gate: return [registry.get_id(pps[0])] registry.register(f"{prefix}.col{col}") elif count > 1: # ge{t} gates: threshold >= t match = re.search(rf'\.col{col}_ge(\d+)$', gate) if match: return [registry.get_id(pp) for pp in pps] for t in range(1, count + 1): registry.register(f"{prefix}.col{col}_ge{t}") # not_ge{t} for even t match = re.search(rf'\.col{col}_not_ge(\d+)$', gate) if match: t = int(match.group(1)) return [registry.get_id(f"{prefix}.col{col}_ge{t}")] for t in range(2, count + 1, 2): registry.register(f"{prefix}.col{col}_not_ge{t}") # odd{t} gates: ge{t} AND (NOT ge{t+1} or just ge{t} if t+1 > count) match = re.search(rf'\.col{col}_odd(\d+)$', gate) if match: t = int(match.group(1)) if t + 1 <= count: return [registry.get_id(f"{prefix}.col{col}_ge{t}"), registry.get_id(f"{prefix}.col{col}_not_ge{t+1}")] else: return [registry.get_id(f"{prefix}.col{col}_ge{t}")] odd_ranges = [] for t in range(1, count + 1, 2): registry.register(f"{prefix}.col{col}_odd{t}") odd_ranges.append(f"{prefix}.col{col}_odd{t}") # col_sum = OR of all odd gates (parity) if f'.col{col}_sum' in gate: return [registry.get_id(r) for r in odd_ranges] registry.register(f"{prefix}.col{col}_sum") # col_bit1 gates (floor(sum/2) mod 2) if count >= 2: match = re.search(rf'\.col{col}_bit1_(\d+)$', gate) if match: t = int(match.group(1)) upper = t + 2 if upper <= count: return [registry.get_id(f"{prefix}.col{col}_ge{t}"), registry.get_id(f"{prefix}.col{col}_not_ge{upper}")] else: return [registry.get_id(f"{prefix}.col{col}_ge{t}")] bit1_ranges = [] for t in range(2, count + 1, 4): registry.register(f"{prefix}.col{col}_bit1_{t}") bit1_ranges.append(f"{prefix}.col{col}_bit1_{t}") if f'.col{col}_bit1' in gate and f'.col{col}_bit1_' not in gate: return [registry.get_id(r) for r in bit1_ranges] if bit1_ranges: registry.register(f"{prefix}.col{col}_bit1") # col_bit2 gates (floor(sum/4) mod 2) if count >= 4: match = re.search(rf'\.col{col}_bit2_(\d+)$', gate) if match: t = int(match.group(1)) upper = t + 4 if upper <= count: return [registry.get_id(f"{prefix}.col{col}_ge{t}"), registry.get_id(f"{prefix}.col{col}_not_ge{upper}")] else: return [registry.get_id(f"{prefix}.col{col}_ge{t}")] bit2_ranges = [] for t in range(4, count + 1, 8): registry.register(f"{prefix}.col{col}_bit2_{t}") bit2_ranges.append(f"{prefix}.col{col}_bit2_{t}") if f'.col{col}_bit2' in gate and f'.col{col}_bit2_' not in gate: return [registry.get_id(r) for r in bit2_ranges] if bit2_ranges: registry.register(f"{prefix}.col{col}_bit2") # col_bit3 gates (floor(sum/8) mod 2) if count >= 8: match = re.search(rf'\.col{col}_bit3_(\d+)$', gate) if match: t = int(match.group(1)) upper = t + 8 if upper <= count: return [registry.get_id(f"{prefix}.col{col}_ge{t}"), registry.get_id(f"{prefix}.col{col}_not_ge{upper}")] else: return [registry.get_id(f"{prefix}.col{col}_ge{t}")] bit3_ranges = [] for t in range(8, count + 1, 16): registry.register(f"{prefix}.col{col}_bit3_{t}") bit3_ranges.append(f"{prefix}.col{col}_bit3_{t}") if f'.col{col}_bit3' in gate and f'.col{col}_bit3_' not in gate: return [registry.get_id(r) for r in bit3_ranges] if bit3_ranges: registry.register(f"{prefix}.col{col}_bit3") # Handle carry accumulator gates if '.carry_acc' in gate: match = re.search(r'\.carry_acc(\d+)_', gate) if match: i = int(match.group(1)) def get_pp_count(col): if col < 0 or col > 20: return 0 return min(col + 1, 21 - col) # Determine which carry bits come into position i carry_inputs = [] if i >= 1 and get_pp_count(i-1) >= 2: carry_inputs.append(registry.get_id(f"{prefix}.col{i-1}_bit1")) if i >= 2 and get_pp_count(i-2) >= 4: carry_inputs.append(registry.get_id(f"{prefix}.col{i-2}_bit2")) if i >= 3 and get_pp_count(i-3) >= 8: carry_inputs.append(registry.get_id(f"{prefix}.col{i-3}_bit3")) n = len(carry_inputs) # ge{t} gates match_ge = re.search(rf'\.carry_acc{i}_ge(\d+)$', gate) if match_ge: return carry_inputs # not_ge{t} gates match_not = re.search(rf'\.carry_acc{i}_not_ge(\d+)$', gate) if match_not: t = int(match_not.group(1)) return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}")] # Register ge gates for t in range(1, n + 1): registry.register(f"{prefix}.carry_acc{i}_ge{t}") for t in range(2, n + 1, 2): registry.register(f"{prefix}.carry_acc{i}_not_ge{t}") # odd{t} gates match_odd = re.search(rf'\.carry_acc{i}_odd(\d+)$', gate) if match_odd: t = int(match_odd.group(1)) if t + 1 <= n: return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}"), registry.get_id(f"{prefix}.carry_acc{i}_not_ge{t+1}")] else: return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}")] # Register odd gates odd_ranges = [] for t in range(1, n + 1, 2): registry.register(f"{prefix}.carry_acc{i}_odd{t}") odd_ranges.append(f"{prefix}.carry_acc{i}_odd{t}") # carry_acc_sum = OR of odd gates registry.register(f"{prefix}.carry_acc{i}_sum") if f'.carry_acc{i}_sum' in gate: return [registry.get_id(r) for r in odd_ranges] # carry_acc_carry = ge2 (register before checking gate match) if n >= 2: registry.register(f"{prefix}.carry_acc{i}_carry") if f'.carry_acc{i}_carry' in gate: return carry_inputs if '.prod_fa' in gate: match = re.search(r'\.prod_fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.prod_fa{i}" def get_pp_count(col): if col < 0 or col > 20: return 0 return min(col + 1, 21 - col) # Count partial products in each column to determine signal names # col 0 and col 20 have 1 PP each, others have more def get_col_sum(col): if col == 0 or col == 20: return registry.get_id(f"{prefix}.col{col}") elif col < 21: return registry.get_id(f"{prefix}.col{col}_sum") return registry.get_id("#0") def get_b_bit(pos): # Determine incoming carries for position pos carry_inputs = [] if pos >= 1 and get_pp_count(pos-1) >= 2: carry_inputs.append("bit1") if pos >= 2 and get_pp_count(pos-2) >= 4: carry_inputs.append("bit2") if pos >= 3 and get_pp_count(pos-3) >= 8: carry_inputs.append("bit3") if len(carry_inputs) == 0: return registry.get_id("#0") elif len(carry_inputs) == 1: # Single carry, use it directly if carry_inputs[0] == "bit1": return registry.get_id(f"{prefix}.col{pos-1}_bit1") elif carry_inputs[0] == "bit2": return registry.get_id(f"{prefix}.col{pos-2}_bit2") else: return registry.get_id(f"{prefix}.col{pos-3}_bit3") else: # Multiple carries, use accumulator sum return registry.register(f"{prefix}.carry_acc{pos}_sum") def get_extra_cin(pos): # Extra carry from accumulator (when sum of carries >= 2) carry_inputs = [] if pos >= 1 and get_pp_count(pos-1) >= 2: carry_inputs.append("bit1") if pos >= 2 and get_pp_count(pos-2) >= 4: carry_inputs.append("bit2") if pos >= 3 and get_pp_count(pos-3) >= 8: carry_inputs.append("bit3") if len(carry_inputs) >= 2: return registry.register(f"{prefix}.carry_acc{pos}_carry") return None if i == 0: a_bit = get_col_sum(0) b_bit = registry.get_id("#0") cin = registry.get_id("#0") else: a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0") b_bit = get_b_bit(i) cin = registry.register(f"{prefix}.prod_fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(22): registry.register(f"{prefix}.prod_fa{i}.xor2.layer2") registry.register(f"{prefix}.prod_fa{i}.cout") # Second pass: prod2_fa adds prod_fa output + carry_acc_carry (from i-1) if '.prod2_fa' in gate: match = re.search(r'\.prod2_fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.prod2_fa{i}" # a = intermediate sum from first pass a_bit = registry.get_id(f"{prefix}.prod_fa{i}.xor2.layer2") # b = carry_acc_carry from position i-1 (if exists) # carry_acc_carry exists when position i-1 had multiple incoming carries def get_pp_count(col): if col < 0 or col > 20: return 0 return min(col + 1, 21 - col) def has_secondary_carry(pos): carry_count = 0 if pos >= 1 and get_pp_count(pos-1) >= 2: carry_count += 1 if pos >= 2 and get_pp_count(pos-2) >= 4: carry_count += 1 if pos >= 3 and get_pp_count(pos-3) >= 8: carry_count += 1 return carry_count >= 2 if i > 0 and has_secondary_carry(i - 1): b_bit = registry.get_id(f"{prefix}.carry_acc{i-1}_carry") else: b_bit = registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.prod2_fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(22): registry.register(f"{prefix}.prod2_fa{i}.xor2.layer2") registry.register(f"{prefix}.prod2_fa{i}.cout") if '.exp_add.fa' in gate: match = re.search(r'\.exp_add\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_add.fa{i}" if i == 0: a_bit = registry.get_id(f"{prefix}.a_adj_exp0") b_bit = registry.get_id(f"{prefix}.b_adj_exp0") else: a_bit = registry.get_id(exp_a_bits[i]) if i < 5 else registry.get_id("#0") b_bit = registry.get_id(exp_b_bits[i]) if i < 5 else registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.exp_add.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(6): registry.register(f"{prefix}.exp_add.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_add.fa{i}.cout") # NOT(15) = NOT(001111) = 110000 in 6-bit, little-endian: [0, 0, 0, 0, 1, 1] not_15_bits = [0, 0, 0, 0, 1, 1] if '.exp_sub.fa' in gate: match = re.search(r'\.exp_sub\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_sub.fa{i}" a_bit = registry.get_id(f"{prefix}.exp_add.fa{i}.xor2.layer2") b_bit = registry.get_id(f"#{not_15_bits[i]}") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.exp_sub.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(6): registry.register(f"{prefix}.exp_sub.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_sub.fa{i}.cout") # Use prod2_fa (second pass) for final product bits if '.prod_overflow' in gate and '.not_prod_overflow' not in gate: return [registry.get_id(f"{prefix}.prod2_fa21.xor2.layer2")] if '.not_prod_overflow' in gate: return [registry.get_id(f"{prefix}.prod_overflow")] registry.register(f"{prefix}.prod_overflow") registry.register(f"{prefix}.not_prod_overflow") match = re.search(r'\.norm_mant(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: return [registry.get_id(f"{prefix}.lshift_s3_{i+1}"), registry.get_id(f"{prefix}.prod_overflow")] if '.normal_path' in gate: return [registry.get_id(f"{prefix}.prod_lshift_s3_{i+10}"), registry.get_id(f"{prefix}.not_prod_overflow")] if '.eq10_path' in gate: return [registry.get_id(f"{prefix}.prod2_fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.norm_shift_eq10")] match = re.search(r'\.norm_mant(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.norm_mant{i}.overflow_path"), registry.register(f"{prefix}.norm_mant{i}.normal_path"), registry.register(f"{prefix}.norm_mant{i}.eq10_path")] for i in range(10): registry.register(f"{prefix}.norm_mant{i}") if '.result_exp_fa' in gate: match = re.search(r'\.result_exp_fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.result_exp_fa{i}" a_bit = registry.get_id(f"{prefix}.exp_sub.fa{i}.xor2.layer2") b_bit = registry.get_id("#0") cin = registry.get_id(f"{prefix}.prod_overflow") if i == 0 else registry.register(f"{prefix}.result_exp_fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.result_exp_fa{i}.xor2.layer2") registry.register(f"{prefix}.result_exp_fa{i}.cout") # Left barrel shifter for guard/sticky bits from low product bits (11 bits) match = re.search(r'\.guard_lshift_s0_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.prod2_fa{i}.xor2.layer2"), registry.register(f"{prefix}.not_norm_shift0")] if '.shift' in gate and i > 0: return [registry.get_id(f"{prefix}.prod2_fa{i-1}.xor2.layer2"), registry.register(f"{prefix}.norm_shift0")] match = re.search(r'\.guard_lshift_s0_(\d+)$', gate) if match: i = int(match.group(1)) if i > 0: return [registry.register(f"{prefix}.guard_lshift_s0_{i}.pass"), registry.register(f"{prefix}.guard_lshift_s0_{i}.shift")] return [registry.register(f"{prefix}.guard_lshift_s0_{i}.pass")] for i in range(11): registry.register(f"{prefix}.guard_lshift_s0_{i}") match = re.search(r'\.guard_lshift_s1_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.guard_lshift_s0_{i}"), registry.register(f"{prefix}.not_norm_shift1")] if '.shift' in gate and i > 1: return [registry.get_id(f"{prefix}.guard_lshift_s0_{i-2}"), registry.register(f"{prefix}.norm_shift1")] match = re.search(r'\.guard_lshift_s1_(\d+)$', gate) if match: i = int(match.group(1)) if i > 1: return [registry.register(f"{prefix}.guard_lshift_s1_{i}.pass"), registry.register(f"{prefix}.guard_lshift_s1_{i}.shift")] return [registry.register(f"{prefix}.guard_lshift_s1_{i}.pass")] for i in range(11): registry.register(f"{prefix}.guard_lshift_s1_{i}") match = re.search(r'\.guard_lshift_s2_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.guard_lshift_s1_{i}"), registry.register(f"{prefix}.not_norm_shift2")] if '.shift' in gate and i > 3: return [registry.get_id(f"{prefix}.guard_lshift_s1_{i-4}"), registry.register(f"{prefix}.norm_shift2")] match = re.search(r'\.guard_lshift_s2_(\d+)$', gate) if match: i = int(match.group(1)) if i > 3: return [registry.register(f"{prefix}.guard_lshift_s2_{i}.pass"), registry.register(f"{prefix}.guard_lshift_s2_{i}.shift")] return [registry.register(f"{prefix}.guard_lshift_s2_{i}.pass")] for i in range(11): registry.register(f"{prefix}.guard_lshift_s2_{i}") match = re.search(r'\.guard_lshift_s3_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.guard_lshift_s2_{i}"), registry.register(f"{prefix}.not_norm_shift3")] if '.shift' in gate and i > 7: return [registry.get_id(f"{prefix}.guard_lshift_s2_{i-8}"), registry.register(f"{prefix}.norm_shift3")] match = re.search(r'\.guard_lshift_s3_(\d+)$', gate) if match: i = int(match.group(1)) if i > 7: return [registry.register(f"{prefix}.guard_lshift_s3_{i}.pass"), registry.register(f"{prefix}.guard_lshift_s3_{i}.shift")] return [registry.register(f"{prefix}.guard_lshift_s3_{i}.pass")] for i in range(11): registry.register(f"{prefix}.guard_lshift_s3_{i}") # Left barrel shifter for normal-path mantissa from product bits (20 bits) match = re.search(r'\.prod_lshift_s0_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.prod2_fa{i}.xor2.layer2"), registry.register(f"{prefix}.not_norm_shift0")] if '.shift' in gate and i > 0: return [registry.get_id(f"{prefix}.prod2_fa{i-1}.xor2.layer2"), registry.register(f"{prefix}.norm_shift0")] match = re.search(r'\.prod_lshift_s0_(\d+)$', gate) if match: i = int(match.group(1)) if i > 0: return [registry.register(f"{prefix}.prod_lshift_s0_{i}.pass"), registry.register(f"{prefix}.prod_lshift_s0_{i}.shift")] return [registry.register(f"{prefix}.prod_lshift_s0_{i}.pass")] for i in range(22): registry.register(f"{prefix}.prod_lshift_s0_{i}") match = re.search(r'\.prod_lshift_s1_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.prod_lshift_s0_{i}"), registry.register(f"{prefix}.not_norm_shift1")] if '.shift' in gate and i > 1: return [registry.get_id(f"{prefix}.prod_lshift_s0_{i-2}"), registry.register(f"{prefix}.norm_shift1")] match = re.search(r'\.prod_lshift_s1_(\d+)$', gate) if match: i = int(match.group(1)) if i > 1: return [registry.register(f"{prefix}.prod_lshift_s1_{i}.pass"), registry.register(f"{prefix}.prod_lshift_s1_{i}.shift")] return [registry.register(f"{prefix}.prod_lshift_s1_{i}.pass")] for i in range(22): registry.register(f"{prefix}.prod_lshift_s1_{i}") match = re.search(r'\.prod_lshift_s2_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.prod_lshift_s1_{i}"), registry.register(f"{prefix}.not_norm_shift2")] if '.shift' in gate and i > 3: return [registry.get_id(f"{prefix}.prod_lshift_s1_{i-4}"), registry.register(f"{prefix}.norm_shift2")] match = re.search(r'\.prod_lshift_s2_(\d+)$', gate) if match: i = int(match.group(1)) if i > 3: return [registry.register(f"{prefix}.prod_lshift_s2_{i}.pass"), registry.register(f"{prefix}.prod_lshift_s2_{i}.shift")] return [registry.register(f"{prefix}.prod_lshift_s2_{i}.pass")] for i in range(22): registry.register(f"{prefix}.prod_lshift_s2_{i}") match = re.search(r'\.prod_lshift_s3_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.prod_lshift_s2_{i}"), registry.register(f"{prefix}.not_norm_shift3")] if '.shift' in gate and i > 7: return [registry.get_id(f"{prefix}.prod_lshift_s2_{i-8}"), registry.register(f"{prefix}.norm_shift3")] match = re.search(r'\.prod_lshift_s3_(\d+)$', gate) if match: i = int(match.group(1)) if i > 7: return [registry.register(f"{prefix}.prod_lshift_s3_{i}.pass"), registry.register(f"{prefix}.prod_lshift_s3_{i}.shift")] return [registry.register(f"{prefix}.prod_lshift_s3_{i}.pass")] for i in range(22): registry.register(f"{prefix}.prod_lshift_s3_{i}") # Rounding and underflow handling if '.round_guard_overflow.and' in gate: return [registry.register(f"{prefix}.round_guard_overflow"), registry.get_id(f"{prefix}.prod_overflow")] if '.round_guard_norm.and' in gate: return [registry.register(f"{prefix}.round_guard_norm"), registry.get_id(f"{prefix}.not_prod_overflow")] if gate.endswith('.round_guard_overflow'): return [registry.get_id(f"{prefix}.lshift_s3_0")] if gate.endswith('.round_guard_norm'): return [registry.get_id(f"{prefix}.guard_lshift_s3_9")] if gate.endswith('.round_guard'): return [registry.register(f"{prefix}.round_guard_overflow.and"), registry.register(f"{prefix}.round_guard_norm.and")] registry.register(f"{prefix}.round_guard_overflow") registry.register(f"{prefix}.round_guard_norm") registry.register(f"{prefix}.round_guard_overflow.and") registry.register(f"{prefix}.round_guard_norm.and") registry.register(f"{prefix}.round_guard") if gate.endswith('.sticky_overflow'): return [registry.get_id(f"{prefix}.prod2_fa{i}.xor2.layer2") for i in range(10)] if gate.endswith('.sticky_norm'): return [registry.get_id(f"{prefix}.guard_lshift_s3_{i}") for i in range(9)] if '.sticky_norm_ext' in gate: return [registry.get_id(f"{prefix}.sticky_norm")] if '.round_sticky_overflow.and' in gate: return [registry.register(f"{prefix}.sticky_overflow"), registry.get_id(f"{prefix}.prod_overflow")] if '.round_sticky_norm.and' in gate: return [registry.register(f"{prefix}.sticky_norm_ext"), registry.get_id(f"{prefix}.not_prod_overflow")] if gate.endswith('.round_sticky'): return [registry.register(f"{prefix}.round_sticky_overflow.and"), registry.register(f"{prefix}.round_sticky_norm.and")] registry.register(f"{prefix}.sticky_overflow") registry.register(f"{prefix}.sticky_norm") registry.register(f"{prefix}.sticky_norm_ext") registry.register(f"{prefix}.round_sticky_overflow.and") registry.register(f"{prefix}.round_sticky_norm.and") registry.register(f"{prefix}.round_sticky") if gate.endswith('.round_lsb_or_sticky'): return [registry.register(f"{prefix}.round_sticky"), registry.get_id(f"{prefix}.norm_mant0")] registry.register(f"{prefix}.round_lsb_or_sticky") if gate.endswith('.round_inc'): return [registry.get_id(f"{prefix}.round_guard"), registry.get_id(f"{prefix}.round_lsb_or_sticky")] registry.register(f"{prefix}.round_inc") if '.round_norm.fa' in gate: match = re.search(r'\.round_norm\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.round_norm.fa{i}" a_bit = registry.get_id(f"{prefix}.norm_mant{i}") b_bit = registry.get_id(f"{prefix}.round_inc") if i == 0 else registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.round_norm.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(10): registry.register(f"{prefix}.round_norm.fa{i}.xor2.layer2") registry.register(f"{prefix}.round_norm.fa{i}.cout") if gate.endswith('.round_overflow'): return [registry.get_id(f"{prefix}.round_norm.fa9.cout")] registry.register(f"{prefix}.round_overflow") if '.not_round_overflow' in gate: return [registry.get_id(f"{prefix}.round_overflow")] registry.register(f"{prefix}.not_round_overflow") # Exponent decrement by normalization shift amount for i in range(4): if f'.not_norm_shift_sub{i}' in gate: return [registry.get_id(f"{prefix}.norm_shift{i}")] registry.register(f"{prefix}.not_norm_shift_sub{i}") if '.exp_dec.fa' in gate: match = re.search(r'\.exp_dec\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_dec.fa{i}" exp_bit = registry.get_id(f"{prefix}.result_exp_fa{i}.xor2.layer2") if i < 4: not_shift = registry.get_id(f"{prefix}.not_norm_shift_sub{i}") else: not_shift = registry.get_id("#1") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.exp_dec.fa{i-1}.cout") if '.xor1.layer1' in gate: return [exp_bit, not_shift] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [exp_bit, not_shift] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.exp_dec.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_dec.fa{i}.cout") if '.round_exp.fa' in gate: match = re.search(r'\.round_exp\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.round_exp.fa{i}" a_bit = registry.get_id(f"{prefix}.exp_dec.fa{i}.xor2.layer2") b_bit = registry.get_id("#0") cin = registry.get_id(f"{prefix}.round_overflow") if i == 0 else registry.register(f"{prefix}.round_exp.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.round_exp.fa{i}.xor2.layer2") registry.register(f"{prefix}.round_exp.fa{i}.cout") match = re.search(r'\.final_exp(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: return [registry.get_id(f"{prefix}.round_exp.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.round_overflow")] if '.normal_path' in gate: return [registry.get_id(f"{prefix}.exp_dec.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_round_overflow")] match = re.search(r'\.final_exp(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.final_exp{i}.overflow_path"), registry.register(f"{prefix}.final_exp{i}.normal_path")] for i in range(5): registry.register(f"{prefix}.final_exp{i}") match = re.search(r'\.final_mant(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.round_norm.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_round_overflow")] for i in range(10): registry.register(f"{prefix}.final_mant{i}") if '.final_exp_all_ones' in gate: return [registry.register(f"{prefix}.final_exp{i}") for i in range(5)] if '.round_exp_all_ones' in gate: return [registry.register(f"{prefix}.round_exp.fa{i}.xor2.layer2") for i in range(5)] if '.exp_overflow_any' in gate: return [registry.get_id(f"{prefix}.exp_sub.fa5.xor2.layer2"), registry.get_id(f"{prefix}.round_exp.fa4.cout"), registry.get_id(f"{prefix}.round_exp_all_ones"), registry.get_id(f"{prefix}.result_exp_fa4.cout")] if '.exp_overflow_to_inf' in gate: return [registry.get_id(f"{prefix}.exp_sub.fa5.cout"), registry.get_id(f"{prefix}.exp_overflow_any"), registry.register(f"{prefix}.not_exp_underflow")] registry.register(f"{prefix}.final_exp_all_ones") registry.register(f"{prefix}.round_exp_all_ones") registry.register(f"{prefix}.exp_overflow_any") registry.register(f"{prefix}.exp_overflow_to_inf") if '.exp_sub_borrow' in gate: return [registry.get_id(f"{prefix}.exp_sub.fa5.cout")] if '.exp_sub_zero_and_npo' in gate: return [registry.get_id(f"{prefix}.exp_sub_zero"), registry.get_id(f"{prefix}.not_prod_overflow")] if gate.endswith('.exp_sub_zero'): return [registry.get_id(f"{prefix}.exp_sub.fa{i}.xor2.layer2") for i in range(6)] if '.exp_dec_borrow' in gate: return [registry.get_id(f"{prefix}.exp_dec.fa4.cout")] if '.exp_dec_zero_and_no_overflow' in gate: return [registry.get_id(f"{prefix}.exp_dec_zero"), registry.register(f"{prefix}.not_exp_overflow_any")] if '.exp_dec_zero' in gate: return [registry.get_id(f"{prefix}.exp_dec.fa{i}.xor2.layer2") for i in range(5)] if '.exp_underflow' in gate: return [registry.get_id(f"{prefix}.exp_sub_borrow"), registry.get_id(f"{prefix}.exp_dec_borrow"), registry.register(f"{prefix}.exp_dec_zero_and_no_overflow")] if '.not_exp_overflow_any' in gate: return [registry.get_id(f"{prefix}.exp_overflow_any")] if '.not_exp_underflow' in gate: return [registry.register(f"{prefix}.exp_underflow")] registry.register(f"{prefix}.exp_sub_borrow") registry.register(f"{prefix}.exp_sub_zero") registry.register(f"{prefix}.exp_sub_zero_and_npo") registry.register(f"{prefix}.exp_dec_borrow") registry.register(f"{prefix}.exp_dec_zero") registry.register(f"{prefix}.exp_underflow") registry.register(f"{prefix}.not_exp_overflow_any") registry.register(f"{prefix}.exp_dec_zero_and_no_overflow") registry.register(f"{prefix}.not_exp_underflow") match = re.search(r'\.not_exp_add(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.exp_add.fa{i}.xor2.layer2")] for i in range(5): registry.register(f"{prefix}.not_exp_add{i}") match = re.search(r'\.not_exp_dec(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.exp_dec.fa{i}.xor2.layer2")] for i in range(5): registry.register(f"{prefix}.not_exp_dec{i}") if '.sub_shift_base.fa' in gate: match = re.search(r'\.sub_shift_base\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_shift_base.fa{i}" const_1 = [1, 0, 0, 0, 0] a_bit = registry.get_id(f"#{const_1[i]}") b_bit = registry.get_id(f"{prefix}.not_exp_dec{i}") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.sub_shift_base.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.sub_shift_base.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_shift_base.fa{i}.cout") if '.sub_shift.fa' in gate: match = re.search(r'\.sub_shift\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_shift.fa{i}" a_bit = registry.get_id(f"{prefix}.sub_shift_base.fa{i}.xor2.layer2") b_bit = registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.sub_shift.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.sub_shift.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_shift.fa{i}.cout") match = re.search(r'\.not_sub_shift(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_shift.fa{i}.xor2.layer2")] for i in range(5): registry.register(f"{prefix}.not_sub_shift{i}") match = re.search(r'\.norm_full(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: if i == 11: src = registry.get_id(f"{prefix}.prod2_fa21.xor2.layer2") else: src = registry.get_id(f"{prefix}.prod2_fa{i+10}.xor2.layer2") return [src, registry.get_id(f"{prefix}.prod_overflow")] if '.normal_path' in gate: if i == 11: src = registry.get_id(f"{prefix}.prod2_fa20.xor2.layer2") else: src = registry.get_id(f"{prefix}.prod2_fa{i+9}.xor2.layer2") return [src, registry.get_id(f"{prefix}.not_prod_overflow")] match = re.search(r'\.norm_full(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.norm_full{i}.overflow_path"), registry.register(f"{prefix}.norm_full{i}.normal_path")] for i in range(12): registry.register(f"{prefix}.norm_full{i}") # CLZ on norm_full (bits 11:0) for left normalization match = re.search(r'\.norm_pz(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.norm_full{11-i}") for i in range(k)] for k in range(1, 13): registry.register(f"{prefix}.norm_pz{k}") pz_ids = [registry.get_id(f"{prefix}.norm_pz{k}") for k in range(1, 13)] match = re.search(r'\.norm_ge(\d+)$', gate) if match: return pz_ids for k in range(1, 13): registry.register(f"{prefix}.norm_ge{k}") match = re.search(r'\.norm_not_ge(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.norm_ge{k}")] for k in [2, 4, 6, 8, 10, 12]: registry.register(f"{prefix}.norm_not_ge{k}") if '.norm_shift3' in gate: return [registry.get_id(f"{prefix}.norm_ge8")] if '.norm_and_4_7' in gate: return [registry.get_id(f"{prefix}.norm_ge4"), registry.get_id(f"{prefix}.norm_not_ge8")] registry.register(f"{prefix}.norm_and_4_7") if '.norm_and_12' in gate: return [registry.get_id(f"{prefix}.norm_ge12")] registry.register(f"{prefix}.norm_and_12") if '.norm_shift2' in gate: return [registry.get_id(f"{prefix}.norm_and_4_7"), registry.get_id(f"{prefix}.norm_and_12")] if '.norm_and_2_3' in gate: return [registry.get_id(f"{prefix}.norm_ge2"), registry.get_id(f"{prefix}.norm_not_ge4")] if '.norm_and_6_7' in gate: return [registry.get_id(f"{prefix}.norm_ge6"), registry.get_id(f"{prefix}.norm_not_ge8")] if '.norm_and_10_11' in gate: return [registry.get_id(f"{prefix}.norm_ge10"), registry.get_id(f"{prefix}.norm_not_ge12")] registry.register(f"{prefix}.norm_and_2_3") registry.register(f"{prefix}.norm_and_6_7") registry.register(f"{prefix}.norm_and_10_11") if '.norm_shift1' in gate: return [registry.get_id(f"{prefix}.norm_and_2_3"), registry.get_id(f"{prefix}.norm_and_6_7"), registry.get_id(f"{prefix}.norm_and_10_11")] match = re.search(r'\.norm_and_(\d+)$', gate) if match: i = int(match.group(1)) if i in [1, 3, 5, 7, 9, 11]: return [registry.get_id(f"{prefix}.norm_ge{i}"), registry.get_id(f"{prefix}.norm_not_ge{i+1}")] for i in [1, 3, 5, 7, 9, 11]: registry.register(f"{prefix}.norm_and_{i}") if '.norm_shift0' in gate: return [registry.get_id(f"{prefix}.norm_and_{i}") for i in [1, 3, 5, 7, 9, 11]] for i in range(4): registry.register(f"{prefix}.norm_shift{i}") if '.norm_shift_eq10' in gate: return [registry.get_id(f"{prefix}.norm_shift{i}") for i in range(4)] registry.register(f"{prefix}.norm_shift_eq10") # NOT shift bits for i in range(4): if f'.not_norm_shift{i}' in gate and '.not_norm_shift_sub' not in gate: return [registry.get_id(f"{prefix}.norm_shift{i}")] registry.register(f"{prefix}.not_norm_shift{i}") # 12-bit left barrel shifter stages for norm_full match = re.search(r'\.lshift_s0_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.norm_full{i}"), registry.get_id(f"{prefix}.not_norm_shift0")] if '.shift' in gate and i > 0: return [registry.get_id(f"{prefix}.norm_full{i-1}"), registry.get_id(f"{prefix}.norm_shift0")] match = re.search(r'\.lshift_s0_(\d+)$', gate) if match: i = int(match.group(1)) if i > 0: return [registry.register(f"{prefix}.lshift_s0_{i}.pass"), registry.register(f"{prefix}.lshift_s0_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s0_{i}.pass")] for i in range(12): registry.register(f"{prefix}.lshift_s0_{i}") match = re.search(r'\.lshift_s1_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.lshift_s0_{i}"), registry.get_id(f"{prefix}.not_norm_shift1")] if '.shift' in gate and i > 1: return [registry.get_id(f"{prefix}.lshift_s0_{i-2}"), registry.get_id(f"{prefix}.norm_shift1")] match = re.search(r'\.lshift_s1_(\d+)$', gate) if match: i = int(match.group(1)) if i > 1: return [registry.register(f"{prefix}.lshift_s1_{i}.pass"), registry.register(f"{prefix}.lshift_s1_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s1_{i}.pass")] for i in range(12): registry.register(f"{prefix}.lshift_s1_{i}") match = re.search(r'\.lshift_s2_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.lshift_s1_{i}"), registry.get_id(f"{prefix}.not_norm_shift2")] if '.shift' in gate and i > 3: return [registry.get_id(f"{prefix}.lshift_s1_{i-4}"), registry.get_id(f"{prefix}.norm_shift2")] match = re.search(r'\.lshift_s2_(\d+)$', gate) if match: i = int(match.group(1)) if i > 3: return [registry.register(f"{prefix}.lshift_s2_{i}.pass"), registry.register(f"{prefix}.lshift_s2_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s2_{i}.pass")] for i in range(12): registry.register(f"{prefix}.lshift_s2_{i}") match = re.search(r'\.lshift_s3_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.lshift_s2_{i}"), registry.get_id(f"{prefix}.not_norm_shift3")] if '.shift' in gate and i > 7: return [registry.get_id(f"{prefix}.lshift_s2_{i-8}"), registry.get_id(f"{prefix}.norm_shift3")] match = re.search(r'\.lshift_s3_(\d+)$', gate) if match: i = int(match.group(1)) if i > 7: return [registry.register(f"{prefix}.lshift_s3_{i}.pass"), registry.register(f"{prefix}.lshift_s3_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s3_{i}.pass")] for i in range(12): registry.register(f"{prefix}.lshift_s3_{i}") # Track bits shifted out during normalization for sticky match = re.search(r'\.norm_shift_gt(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.norm_shift{i}") for i in range(4)] for k in range(12): registry.register(f"{prefix}.norm_shift_gt{k}") match = re.search(r'\.norm_sticky_part(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.norm_full{i}"), registry.get_id(f"{prefix}.norm_shift_gt{i}")] for i in range(12): registry.register(f"{prefix}.norm_sticky_part{i}") if '.norm_shifted_out_or' in gate: return [registry.get_id(f"{prefix}.norm_sticky_part{i}") for i in range(12)] registry.register(f"{prefix}.norm_shifted_out_or") # Subnormal shift source: select prod_lshift bits with correct overflow offset match = re.search(r'\.sub_src(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: return [registry.get_id(f"{prefix}.prod_lshift_s3_{i+10}"), registry.get_id(f"{prefix}.prod_overflow")] if '.normal_path' in gate: return [registry.get_id(f"{prefix}.prod_lshift_s3_{i+9}"), registry.get_id(f"{prefix}.not_prod_overflow")] match = re.search(r'\.sub_src(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.sub_src{i}.overflow_path"), registry.register(f"{prefix}.sub_src{i}.normal_path")] for i in range(12): registry.register(f"{prefix}.sub_src{i}") match = re.search(r'\.sub_rshift_s0_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_src{i}"), registry.get_id(f"{prefix}.not_sub_shift0")] if '.shift' in gate and i < 11: return [registry.get_id(f"{prefix}.sub_src{i+1}"), registry.get_id(f"{prefix}.sub_shift.fa0.xor2.layer2")] match = re.search(r'\.sub_rshift_s0_(\d+)$', gate) if match: i = int(match.group(1)) if i < 11: return [registry.register(f"{prefix}.sub_rshift_s0_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s0_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s0_{i}.pass")] for i in range(12): registry.register(f"{prefix}.sub_rshift_s0_{i}") match = re.search(r'\.sub_rshift_s1_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s0_{i}"), registry.get_id(f"{prefix}.not_sub_shift1")] if '.shift' in gate and i < 10: return [registry.get_id(f"{prefix}.sub_rshift_s0_{i+2}"), registry.get_id(f"{prefix}.sub_shift.fa1.xor2.layer2")] match = re.search(r'\.sub_rshift_s1_(\d+)$', gate) if match: i = int(match.group(1)) if i < 10: return [registry.register(f"{prefix}.sub_rshift_s1_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s1_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s1_{i}.pass")] for i in range(12): registry.register(f"{prefix}.sub_rshift_s1_{i}") match = re.search(r'\.sub_rshift_s2_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s1_{i}"), registry.get_id(f"{prefix}.not_sub_shift2")] if '.shift' in gate and i < 8: return [registry.get_id(f"{prefix}.sub_rshift_s1_{i+4}"), registry.get_id(f"{prefix}.sub_shift.fa2.xor2.layer2")] match = re.search(r'\.sub_rshift_s2_(\d+)$', gate) if match: i = int(match.group(1)) if i < 8: return [registry.register(f"{prefix}.sub_rshift_s2_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s2_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s2_{i}.pass")] for i in range(12): registry.register(f"{prefix}.sub_rshift_s2_{i}") match = re.search(r'\.sub_rshift_s3_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s2_{i}"), registry.get_id(f"{prefix}.not_sub_shift3")] if '.shift' in gate and i < 4: return [registry.get_id(f"{prefix}.sub_rshift_s2_{i+8}"), registry.get_id(f"{prefix}.sub_shift.fa3.xor2.layer2")] match = re.search(r'\.sub_rshift_s3_(\d+)$', gate) if match: i = int(match.group(1)) if i < 4: return [registry.register(f"{prefix}.sub_rshift_s3_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s3_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s3_{i}.pass")] for i in range(12): registry.register(f"{prefix}.sub_rshift_s3_{i}") match = re.search(r'\.sub_shifted(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_rshift_s3_{i}"), registry.get_id(f"{prefix}.not_sub_shift4")] for i in range(12): registry.register(f"{prefix}.sub_shifted{i}") match = re.search(r'\.sub_mant(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_shifted{i+1}")] for i in range(10): registry.register(f"{prefix}.sub_mant{i}") if gate == f"{prefix}.sub_guard": return [registry.get_id(f"{prefix}.sub_shifted0")] registry.register(f"{prefix}.sub_guard") match = re.search(r'\.sub_shift_gt(\d+)$', gate) if match: return [registry.get_id(f"{prefix}.sub_shift.fa{i}.xor2.layer2") for i in range(5)] for k in range(12): registry.register(f"{prefix}.sub_shift_gt{k}") match = re.search(r'\.sub_sticky_part(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_src{i}"), registry.get_id(f"{prefix}.sub_shift_gt{i}")] for i in range(12): registry.register(f"{prefix}.sub_sticky_part{i}") if gate.endswith('.sub_sticky_raw'): return [registry.get_id(f"{prefix}.sub_sticky_part{i}") for i in range(12)] registry.register(f"{prefix}.sub_sticky_raw") if gate.endswith('.sub_sticky'): return [registry.get_id(f"{prefix}.sub_sticky_raw"), registry.register(f"{prefix}.round_sticky")] registry.register(f"{prefix}.sub_sticky") if gate.endswith('.sub_round_lsb_or_sticky'): return [registry.get_id(f"{prefix}.sub_sticky"), registry.get_id(f"{prefix}.sub_mant0")] registry.register(f"{prefix}.sub_round_lsb_or_sticky") if gate.endswith('.sub_round_inc'): return [registry.get_id(f"{prefix}.sub_guard"), registry.get_id(f"{prefix}.sub_round_lsb_or_sticky")] registry.register(f"{prefix}.sub_round_inc") if '.sub_round.fa' in gate: match = re.search(r'\.sub_round\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_round.fa{i}" a_bit = registry.get_id(f"{prefix}.sub_mant{i}") b_bit = registry.get_id(f"{prefix}.sub_round_inc") if i == 0 else registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.sub_round.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(10): registry.register(f"{prefix}.sub_round.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_round.fa{i}.cout") if gate.endswith('.sub_round_overflow'): return [registry.get_id(f"{prefix}.sub_round.fa9.cout")] registry.register(f"{prefix}.sub_round_overflow") if '.not_result_is_inf' in gate: return [registry.get_id(f"{prefix}.result_is_inf")] if '.not_result_is_zero' in gate: return [registry.get_id(f"{prefix}.result_is_zero")] if '.not_exp_underflow' in gate: return [registry.get_id(f"{prefix}.exp_underflow")] registry.register(f"{prefix}.not_result_is_inf") registry.register(f"{prefix}.not_result_is_zero") registry.register(f"{prefix}.not_exp_underflow") if '.subnorm_enable' in gate: return [registry.get_id(f"{prefix}.exp_underflow"), registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_result_is_inf"), registry.get_id(f"{prefix}.not_result_is_zero")] registry.register(f"{prefix}.subnorm_enable") if '.is_normal_result' in gate: return [registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_result_is_inf"), registry.get_id(f"{prefix}.not_result_is_zero"), registry.get_id(f"{prefix}.not_exp_underflow")] registry.register(f"{prefix}.is_normal_result") # out_nan{i}: constant NaN bits (canonical NaN = 0x7E00) # bits 9 and 10-14 are 1, others are 0 match = re.search(r'\.out_nan(\d+)$', gate) if match: i = int(match.group(1)) # NaN = 0x7E00 = 0111111000000000, bit 9 and bits 10-14 are 1 if i == 9 or (i >= 10 and i < 15): return [registry.get_id("#1")] else: return [registry.get_id("#0")] # out_normal{i}: pass-through from rounded result match = re.search(r'\.out_normal(\d+)$', gate) if match: i = int(match.group(1)) if i < 10: return [registry.get_id(f"{prefix}.final_mant{i}")] elif i < 15: return [registry.get_id(f"{prefix}.final_exp{i-10}")] else: return [registry.get_id(f"{prefix}.result_sign.layer2")] # out_sub{i}: subnormal path match = re.search(r'\.out_sub(\d+)$', gate) if match: i = int(match.group(1)) if i == 15: return [registry.get_id(f"{prefix}.result_sign.layer2")] elif i == 10: return [registry.get_id(f"{prefix}.sub_round_overflow")] elif 10 < i < 15: return [registry.get_id("#0")] else: return [registry.get_id(f"{prefix}.sub_round.fa{i}.xor2.layer2")] match = re.search(r'\.out(\d+)\.', gate) if match: i = int(match.group(1)) if '.nan_gate' in gate: # Canonical NaN = 0x7E00 = 0_11111_1000000000, bits 9-14 are 1 nan_bit = registry.get_id("#1") if (i >= 9 and i < 15) else registry.get_id("#0") return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")] if '.inf_gate' in gate: # Inf = 0x7C00 = 0_11111_0000000000, bits 10-14 are 1 if i == 15: inf_bit = registry.get_id(f"{prefix}.result_sign.layer2") else: inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0") return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")] if '.zero_gate' in gate: zero_bit = registry.get_id(f"{prefix}.result_sign.layer2") if i == 15 else registry.get_id("#0") return [zero_bit, registry.get_id(f"{prefix}.result_is_zero")] if '.normal_gate' in gate: if i < 10: normal_bit = registry.get_id(f"{prefix}.final_mant{i}") elif i < 15: normal_bit = registry.get_id(f"{prefix}.final_exp{i-10}") else: normal_bit = registry.get_id(f"{prefix}.result_sign.layer2") return [normal_bit, registry.get_id(f"{prefix}.is_normal_result")] if '.sub_gate' in gate: return [registry.register(f"{prefix}.out_sub{i}"), registry.get_id(f"{prefix}.subnorm_enable")] match = re.search(r'\.out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.out{i}.nan_gate"), registry.register(f"{prefix}.out{i}.inf_gate"), registry.register(f"{prefix}.out{i}.zero_gate"), registry.register(f"{prefix}.out{i}.normal_gate"), registry.register(f"{prefix}.out{i}.sub_gate")] return [] def infer_float16_div_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.div circuit.""" prefix = "float16.div" for i in range(16): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") exp_a_bits = [f"{prefix}.$a[{10+i}]" for i in range(5)] exp_b_bits = [f"{prefix}.$b[{10+i}]" for i in range(5)] mant_a_bits = [f"{prefix}.$a[{i}]" for i in range(10)] mant_b_bits = [f"{prefix}.$b[{i}]" for i in range(10)] if '.exp_a_all_ones' in gate: return [registry.get_id(b) for b in exp_a_bits] if '.exp_b_all_ones' in gate: return [registry.get_id(b) for b in exp_b_bits] if '.exp_a_zero' in gate: return [registry.get_id(b) for b in exp_a_bits] if '.exp_b_zero' in gate: return [registry.get_id(b) for b in exp_b_bits] registry.register(f"{prefix}.exp_a_all_ones") registry.register(f"{prefix}.exp_b_all_ones") registry.register(f"{prefix}.exp_a_zero") registry.register(f"{prefix}.exp_b_zero") if '.a_adj_exp0' in gate: return [registry.get_id(exp_a_bits[0]), registry.get_id(f"{prefix}.exp_a_zero")] if '.b_adj_exp0' in gate: return [registry.get_id(exp_b_bits[0]), registry.get_id(f"{prefix}.exp_b_zero")] registry.register(f"{prefix}.a_adj_exp0") registry.register(f"{prefix}.b_adj_exp0") if '.mant_a_nonzero' in gate: return [registry.get_id(b) for b in mant_a_bits] if '.mant_b_nonzero' in gate: return [registry.get_id(b) for b in mant_b_bits] registry.register(f"{prefix}.mant_a_nonzero") registry.register(f"{prefix}.mant_b_nonzero") if '.mant_a_zero' in gate: return [registry.get_id(f"{prefix}.mant_a_nonzero")] if '.mant_b_zero' in gate: return [registry.get_id(f"{prefix}.mant_b_nonzero")] registry.register(f"{prefix}.mant_a_zero") registry.register(f"{prefix}.mant_b_zero") if '.a_is_subnormal' in gate: return [registry.get_id(f"{prefix}.exp_a_zero"), registry.get_id(f"{prefix}.mant_a_nonzero")] if '.b_is_subnormal' in gate: return [registry.get_id(f"{prefix}.exp_b_zero"), registry.get_id(f"{prefix}.mant_b_nonzero")] registry.register(f"{prefix}.a_is_subnormal") registry.register(f"{prefix}.b_is_subnormal") if '.implicit_a_raw' in gate: return [registry.get_id(f"{prefix}.exp_a_zero")] if '.implicit_b_raw' in gate: return [registry.get_id(f"{prefix}.exp_b_zero")] registry.register(f"{prefix}.implicit_a_raw") registry.register(f"{prefix}.implicit_b_raw") match = re.search(r'\.a_pz(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(mant_a_bits[9 - j]) for j in range(k)] match = re.search(r'\.b_pz(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(mant_b_bits[9 - j]) for j in range(k)] for k in range(1, 10): registry.register(f"{prefix}.a_pz{k}") registry.register(f"{prefix}.b_pz{k}") match = re.search(r'\.a_lead(\d+)$', gate) if match: i = int(match.group(1)) if i == 9: return [registry.get_id(mant_a_bits[9])] return [registry.get_id(f"{prefix}.a_pz{9 - i}"), registry.get_id(mant_a_bits[i])] match = re.search(r'\.b_lead(\d+)$', gate) if match: i = int(match.group(1)) if i == 9: return [registry.get_id(mant_b_bits[9])] return [registry.get_id(f"{prefix}.b_pz{9 - i}"), registry.get_id(mant_b_bits[i])] for i in range(10): registry.register(f"{prefix}.a_lead{i}") registry.register(f"{prefix}.b_lead{i}") if '.a_shift_raw0' in gate: return [registry.get_id(f"{prefix}.a_lead{i}") for i in [9, 7, 5, 3, 1]] if '.a_shift_raw1' in gate: return [registry.get_id(f"{prefix}.a_lead{i}") for i in [8, 7, 4, 3, 0]] if '.a_shift_raw2' in gate: return [registry.get_id(f"{prefix}.a_lead{i}") for i in [6, 5, 4, 3]] if '.a_shift_raw3' in gate: return [registry.get_id(f"{prefix}.a_lead{i}") for i in [2, 1, 0]] if '.b_shift_raw0' in gate: return [registry.get_id(f"{prefix}.b_lead{i}") for i in [9, 7, 5, 3, 1]] if '.b_shift_raw1' in gate: return [registry.get_id(f"{prefix}.b_lead{i}") for i in [8, 7, 4, 3, 0]] if '.b_shift_raw2' in gate: return [registry.get_id(f"{prefix}.b_lead{i}") for i in [6, 5, 4, 3]] if '.b_shift_raw3' in gate: return [registry.get_id(f"{prefix}.b_lead{i}") for i in [2, 1, 0]] for i in range(4): registry.register(f"{prefix}.a_shift_raw{i}") registry.register(f"{prefix}.b_shift_raw{i}") match = re.search(r'\.a_shift(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.a_shift_raw{i}"), registry.get_id(f"{prefix}.a_is_subnormal")] match = re.search(r'\.b_shift(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.b_shift_raw{i}"), registry.get_id(f"{prefix}.b_is_subnormal")] for i in range(4): registry.register(f"{prefix}.a_shift{i}") registry.register(f"{prefix}.b_shift{i}") match = re.search(r'\.not_a_shift(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.a_shift{i}")] match = re.search(r'\.not_b_shift(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.b_shift{i}")] for i in range(4): registry.register(f"{prefix}.not_a_shift{i}") registry.register(f"{prefix}.not_b_shift{i}") match = re.search(r'\.a_norm_s(\d+)_(\d+)\.(pass|shift)$', gate) if match: stage = int(match.group(1)) bit = int(match.group(2)) kind = match.group(3) shift_amt = 1 << stage if stage == 0: in_bit = registry.get_id(f"{prefix}.implicit_a_raw") if bit == 10 else registry.get_id(mant_a_bits[bit]) else: in_bit = registry.get_id(f"{prefix}.a_norm_s{stage-1}_{bit}") if kind == "pass": return [in_bit, registry.get_id(f"{prefix}.not_a_shift{stage}")] if bit < shift_amt: return [registry.get_id("#0"), registry.get_id(f"{prefix}.a_shift{stage}")] if stage == 0: shift_in = registry.get_id(f"{prefix}.implicit_a_raw") if bit - shift_amt == 10 else registry.get_id(mant_a_bits[bit - shift_amt]) else: shift_in = registry.get_id(f"{prefix}.a_norm_s{stage-1}_{bit - shift_amt}") return [shift_in, registry.get_id(f"{prefix}.a_shift{stage}")] match = re.search(r'\.b_norm_s(\d+)_(\d+)\.(pass|shift)$', gate) if match: stage = int(match.group(1)) bit = int(match.group(2)) kind = match.group(3) shift_amt = 1 << stage if stage == 0: in_bit = registry.get_id(f"{prefix}.implicit_b_raw") if bit == 10 else registry.get_id(mant_b_bits[bit]) else: in_bit = registry.get_id(f"{prefix}.b_norm_s{stage-1}_{bit}") if kind == "pass": return [in_bit, registry.get_id(f"{prefix}.not_b_shift{stage}")] if bit < shift_amt: return [registry.get_id("#0"), registry.get_id(f"{prefix}.b_shift{stage}")] if stage == 0: shift_in = registry.get_id(f"{prefix}.implicit_b_raw") if bit - shift_amt == 10 else registry.get_id(mant_b_bits[bit - shift_amt]) else: shift_in = registry.get_id(f"{prefix}.b_norm_s{stage-1}_{bit - shift_amt}") return [shift_in, registry.get_id(f"{prefix}.b_shift{stage}")] match = re.search(r'\.a_norm_s(\d+)_(\d+)$', gate) if match: stage = int(match.group(1)) bit = int(match.group(2)) shift_amt = 1 << stage pass_gate = registry.register(f"{prefix}.a_norm_s{stage}_{bit}.pass") if bit >= shift_amt: shift_gate = registry.register(f"{prefix}.a_norm_s{stage}_{bit}.shift") return [pass_gate, shift_gate] return [pass_gate] match = re.search(r'\.b_norm_s(\d+)_(\d+)$', gate) if match: stage = int(match.group(1)) bit = int(match.group(2)) shift_amt = 1 << stage pass_gate = registry.register(f"{prefix}.b_norm_s{stage}_{bit}.pass") if bit >= shift_amt: shift_gate = registry.register(f"{prefix}.b_norm_s{stage}_{bit}.shift") return [pass_gate, shift_gate] return [pass_gate] for stage in range(4): for bit in range(11): registry.register(f"{prefix}.a_norm_s{stage}_{bit}") registry.register(f"{prefix}.b_norm_s{stage}_{bit}") match = re.search(r'\.mant_a_norm(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.a_norm_s3_{i}")] match = re.search(r'\.mant_b_norm(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.b_norm_s3_{i}")] for i in range(10): registry.register(f"{prefix}.mant_a_norm{i}") registry.register(f"{prefix}.mant_b_norm{i}") if '.a_is_nan' in gate: return [registry.get_id(f"{prefix}.exp_a_all_ones"), registry.get_id(f"{prefix}.mant_a_nonzero")] if '.b_is_nan' in gate: return [registry.get_id(f"{prefix}.exp_b_all_ones"), registry.get_id(f"{prefix}.mant_b_nonzero")] if '.a_is_inf' in gate: return [registry.get_id(f"{prefix}.exp_a_all_ones"), registry.get_id(f"{prefix}.mant_a_zero")] if '.b_is_inf' in gate: return [registry.get_id(f"{prefix}.exp_b_all_ones"), registry.get_id(f"{prefix}.mant_b_zero")] if '.a_is_zero' in gate: return [registry.get_id(f"{prefix}.exp_a_zero"), registry.get_id(f"{prefix}.mant_a_zero")] if '.b_is_zero' in gate: return [registry.get_id(f"{prefix}.exp_b_zero"), registry.get_id(f"{prefix}.mant_b_zero")] registry.register(f"{prefix}.a_is_nan") registry.register(f"{prefix}.b_is_nan") registry.register(f"{prefix}.a_is_inf") registry.register(f"{prefix}.b_is_inf") registry.register(f"{prefix}.a_is_zero") registry.register(f"{prefix}.b_is_zero") if '.either_is_nan' in gate: return [registry.get_id(f"{prefix}.a_is_nan"), registry.get_id(f"{prefix}.b_is_nan")] if '.both_inf' in gate: return [registry.get_id(f"{prefix}.a_is_inf"), registry.get_id(f"{prefix}.b_is_inf")] if '.both_zero' in gate: return [registry.get_id(f"{prefix}.a_is_zero"), registry.get_id(f"{prefix}.b_is_zero")] registry.register(f"{prefix}.either_is_nan") registry.register(f"{prefix}.both_inf") registry.register(f"{prefix}.both_zero") if gate.endswith('.result_is_nan'): return [registry.get_id(f"{prefix}.either_is_nan"), registry.get_id(f"{prefix}.both_inf"), registry.get_id(f"{prefix}.both_zero")] if '.not_result_is_nan' in gate: return [registry.get_id(f"{prefix}.result_is_nan")] if '.not_a_is_zero' in gate: return [registry.get_id(f"{prefix}.a_is_zero")] if '.not_b_is_zero' in gate: return [registry.get_id(f"{prefix}.b_is_zero")] if '.not_a_is_inf' in gate: return [registry.get_id(f"{prefix}.a_is_inf")] if '.not_b_is_inf' in gate: return [registry.get_id(f"{prefix}.b_is_inf")] registry.register(f"{prefix}.result_is_nan") registry.register(f"{prefix}.not_result_is_nan") registry.register(f"{prefix}.not_a_is_zero") registry.register(f"{prefix}.not_b_is_zero") registry.register(f"{prefix}.not_a_is_inf") registry.register(f"{prefix}.not_b_is_inf") if '.finite_div_zero' in gate: return [registry.get_id(f"{prefix}.not_a_is_zero"), registry.get_id(f"{prefix}.b_is_zero")] if '.inf_div_finite' in gate: return [registry.get_id(f"{prefix}.a_is_inf"), registry.get_id(f"{prefix}.not_b_is_inf")] if gate.endswith('.result_is_inf'): return [registry.get_id(f"{prefix}.finite_div_zero"), registry.get_id(f"{prefix}.inf_div_finite"), registry.get_id(f"{prefix}.exp_overflow_to_inf"), registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_result_is_zero")] registry.register(f"{prefix}.finite_div_zero") registry.register(f"{prefix}.inf_div_finite") registry.register(f"{prefix}.result_is_inf") if '.zero_div_finite' in gate: return [registry.get_id(f"{prefix}.a_is_zero"), registry.get_id(f"{prefix}.not_b_is_zero")] if '.finite_div_inf' in gate: return [registry.get_id(f"{prefix}.not_a_is_inf"), registry.get_id(f"{prefix}.b_is_inf")] if gate.endswith('.result_is_zero'): return [registry.get_id(f"{prefix}.zero_div_finite"), registry.get_id(f"{prefix}.finite_div_inf"), registry.get_id(f"{prefix}.not_result_is_nan")] registry.register(f"{prefix}.zero_div_finite") registry.register(f"{prefix}.finite_div_inf") registry.register(f"{prefix}.result_is_zero") if '.result_sign.layer1.or' in gate: return [registry.get_id(f"{prefix}.$a[15]"), registry.get_id(f"{prefix}.$b[15]")] if '.result_sign.layer1.nand' in gate: return [registry.get_id(f"{prefix}.$a[15]"), registry.get_id(f"{prefix}.$b[15]")] if '.result_sign.layer2' in gate: return [registry.register(f"{prefix}.result_sign.layer1.or"), registry.register(f"{prefix}.result_sign.layer1.nand")] registry.register(f"{prefix}.result_sign.layer2") match = re.search(r'\.not_exp_b(\d+)$', gate) if match: i = int(match.group(1)) if i == 0: return [registry.get_id(f"{prefix}.b_adj_exp0")] return [registry.get_id(exp_b_bits[i])] for i in range(5): registry.register(f"{prefix}.not_exp_b{i}") if '.exp_sub.fa' in gate: match = re.search(r'\.exp_sub\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_sub.fa{i}" if i == 0: a_bit = registry.get_id(f"{prefix}.a_adj_exp0") else: a_bit = registry.get_id(exp_a_bits[i]) if i < 5 else registry.get_id("#0") b_bit = registry.get_id(f"{prefix}.not_exp_b{i}") if i < 5 else registry.get_id("#1") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.exp_sub.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(6): registry.register(f"{prefix}.exp_sub.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_sub.fa{i}.cout") if '.exp_sub_a.fa' in gate: match = re.search(r'\.exp_sub_a\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_sub_a.fa{i}" a_bit = registry.get_id(f"{prefix}.exp_sub.fa{i}.xor2.layer2") if i < 4: b_bit = registry.get_id(f"{prefix}.not_a_shift{i}") else: b_bit = registry.get_id("#1") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.exp_sub_a.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(6): registry.register(f"{prefix}.exp_sub_a.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_sub_a.fa{i}.cout") if '.exp_sub_ab.fa' in gate: match = re.search(r'\.exp_sub_ab\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_sub_ab.fa{i}" a_bit = registry.get_id(f"{prefix}.exp_sub_a.fa{i}.xor2.layer2") if i < 4: b_bit = registry.get_id(f"{prefix}.b_shift{i}") else: b_bit = registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.exp_sub_ab.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(6): registry.register(f"{prefix}.exp_sub_ab.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_sub_ab.fa{i}.cout") bits_15 = [1, 1, 1, 1, 0, 0] if '.exp_add15.fa' in gate: match = re.search(r'\.exp_add15\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_add15.fa{i}" a_bit = registry.get_id(f"{prefix}.exp_sub_ab.fa{i}.xor2.layer2") b_bit = registry.get_id(f"#{bits_15[i]}") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.exp_add15.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(6): registry.register(f"{prefix}.exp_add15.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_add15.fa{i}.cout") if '.implicit_a' in gate: return [registry.get_id(f"{prefix}.a_is_zero")] if '.implicit_b' in gate: return [registry.get_id(f"{prefix}.b_is_zero")] registry.register(f"{prefix}.implicit_a") registry.register(f"{prefix}.implicit_b") for i in range(10): if f'.not_div_b{i}' in gate: return [registry.get_id(f"{prefix}.mant_b_norm{i}")] registry.register(f"{prefix}.not_div_b{i}") if '.not_implicit_b' in gate: return [registry.get_id(f"{prefix}.implicit_b")] registry.register(f"{prefix}.not_implicit_b") for step in range(13): sp = f"{prefix}.div_step{step}" # Pre-register subtractor outputs so q_bit can reference them for i in range(12): registry.register(f"{sp}.sub.fa{i}.xor2.layer2") registry.register(f"{sp}.sub.fa{i}.cout") if f'.div_step{step}.q_bit' in gate: return [registry.get_id(f"{sp}.sub.fa11.cout")] registry.register(f"{sp}.q_bit") if f'.div_step{step}.not_q_bit' in gate: return [registry.get_id(f"{sp}.q_bit")] registry.register(f"{sp}.not_q_bit") if f'.div_step{step}.sub.fa' in gate: match = re.search(rf'\.div_step{step}\.sub\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{sp}.sub.fa{i}" if step == 0: if i <= 9: a_bit = registry.get_id(f"{prefix}.mant_a_norm{i}") elif i == 10: a_bit = registry.get_id(f"{prefix}.implicit_a") else: a_bit = registry.get_id("#0") else: if i == 0: a_bit = registry.get_id("#0") else: a_bit = registry.get_id(f"{prefix}.div_step{step-1}.rem{i-1}") if i <= 10: b_bit = registry.get_id(f"{prefix}.not_div_b{i}") if i < 10 else registry.get_id(f"{prefix}.not_implicit_b") else: b_bit = registry.get_id("#1") cin = registry.get_id("#1") if i == 0 else registry.register(f"{sp}.sub.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] match = re.search(rf'\.div_step{step}\.rem(\d+)\.', gate) if match: i = int(match.group(1)) if '.sub_path' in gate: return [registry.get_id(f"{sp}.sub.fa{i}.xor2.layer2"), registry.get_id(f"{sp}.q_bit")] if '.shift_path' in gate: if step == 0: if i < 10: shift_in = registry.get_id(f"{prefix}.mant_a_norm{i}") elif i == 10: shift_in = registry.get_id(f"{prefix}.implicit_a") else: shift_in = registry.get_id("#0") else: if i == 0: shift_in = registry.get_id("#0") else: shift_in = registry.get_id(f"{prefix}.div_step{step-1}.rem{i-1}") return [shift_in, registry.get_id(f"{sp}.not_q_bit")] match = re.search(rf'\.div_step{step}\.rem(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{sp}.rem{i}.sub_path"), registry.register(f"{sp}.rem{i}.shift_path")] for i in range(12): registry.register(f"{sp}.rem{i}") if '.need_norm' in gate: return [registry.get_id(f"{prefix}.div_step0.q_bit")] registry.register(f"{prefix}.need_norm") if '.not_need_norm' in gate: return [registry.get_id(f"{prefix}.need_norm")] registry.register(f"{prefix}.not_need_norm") if '.exp_norm.fa' in gate: match = re.search(r'\.exp_norm\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_norm.fa{i}" a_bit = registry.get_id(f"{prefix}.exp_add15.fa{i}.xor2.layer2") b_bit = registry.get_id(f"{prefix}.not_need_norm") if i == 0 else registry.get_id("#1") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.exp_norm.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(6): registry.register(f"{prefix}.exp_norm.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_norm.fa{i}.cout") if '.exp_underflow_borrow' in gate: return [registry.get_id(f"{prefix}.exp_norm.fa5.xor2.layer2")] registry.register(f"{prefix}.exp_underflow_borrow") if '.exp_norm_zero' in gate: return [registry.get_id(f"{prefix}.exp_norm.fa{i}.xor2.layer2") for i in range(6)] registry.register(f"{prefix}.exp_norm_zero") if '.exp_underflow' in gate: return [registry.get_id(f"{prefix}.exp_underflow_borrow"), registry.get_id(f"{prefix}.exp_norm_zero")] registry.register(f"{prefix}.exp_underflow") if '.exp_norm_all_ones' in gate: return [registry.register(f"{prefix}.exp_norm.fa{i}.xor2.layer2") for i in range(5)] registry.register(f"{prefix}.exp_norm_all_ones") if '.exp_out_all_ones' in gate: return [registry.register(f"{prefix}.exp_out{i}") for i in range(5)] if '.exp_overflow_any' in gate: return [registry.get_id(f"{prefix}.exp_norm.fa5.xor2.layer2"), registry.get_id(f"{prefix}.exp_out_all_ones")] if '.exp_overflow_to_inf' in gate: return [registry.get_id(f"{prefix}.exp_sub.fa5.cout"), registry.get_id(f"{prefix}.exp_overflow_any")] registry.register(f"{prefix}.exp_out_all_ones") registry.register(f"{prefix}.exp_overflow_any") registry.register(f"{prefix}.exp_overflow_to_inf") match = re.search(r'\.norm_mant(\d+)\.', gate) if match: i = int(match.group(1)) if '.norm_path' in gate: q_src = registry.get_id(f"{prefix}.div_step{11-i}.q_bit") return [q_src, registry.get_id(f"{prefix}.need_norm")] if '.direct_path' in gate: return [registry.get_id(f"{prefix}.div_step{10-i}.q_bit") if i < 10 else registry.get_id("#0"), registry.get_id(f"{prefix}.not_need_norm")] match = re.search(r'\.norm_mant(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.norm_mant{i}.norm_path"), registry.register(f"{prefix}.norm_mant{i}.direct_path")] for i in range(10): registry.register(f"{prefix}.norm_mant{i}") if '.rem_nonzero' in gate: return [registry.get_id(f"{prefix}.div_step12.rem{i}") for i in range(12)] registry.register(f"{prefix}.rem_nonzero") if '.guard_direct' in gate: return [registry.get_id(f"{prefix}.div_step11.q_bit"), registry.get_id(f"{prefix}.not_need_norm")] if '.guard_norm' in gate: return [registry.get_id(f"{prefix}.div_step12.q_bit"), registry.get_id(f"{prefix}.need_norm")] if '.guard_bit' in gate: return [registry.get_id(f"{prefix}.guard_direct"), registry.get_id(f"{prefix}.guard_norm")] registry.register(f"{prefix}.guard_direct") registry.register(f"{prefix}.guard_norm") registry.register(f"{prefix}.guard_bit") if '.sticky_direct_or' in gate: return [registry.get_id(f"{prefix}.div_step12.q_bit"), registry.get_id(f"{prefix}.rem_nonzero")] if '.sticky_direct' in gate: return [registry.get_id(f"{prefix}.sticky_direct_or"), registry.get_id(f"{prefix}.not_need_norm")] if '.sticky_norm' in gate: return [registry.get_id(f"{prefix}.rem_nonzero"), registry.get_id(f"{prefix}.need_norm")] if '.sticky_bit' in gate: return [registry.get_id(f"{prefix}.sticky_direct"), registry.get_id(f"{prefix}.sticky_norm")] registry.register(f"{prefix}.sticky_direct_or") registry.register(f"{prefix}.sticky_direct") registry.register(f"{prefix}.sticky_norm") registry.register(f"{prefix}.sticky_bit") if '.round_or' in gate: return [registry.get_id(f"{prefix}.sticky_bit"), registry.get_id(f"{prefix}.norm_mant0")] if '.round_inc' in gate: return [registry.get_id(f"{prefix}.guard_bit"), registry.get_id(f"{prefix}.round_or")] registry.register(f"{prefix}.round_or") registry.register(f"{prefix}.round_inc") if '.round_add.fa' in gate: match = re.search(r'\.round_add\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.round_add.fa{i}" a_bit = registry.get_id(f"{prefix}.norm_mant{i}") b_bit = registry.get_id(f"{prefix}.round_inc") if i == 0 else registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.round_add.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(10): registry.register(f"{prefix}.round_add.fa{i}.xor2.layer2") registry.register(f"{prefix}.round_add.fa{i}.cout") if '.mant_overflow' in gate and '.not_' not in gate: return [registry.get_id(f"{prefix}.round_add.fa9.cout")] registry.register(f"{prefix}.mant_overflow") if '.not_mant_overflow' in gate: return [registry.get_id(f"{prefix}.mant_overflow")] registry.register(f"{prefix}.not_mant_overflow") if '.exp_inc.fa' in gate: match = re.search(r'\.exp_inc\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_inc.fa{i}" a_bit = registry.get_id(f"{prefix}.exp_norm.fa{i}.xor2.layer2") b_bit = registry.get_id("#0") if i == 0: cin = registry.get_id(f"{prefix}.mant_overflow") else: cin = registry.register(f"{prefix}.exp_inc.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.exp_inc.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_inc.fa{i}.cout") match = re.search(r'\.exp_out(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: return [registry.get_id(f"{prefix}.exp_inc.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.mant_overflow")] if '.normal_path' in gate: return [registry.get_id(f"{prefix}.exp_norm.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_mant_overflow")] match = re.search(r'\.exp_out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.exp_out{i}.overflow_path"), registry.register(f"{prefix}.exp_out{i}.normal_path")] match = re.search(r'\.mant_out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.round_add.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_mant_overflow")] if '.quot_implicit_norm' in gate: return [registry.get_id(f"{prefix}.div_step1.q_bit"), registry.get_id(f"{prefix}.need_norm")] if '.quot_implicit_direct' in gate: return [registry.get_id(f"{prefix}.div_step0.q_bit"), registry.get_id(f"{prefix}.not_need_norm")] if gate.endswith('.quot_implicit'): return [registry.get_id(f"{prefix}.quot_implicit_norm"), registry.get_id(f"{prefix}.quot_implicit_direct")] registry.register(f"{prefix}.quot_implicit_norm") registry.register(f"{prefix}.quot_implicit_direct") registry.register(f"{prefix}.quot_implicit") match = re.search(r'\.norm_full(\d+)$', gate) if match: i = int(match.group(1)) if i == 0: return [registry.get_id(f"{prefix}.guard_bit")] if i == 11: return [registry.get_id(f"{prefix}.quot_implicit")] return [registry.get_id(f"{prefix}.norm_mant{i-1}")] for i in range(12): registry.register(f"{prefix}.norm_full{i}") match = re.search(r'\.not_exp_norm(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.exp_norm.fa{i}.xor2.layer2")] for i in range(5): registry.register(f"{prefix}.not_exp_norm{i}") if '.sub_shift_base.fa' in gate: match = re.search(r'\.sub_shift_base\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_shift_base.fa{i}" const_1 = [1, 0, 0, 0, 0] a_bit = registry.get_id(f"#{const_1[i]}") b_bit = registry.get_id(f"{prefix}.not_exp_norm{i}") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.sub_shift_base.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.sub_shift_base.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_shift_base.fa{i}.cout") if '.sub_shift.fa' in gate: match = re.search(r'\.sub_shift\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_shift.fa{i}" a_bit = registry.get_id(f"{prefix}.sub_shift_base.fa{i}.xor2.layer2") b_bit = registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.sub_shift.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.sub_shift.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_shift.fa{i}.cout") match = re.search(r'\.not_sub_shift(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_shift.fa{i}.xor2.layer2")] for i in range(5): registry.register(f"{prefix}.not_sub_shift{i}") match = re.search(r'\.sub_rshift_s0_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.norm_full{i}"), registry.get_id(f"{prefix}.not_sub_shift0")] if '.shift' in gate and i < 11: return [registry.get_id(f"{prefix}.norm_full{i+1}"), registry.get_id(f"{prefix}.sub_shift.fa0.xor2.layer2")] match = re.search(r'\.sub_rshift_s0_(\d+)$', gate) if match: i = int(match.group(1)) if i < 11: return [registry.register(f"{prefix}.sub_rshift_s0_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s0_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s0_{i}.pass")] for i in range(12): registry.register(f"{prefix}.sub_rshift_s0_{i}") match = re.search(r'\.sub_rshift_s1_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s0_{i}"), registry.get_id(f"{prefix}.not_sub_shift1")] if '.shift' in gate and i < 10: return [registry.get_id(f"{prefix}.sub_rshift_s0_{i+2}"), registry.get_id(f"{prefix}.sub_shift.fa1.xor2.layer2")] match = re.search(r'\.sub_rshift_s1_(\d+)$', gate) if match: i = int(match.group(1)) if i < 10: return [registry.register(f"{prefix}.sub_rshift_s1_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s1_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s1_{i}.pass")] for i in range(12): registry.register(f"{prefix}.sub_rshift_s1_{i}") match = re.search(r'\.sub_rshift_s2_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s1_{i}"), registry.get_id(f"{prefix}.not_sub_shift2")] if '.shift' in gate and i < 8: return [registry.get_id(f"{prefix}.sub_rshift_s1_{i+4}"), registry.get_id(f"{prefix}.sub_shift.fa2.xor2.layer2")] match = re.search(r'\.sub_rshift_s2_(\d+)$', gate) if match: i = int(match.group(1)) if i < 8: return [registry.register(f"{prefix}.sub_rshift_s2_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s2_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s2_{i}.pass")] for i in range(12): registry.register(f"{prefix}.sub_rshift_s2_{i}") match = re.search(r'\.sub_rshift_s3_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: return [registry.get_id(f"{prefix}.sub_rshift_s2_{i}"), registry.get_id(f"{prefix}.not_sub_shift3")] if '.shift' in gate and i < 4: return [registry.get_id(f"{prefix}.sub_rshift_s2_{i+8}"), registry.get_id(f"{prefix}.sub_shift.fa3.xor2.layer2")] match = re.search(r'\.sub_rshift_s3_(\d+)$', gate) if match: i = int(match.group(1)) if i < 4: return [registry.register(f"{prefix}.sub_rshift_s3_{i}.pass"), registry.register(f"{prefix}.sub_rshift_s3_{i}.shift")] else: return [registry.register(f"{prefix}.sub_rshift_s3_{i}.pass")] for i in range(12): registry.register(f"{prefix}.sub_rshift_s3_{i}") match = re.search(r'\.sub_shifted(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_rshift_s3_{i}"), registry.get_id(f"{prefix}.not_sub_shift4")] for i in range(12): registry.register(f"{prefix}.sub_shifted{i}") match = re.search(r'\.sub_mant(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.sub_shifted{i+1}")] for i in range(10): registry.register(f"{prefix}.sub_mant{i}") if gate == f"{prefix}.sub_guard": return [registry.get_id(f"{prefix}.sub_shifted0")] registry.register(f"{prefix}.sub_guard") match = re.search(r'\.sub_shift_gt(\d+)$', gate) if match: return [registry.get_id(f"{prefix}.sub_shift.fa{i}.xor2.layer2") for i in range(5)] for k in range(12): registry.register(f"{prefix}.sub_shift_gt{k}") match = re.search(r'\.sub_sticky_part(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.norm_full{i}"), registry.get_id(f"{prefix}.sub_shift_gt{i}")] for i in range(12): registry.register(f"{prefix}.sub_sticky_part{i}") if gate.endswith('.sub_sticky_raw'): return [registry.get_id(f"{prefix}.sub_sticky_part{i}") for i in range(12)] registry.register(f"{prefix}.sub_sticky_raw") if gate.endswith('.sub_sticky'): return [registry.get_id(f"{prefix}.sub_sticky_raw"), registry.get_id(f"{prefix}.rem_nonzero")] registry.register(f"{prefix}.sub_sticky") if gate.endswith('.sub_round_lsb_or_sticky'): return [registry.get_id(f"{prefix}.sub_sticky"), registry.get_id(f"{prefix}.sub_mant0")] registry.register(f"{prefix}.sub_round_lsb_or_sticky") if gate.endswith('.sub_round_inc'): return [registry.get_id(f"{prefix}.sub_guard"), registry.get_id(f"{prefix}.sub_round_lsb_or_sticky")] registry.register(f"{prefix}.sub_round_inc") if '.sub_round.fa' in gate: match = re.search(r'\.sub_round\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.sub_round.fa{i}" a_bit = registry.get_id(f"{prefix}.sub_mant{i}") b_bit = registry.get_id(f"{prefix}.sub_round_inc") if i == 0 else registry.get_id("#0") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.sub_round.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(10): registry.register(f"{prefix}.sub_round.fa{i}.xor2.layer2") registry.register(f"{prefix}.sub_round.fa{i}.cout") if gate.endswith('.sub_round_overflow'): return [registry.get_id(f"{prefix}.sub_round.fa9.cout")] registry.register(f"{prefix}.sub_round_overflow") if '.not_result_is_inf' in gate: return [registry.get_id(f"{prefix}.result_is_inf")] if '.not_result_is_zero' in gate: return [registry.get_id(f"{prefix}.result_is_zero")] if '.not_exp_underflow' in gate: return [registry.get_id(f"{prefix}.exp_underflow")] registry.register(f"{prefix}.not_result_is_inf") registry.register(f"{prefix}.not_result_is_zero") registry.register(f"{prefix}.not_exp_underflow") if '.subnorm_enable' in gate: return [registry.get_id(f"{prefix}.exp_underflow"), registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_result_is_inf"), registry.get_id(f"{prefix}.not_result_is_zero")] registry.register(f"{prefix}.subnorm_enable") if '.is_normal_result' in gate: return [registry.get_id(f"{prefix}.not_result_is_nan"), registry.get_id(f"{prefix}.not_result_is_inf"), registry.get_id(f"{prefix}.not_result_is_zero"), registry.get_id(f"{prefix}.not_exp_underflow")] registry.register(f"{prefix}.is_normal_result") match = re.search(r'\.out_sub(\d+)$', gate) if match: i = int(match.group(1)) if i == 15: return [registry.get_id(f"{prefix}.result_sign.layer2")] elif i == 10: return [registry.get_id(f"{prefix}.sub_round_overflow")] elif 10 < i < 15: return [registry.get_id("#0")] else: return [registry.get_id(f"{prefix}.sub_round.fa{i}.xor2.layer2")] for i in range(16): registry.register(f"{prefix}.out_sub{i}") match = re.search(r'\.out(\d+)\.', gate) if match: i = int(match.group(1)) if '.nan_gate' in gate: # Canonical NaN = 0x7E00 = 0_11111_1000000000, bits 9-14 are 1 nan_bit = registry.get_id("#1") if (i >= 9 and i < 15) else registry.get_id("#0") return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")] if '.inf_gate' in gate: # Inf = 0x7C00 = 0_11111_0000000000, bits 10-14 are 1 if i == 15: inf_bit = registry.get_id(f"{prefix}.result_sign.layer2") else: inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0") return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")] if '.zero_gate' in gate: zero_bit = registry.get_id(f"{prefix}.result_sign.layer2") if i == 15 else registry.get_id("#0") return [zero_bit, registry.get_id(f"{prefix}.result_is_zero")] if '.normal_gate' in gate: if i < 10: normal_bit = registry.get_id(f"{prefix}.mant_out{i}") elif i < 15: normal_bit = registry.get_id(f"{prefix}.exp_out{i-10}") else: normal_bit = registry.get_id(f"{prefix}.result_sign.layer2") return [normal_bit, registry.get_id(f"{prefix}.is_normal_result")] if '.sub_gate' in gate: return [registry.register(f"{prefix}.out_sub{i}"), registry.get_id(f"{prefix}.subnorm_enable")] match = re.search(r'\.out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.out{i}.nan_gate"), registry.register(f"{prefix}.out{i}.inf_gate"), registry.register(f"{prefix}.out{i}.zero_gate"), registry.register(f"{prefix}.out{i}.normal_gate"), registry.register(f"{prefix}.out{i}.sub_gate")] return [] def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.toint circuit (with right-shift barrel shifter).""" prefix = "float16.toint" for i in range(16): registry.register(f"{prefix}.$x[{i}]") exp_bits = [f"{prefix}.$x[{10+i}]" for i in range(5)] mant_bits = [f"{prefix}.$x[{i}]" for i in range(10)] # === SPECIAL CASE DETECTION === if '.exp_all_ones' in gate: return [registry.get_id(b) for b in exp_bits] if '.exp_zero' in gate: return [registry.get_id(b) for b in exp_bits] registry.register(f"{prefix}.exp_all_ones") registry.register(f"{prefix}.exp_zero") if '.mant_nonzero' in gate: return [registry.get_id(b) for b in mant_bits] registry.register(f"{prefix}.mant_nonzero") if '.is_nan' in gate: return [registry.get_id(f"{prefix}.exp_all_ones"), registry.get_id(f"{prefix}.mant_nonzero")] if '.mant_zero' in gate: return [registry.get_id(f"{prefix}.mant_nonzero")] registry.register(f"{prefix}.mant_zero") registry.register(f"{prefix}.is_nan") if '.is_inf' in gate: return [registry.get_id(f"{prefix}.exp_all_ones"), registry.get_id(f"{prefix}.mant_zero")] if '.is_zero' in gate and '.not_' not in gate and '.result_is_zero' not in gate: return [registry.get_id(f"{prefix}.exp_zero"), registry.get_id(f"{prefix}.mant_zero")] registry.register(f"{prefix}.is_inf") registry.register(f"{prefix}.is_zero") if '.exp_lt_15' in gate: return [registry.get_id(b) for b in exp_bits] registry.register(f"{prefix}.exp_lt_15") if '.result_is_zero' in gate and '.not_' not in gate: return [registry.get_id(f"{prefix}.is_nan"), registry.get_id(f"{prefix}.is_zero"), registry.get_id(f"{prefix}.exp_lt_15")] registry.register(f"{prefix}.result_is_zero") if '.not_result_is_zero' in gate: return [registry.get_id(f"{prefix}.result_is_zero")] registry.register(f"{prefix}.not_result_is_zero") if '.implicit_bit' in gate: return [registry.get_id(f"{prefix}.exp_zero")] registry.register(f"{prefix}.implicit_bit") # === THRESHOLD GATES FOR SHIFT CONTROL === if '.exp_ge_15' in gate: return [registry.get_id(b) for b in exp_bits] if '.exp_ge_26' in gate: return [registry.get_id(b) for b in exp_bits] if '.exp_ge_18' in gate: return [registry.get_id(b) for b in exp_bits] if '.exp_le_21' in gate: return [registry.get_id(b) for b in exp_bits] registry.register(f"{prefix}.exp_ge_15") registry.register(f"{prefix}.exp_ge_26") if '.not_exp_ge_26' in gate: return [registry.get_id(f"{prefix}.exp_ge_26")] registry.register(f"{prefix}.not_exp_ge_26") registry.register(f"{prefix}.exp_ge_18") registry.register(f"{prefix}.exp_le_21") if '.shift_bit3' in gate: return [registry.get_id(b) for b in exp_bits] if '.shift_bit2' in gate: return [registry.get_id(f"{prefix}.exp_ge_18"), registry.get_id(f"{prefix}.exp_le_21")] registry.register(f"{prefix}.shift_bit3") registry.register(f"{prefix}.shift_bit2") # === NOT OF EXPONENT BITS === for i in range(5): if f'.not_exp{i}' in gate: return [registry.get_id(exp_bits[i])] registry.register(f"{prefix}.not_exp{i}") # === SHIFT CALCULATION: 25 - exp = ~exp + 26 === # 26 = 0b011010 const_26 = [0, 1, 0, 1, 1, 0] if '.shift_calc.fa' in gate: match = re.search(r'\.shift_calc\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.shift_calc.fa{i}" # a = ~exp[i] (or 1 for i >= 5) a_bit = registry.get_id(f"{prefix}.not_exp{i}") if i < 5 else registry.get_id("#1") # b = const_26[i] b_bit = registry.get_id(f"#{const_26[i]}") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.shift_calc.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(6): registry.register(f"{prefix}.shift_calc.fa{i}.xor2.layer2") registry.register(f"{prefix}.shift_calc.fa{i}.cout") # === EXP_MINUS_25: exp + 7 === const_7 = [1, 1, 1, 0, 0] if '.exp_minus_25.fa' in gate: match = re.search(r'\.exp_minus_25\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_minus_25.fa{i}" a_bit = registry.get_id(exp_bits[i]) if i < 5 else registry.get_id("#0") b_bit = registry.get_id(f"#{const_7[i]}") cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.exp_minus_25.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.exp_minus_25.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_minus_25.fa{i}.cout") # === RIGHT-SHIFT BARREL SHIFTER === for stage in range(4): shift_amt = 1 << stage if f'.not_shift{stage}' in gate: return [registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")] registry.register(f"{prefix}.not_shift{stage}") match = re.search(rf'\.rshift_s{stage}_(\d+)\.', gate) if match: i = int(match.group(1)) src_pos = i + shift_amt if '.pass' in gate: # Current value (from previous stage or input) if stage == 0: if i < 10: val = registry.get_id(mant_bits[i]) elif i == 10: val = registry.get_id(f"{prefix}.implicit_bit") else: val = registry.get_id("#0") else: val = registry.get_id(f"{prefix}.rshift_s{stage-1}_{i}") return [val, registry.get_id(f"{prefix}.not_shift{stage}")] if '.shift' in gate and src_pos < 16: # Value from higher position if stage == 0: if src_pos < 10: val = registry.get_id(mant_bits[src_pos]) elif src_pos == 10: val = registry.get_id(f"{prefix}.implicit_bit") else: val = registry.get_id("#0") else: val = registry.get_id(f"{prefix}.rshift_s{stage-1}_{src_pos}") return [val, registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")] match = re.search(rf'\.rshift_s{stage}_(\d+)$', gate) if match: i = int(match.group(1)) src_pos = i + shift_amt if src_pos < 16: return [registry.register(f"{prefix}.rshift_s{stage}_{i}.pass"), registry.register(f"{prefix}.rshift_s{stage}_{i}.shift")] else: return [registry.register(f"{prefix}.rshift_s{stage}_{i}.pass")] for i in range(16): registry.register(f"{prefix}.rshift_s{stage}_{i}") # === LEFT SHIFT (exp > 25) === for stage in range(3): shift_amt = 1 << stage if f'.not_lshift{stage}' in gate: return [registry.get_id(f"{prefix}.exp_minus_25.fa{stage}.xor2.layer2")] registry.register(f"{prefix}.not_lshift{stage}") match = re.search(rf'\.lshift_s{stage}_(\d+)\.', gate) if match: i = int(match.group(1)) src_pos = i - shift_amt if '.pass' in gate: if stage == 0: if i < 10: val = registry.get_id(mant_bits[i]) elif i == 10: val = registry.get_id(f"{prefix}.implicit_bit") else: val = registry.get_id("#0") else: val = registry.get_id(f"{prefix}.lshift_s{stage-1}_{i}") return [val, registry.get_id(f"{prefix}.not_lshift{stage}")] if '.shift' in gate and src_pos >= 0: if stage == 0: if src_pos < 10: val = registry.get_id(mant_bits[src_pos]) elif src_pos == 10: val = registry.get_id(f"{prefix}.implicit_bit") else: val = registry.get_id("#0") else: val = registry.get_id(f"{prefix}.lshift_s{stage-1}_{src_pos}") return [val, registry.get_id(f"{prefix}.exp_minus_25.fa{stage}.xor2.layer2")] match = re.search(rf'\.lshift_s{stage}_(\d+)$', gate) if match: i = int(match.group(1)) src_pos = i - shift_amt if src_pos >= 0: return [registry.register(f"{prefix}.lshift_s{stage}_{i}.pass"), registry.register(f"{prefix}.lshift_s{stage}_{i}.shift")] else: return [registry.register(f"{prefix}.lshift_s{stage}_{i}.pass")] for i in range(16): registry.register(f"{prefix}.lshift_s{stage}_{i}") # Magnitude selector match = re.search(r'\.mag_sel(\d+)\.', gate) if match: i = int(match.group(1)) if '.left' in gate: return [registry.get_id(f"{prefix}.lshift_s2_{i}"), registry.get_id(f"{prefix}.exp_ge_26")] if '.right' in gate: return [registry.get_id(f"{prefix}.rshift_s3_{i}"), registry.get_id(f"{prefix}.not_exp_ge_26")] match = re.search(r'\.mag_sel(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.mag_sel{i}.left"), registry.register(f"{prefix}.mag_sel{i}.right")] # === NEGATION === match = re.search(r'\.not_mag(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.mag_sel{i}")] for i in range(16): registry.register(f"{prefix}.not_mag{i}") if '.neg.fa' in gate: match = re.search(r'\.neg\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.neg.fa{i}" a_bit = registry.get_id(f"{prefix}.not_mag{i}") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.neg.fa{i-1}.cout") if '.xor.layer1' in gate: return [a_bit, cin] if '.xor.layer2' in gate: return [registry.register(f"{fa_prefix}.xor.layer1.or"), registry.register(f"{fa_prefix}.xor.layer1.nand")] if '.and' in gate: return [a_bit, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and"), registry.get_id("#0")] for i in range(16): registry.register(f"{prefix}.neg.fa{i}.xor.layer2") registry.register(f"{prefix}.neg.fa{i}.cout") # === OUTPUT SELECTION === match = re.search(r'\.out(\d+)\.', gate) if match: i = int(match.group(1)) sign = registry.get_id(f"{prefix}.$x[15]") not_sign = registry.register(f"{prefix}.not_sign") not_result_zero = registry.get_id(f"{prefix}.not_result_is_zero") if '.pos_path' in gate: return [registry.register(f"{prefix}.mag_sel{i}"), not_sign, not_result_zero] if '.neg_path' in gate: return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"), sign, not_result_zero] # not_sign gate if '.not_sign' in gate: return [registry.get_id(f"{prefix}.$x[15]")] registry.register(f"{prefix}.not_sign") match = re.search(r'\.out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.out{i}.pos_path"), registry.register(f"{prefix}.out{i}.neg_path")] return [] def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.fromint circuit.""" prefix = "float16.fromint" for i in range(16): registry.register(f"{prefix}.$x[{i}]") in_bits = [f"{prefix}.$x[{i}]" for i in range(16)] if '.not_is_zero' in gate: return [registry.get_id(f"{prefix}.is_zero")] if '.is_zero' in gate: return [registry.get_id(b) for b in in_bits] if '.is_negative' in gate: return [registry.get_id(f"{prefix}.$x[15]")] if '.not_negative' in gate: return [registry.get_id(f"{prefix}.is_negative")] registry.register(f"{prefix}.is_zero") registry.register(f"{prefix}.not_is_zero") registry.register(f"{prefix}.is_negative") registry.register(f"{prefix}.not_negative") match = re.search(r'\.not_in(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(in_bits[i])] for i in range(16): registry.register(f"{prefix}.not_in{i}") if '.abs.fa' in gate: match = re.search(r'\.abs\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.abs.fa{i}" a_bit = registry.get_id(f"{prefix}.not_in{i}") cin = registry.get_id(f"{prefix}.is_negative") if i == 0 else registry.register(f"{prefix}.abs.fa{i-1}.cout") if '.xor.layer1' in gate: return [a_bit, cin] if '.xor.layer2' in gate: return [registry.register(f"{fa_prefix}.xor.layer1.or"), registry.register(f"{fa_prefix}.xor.layer1.nand")] if '.and' in gate: return [a_bit, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and"), registry.get_id("#0")] for i in range(16): registry.register(f"{prefix}.abs.fa{i}.xor.layer2") registry.register(f"{prefix}.abs.fa{i}.cout") match = re.search(r'\.abs(\d+)\.', gate) if match: i = int(match.group(1)) if '.neg_path' in gate: return [registry.get_id(f"{prefix}.abs.fa{i}.xor.layer2"), registry.get_id(f"{prefix}.is_negative")] if '.pos_path' in gate: return [registry.get_id(in_bits[i]), registry.get_id(f"{prefix}.not_negative")] match = re.search(r'\.abs(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.abs{i}.neg_path"), registry.register(f"{prefix}.abs{i}.pos_path")] for i in range(16): registry.register(f"{prefix}.abs{i}") abs_bits = [f"{prefix}.abs{i}" for i in range(16)] match = re.search(r'\.pz(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(abs_bits[15-i]) for i in range(k)] for k in range(1, 17): registry.register(f"{prefix}.pz{k}") pz_ids = [registry.get_id(f"{prefix}.pz{k}") for k in range(1, 17)] match = re.search(r'\.ge(\d+)$', gate) if match: return pz_ids for k in range(1, 17): registry.register(f"{prefix}.ge{k}") match = re.search(r'\.not_ge(\d+)$', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.ge{k}")] for k in [2, 4, 6, 8, 10, 12, 14, 16]: registry.register(f"{prefix}.not_ge{k}") if '.clz3' in gate: return [registry.get_id(f"{prefix}.ge8")] if '.clz_and_4_7' in gate: return [registry.get_id(f"{prefix}.ge4"), registry.get_id(f"{prefix}.not_ge8")] if '.clz_and_12_15' in gate: return [registry.get_id(f"{prefix}.ge12"), registry.get_id(f"{prefix}.not_ge16")] if '.clz2' in gate: return [registry.get_id(f"{prefix}.clz_and_4_7"), registry.get_id(f"{prefix}.clz_and_12_15")] registry.register(f"{prefix}.clz3") registry.register(f"{prefix}.clz_and_4_7") registry.register(f"{prefix}.clz_and_12_15") registry.register(f"{prefix}.clz2") if '.clz_and_2_3' in gate: return [registry.get_id(f"{prefix}.ge2"), registry.get_id(f"{prefix}.not_ge4")] if '.clz_and_6_7' in gate: return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")] if '.clz_and_10_11' in gate: return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.not_ge12")] if '.clz_and_14_15' in gate: return [registry.get_id(f"{prefix}.ge14"), registry.get_id(f"{prefix}.not_ge16")] if '.clz1' in gate: return [registry.get_id(f"{prefix}.clz_and_2_3"), registry.get_id(f"{prefix}.clz_and_6_7"), registry.get_id(f"{prefix}.clz_and_10_11"), registry.get_id(f"{prefix}.clz_and_14_15")] registry.register(f"{prefix}.clz_and_2_3") registry.register(f"{prefix}.clz_and_6_7") registry.register(f"{prefix}.clz_and_10_11") registry.register(f"{prefix}.clz_and_14_15") registry.register(f"{prefix}.clz1") match = re.search(r'\.clz_and_(\d+)$', gate) if match: i = int(match.group(1)) if i in [1, 3, 5, 7, 9, 11, 13, 15]: return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{i+1}")] for i in [1, 3, 5, 7, 9, 11, 13, 15]: registry.register(f"{prefix}.clz_and_{i}") if '.clz0' in gate: return [registry.get_id(f"{prefix}.clz_and_{i}") for i in [1, 3, 5, 7, 9, 11, 13, 15]] for i in range(4): registry.register(f"{prefix}.clz{i}") for i in range(5): if f'.not_clz{i}' in gate: return [registry.get_id(f"{prefix}.clz{i}") if i < 4 else registry.get_id("#0")] registry.register(f"{prefix}.not_clz{i}") bits_30 = [0, 1, 1, 1, 1, 0] if '.exp_calc.fa' in gate: match = re.search(r'\.exp_calc\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_calc.fa{i}" a_bit = registry.get_id(f"#{bits_30[i]}") b_bit = registry.get_id(f"{prefix}.not_clz{i}") if i < 5 else registry.get_id("#1") cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.exp_calc.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.exp_calc.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_calc.fa{i}.cout") for stage in range(4): shift_amt = 1 << stage if f'.not_norm_shift{stage}' in gate: return [registry.get_id(f"{prefix}.clz{stage}")] registry.register(f"{prefix}.not_norm_shift{stage}") match = re.search(rf'\.norm_s{stage}_(\d+)\.', gate) if match: i = int(match.group(1)) if '.pass' in gate: if stage == 0: val = registry.get_id(abs_bits[i]) else: val = registry.get_id(f"{prefix}.norm_s{stage-1}_{i}") return [val, registry.get_id(f"{prefix}.not_norm_shift{stage}")] if '.shift' in gate and i >= shift_amt: if stage == 0: val = registry.get_id(abs_bits[i - shift_amt]) else: val = registry.get_id(f"{prefix}.norm_s{stage-1}_{i-shift_amt}") return [val, registry.get_id(f"{prefix}.clz{stage}")] match = re.search(rf'\.norm_s{stage}_(\d+)$', gate) if match: i = int(match.group(1)) if i >= shift_amt: return [registry.register(f"{prefix}.norm_s{stage}_{i}.pass"), registry.register(f"{prefix}.norm_s{stage}_{i}.shift")] else: return [registry.register(f"{prefix}.norm_s{stage}_{i}.pass")] for i in range(16): registry.register(f"{prefix}.norm_s{stage}_{i}") # === ROUNDING (guard/round/sticky) === if '.guard_bit' in gate: return [registry.get_id(f"{prefix}.norm_s3_4")] registry.register(f"{prefix}.guard_bit") if '.round_bit' in gate: return [registry.get_id(f"{prefix}.norm_s3_3")] registry.register(f"{prefix}.round_bit") if '.sticky_bit' in gate: return [registry.get_id(f"{prefix}.norm_s3_0"), registry.get_id(f"{prefix}.norm_s3_1"), registry.get_id(f"{prefix}.norm_s3_2")] registry.register(f"{prefix}.sticky_bit") if '.round_or' in gate: return [registry.get_id(f"{prefix}.round_bit"), registry.get_id(f"{prefix}.sticky_bit"), registry.get_id(f"{prefix}.norm_s3_5")] registry.register(f"{prefix}.round_or") if '.round_inc' in gate: return [registry.get_id(f"{prefix}.guard_bit"), registry.get_id(f"{prefix}.round_or")] registry.register(f"{prefix}.round_inc") # Mantissa rounding adder (10 bits) if '.round_add.fa' in gate: match = re.search(r'\.round_add\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.round_add.fa{i}" a_bit = registry.get_id(f"{prefix}.norm_s3_{i+5}") b_bit = registry.get_id("#0") if i == 0: cin = registry.get_id(f"{prefix}.round_inc") else: cin = registry.register(f"{prefix}.round_add.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(10): registry.register(f"{prefix}.round_add.fa{i}.xor2.layer2") registry.register(f"{prefix}.round_add.fa{i}.cout") if '.mant_overflow' in gate and '.not_' not in gate: return [registry.get_id(f"{prefix}.round_add.fa9.cout")] registry.register(f"{prefix}.mant_overflow") if '.not_mant_overflow' in gate: return [registry.get_id(f"{prefix}.mant_overflow")] registry.register(f"{prefix}.not_mant_overflow") # Exponent increment when mantissa overflows if '.exp_inc.fa' in gate: match = re.search(r'\.exp_inc\.fa(\d+)\.', gate) if match: i = int(match.group(1)) fa_prefix = f"{prefix}.exp_inc.fa{i}" a_bit = registry.get_id(f"{prefix}.exp_calc.fa{i}.xor2.layer2") b_bit = registry.get_id("#0") if i == 0: cin = registry.get_id(f"{prefix}.mant_overflow") else: cin = registry.register(f"{prefix}.exp_inc.fa{i-1}.cout") if '.xor1.layer1' in gate: return [a_bit, b_bit] if '.xor1.layer2' in gate: return [registry.register(f"{fa_prefix}.xor1.layer1.or"), registry.register(f"{fa_prefix}.xor1.layer1.nand")] xor1 = registry.register(f"{fa_prefix}.xor1.layer2") if '.xor2.layer1' in gate: return [xor1, cin] if '.xor2.layer2' in gate: return [registry.register(f"{fa_prefix}.xor2.layer1.or"), registry.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, b_bit] if '.and2' in gate: return [xor1, cin] if '.cout' in gate: return [registry.register(f"{fa_prefix}.and1"), registry.register(f"{fa_prefix}.and2")] for i in range(5): registry.register(f"{prefix}.exp_inc.fa{i}.xor2.layer2") registry.register(f"{prefix}.exp_inc.fa{i}.cout") # Exponent output mux match = re.search(r'\.exp_out(\d+)\.', gate) if match: i = int(match.group(1)) if '.overflow_path' in gate: return [registry.get_id(f"{prefix}.exp_inc.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.mant_overflow")] if '.normal_path' in gate: return [registry.get_id(f"{prefix}.exp_calc.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_mant_overflow")] match = re.search(r'\.exp_out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.register(f"{prefix}.exp_out{i}.overflow_path"), registry.register(f"{prefix}.exp_out{i}.normal_path")] # Mantissa output (zero on overflow) match = re.search(r'\.mant_out(\d+)$', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.round_add.fa{i}.xor2.layer2"), registry.get_id(f"{prefix}.not_mant_overflow")] match = re.search(r'\.out(\d+)\.', gate) if match: i = int(match.group(1)) if '.zero_gate' in gate: return [registry.get_id("#0"), registry.get_id(f"{prefix}.is_zero")] if '.normal_gate' in gate: if i < 10: val = registry.register(f"{prefix}.mant_out{i}") elif i < 15: val = registry.register(f"{prefix}.exp_out{i-10}") else: val = registry.get_id(f"{prefix}.is_negative") not_zero = registry.get_id(f"{prefix}.not_is_zero") return [val, not_zero] match = re.search(r'\.out(\d+)$', gate) if match: i = int(match.group(1)) if i < 15: return [registry.register(f"{prefix}.out{i}.zero_gate"), registry.register(f"{prefix}.out{i}.normal_gate")] else: return [registry.get_id(f"{prefix}.is_negative")] return [] def infer_float16_neg_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.neg circuit.""" prefix = "float16.neg" # Register 16-bit input for i in range(16): registry.register(f"{prefix}.$x[{i}]") # Output gates match = re.search(r'\.out(\d+)', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.$x[{i}]")] return [] def infer_float16_abs_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.abs circuit.""" prefix = "float16.abs" # Register 16-bit input for i in range(16): registry.register(f"{prefix}.$x[{i}]") # Output gates match = re.search(r'\.out(\d+)', gate) if match: i = int(match.group(1)) if i == 15: # Sign bit output doesn't depend on input (always 0) # But we still need an input for the gate structure return [registry.get_id(f"{prefix}.$x[15]")] return [registry.get_id(f"{prefix}.$x[{i}]")] return [] def infer_float32_neg_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float32.neg circuit.""" prefix = "float32.neg" # Register 32-bit input for i in range(32): registry.register(f"{prefix}.$x[{i}]") # Output gates match = re.search(r'\.out(\d+)', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.$x[{i}]")] return [] def infer_float32_abs_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float32.abs circuit.""" prefix = "float32.abs" # Register 32-bit input for i in range(32): registry.register(f"{prefix}.$x[{i}]") # Output gates match = re.search(r'\.out(\d+)', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.$x[{i}]")] return [] def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.normalize circuit.""" prefix = "float16.normalize" # Register 13-bit mantissa input for i in range(13): registry.register(f"{prefix}.$m[{i}]") # Overflow detection (bit 12) if '.overflow' in gate and '.not_overflow' not in gate: return [registry.get_id(f"{prefix}.$m[12]")] registry.register(f"{prefix}.overflow") # is_zero (NOR of all mantissa bits) if '.is_zero' in gate: return [registry.get_id(f"{prefix}.$m[{i}]") for i in range(13)] # pz gates (CLZ on bits 11:0) if '.pz' in gate: match = re.search(r'\.pz(\d+)', gate) if match: k = int(match.group(1)) # Check top k bits of m[11:0] return [registry.get_id(f"{prefix}.$m[{11-i}]") for i in range(k)] # Register pz outputs for i in range(1, 13): registry.register(f"{prefix}.pz{i}") pz_ids = [registry.get_id(f"{prefix}.pz{i}") for i in range(1, 13)] # ge gates if '.ge' in gate and '.not_ge' not in gate: match = re.search(r'\.ge(\d+)', gate) if match: return pz_ids # Register ge outputs for k in range(1, 13): registry.register(f"{prefix}.ge{k}") # NOT gates if '.not_ge' in gate: match = re.search(r'\.not_ge(\d+)', gate) if match: k = int(match.group(1)) return [registry.get_id(f"{prefix}.ge{k}")] for k in [2, 4, 6, 8, 10, 12]: registry.register(f"{prefix}.not_ge{k}") # AND gates for ranges if '.and_4_7' in gate: return [registry.get_id(f"{prefix}.ge4"), registry.get_id(f"{prefix}.not_ge8")] if '.and_2_3' in gate: return [registry.get_id(f"{prefix}.ge2"), registry.get_id(f"{prefix}.not_ge4")] if '.and_6_7' in gate: return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")] if '.and_10_11' in gate: return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.not_ge12")] # Odd AND gates match = re.search(r'\.and_(\d+)$', gate) if match: i = int(match.group(1)) if i in [1, 3, 5, 7, 9, 11]: next_even = i + 1 if next_even in [2, 4, 8]: return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{next_even}")] else: # Need to register not_ge for this value registry.register(f"{prefix}.not_ge{next_even}") return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{next_even}")] # Register AND outputs for name in ['and_4_7', 'and_2_3', 'and_6_7', 'and_10_11']: registry.register(f"{prefix}.{name}") for i in [1, 3, 5, 7, 9, 11]: registry.register(f"{prefix}.and_{i}") # Shift bit gates if '.shift3' in gate: return [registry.get_id(f"{prefix}.ge8")] if '.shift2' in gate: return [registry.get_id(f"{prefix}.and_4_7"), registry.get_id(f"{prefix}.ge12")] if '.shift1' in gate: return [registry.get_id(f"{prefix}.and_2_3"), registry.get_id(f"{prefix}.and_6_7"), registry.get_id(f"{prefix}.and_10_11")] if '.shift0' in gate: return [registry.get_id(f"{prefix}.and_{i}") for i in [1, 3, 5, 7, 9, 11]] for i in range(4): registry.register(f"{prefix}.shift{i}") # not_overflow if '.not_overflow' in gate: return [registry.get_id(f"{prefix}.overflow")] registry.register(f"{prefix}.not_overflow") # Output shift bits (masked by not_overflow) if '.out_shift' in gate: match = re.search(r'\.out_shift(\d+)', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.shift{i}"), registry.get_id(f"{prefix}.not_overflow")] return [] def infer_float16_cmp_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.cmp circuit.""" prefix = "float16.cmp" # Register inputs: 16 bits for a, 16 bits for b for i in range(16): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # Sign extraction if '.sign_a' in gate: return [registry.get_id(f"{prefix}.$a[15]")] if '.sign_b' in gate: return [registry.get_id(f"{prefix}.$b[15]")] # Register sign outputs registry.register(f"{prefix}.sign_a") registry.register(f"{prefix}.sign_b") # NOT sign gates if '.not_sign_a' in gate: return [registry.get_id(f"{prefix}.sign_a")] if '.not_sign_b' in gate: return [registry.get_id(f"{prefix}.sign_b")] registry.register(f"{prefix}.not_sign_a") registry.register(f"{prefix}.not_sign_b") # Magnitude comparison (bits 14-0 of both) if '.mag_cmp' in gate: inputs = [] for i in range(15): inputs.append(registry.get_id(f"{prefix}.$a[{i}]")) for i in range(15): inputs.append(registry.get_id(f"{prefix}.$b[{i}]")) return inputs registry.register(f"{prefix}.mag_cmp") # a_gt_b_mag (pass-through from mag_cmp) if '.a_gt_b_mag' in gate: return [registry.get_id(f"{prefix}.mag_cmp")] # b_gt_a_mag (reversed comparison) if '.b_gt_a_mag' in gate: inputs = [] for i in range(15): inputs.append(registry.get_id(f"{prefix}.$b[{i}]")) for i in range(15): inputs.append(registry.get_id(f"{prefix}.$a[{i}]")) return inputs registry.register(f"{prefix}.a_gt_b_mag") registry.register(f"{prefix}.b_gt_a_mag") # both_pos_gt: AND(not_sign_a, not_sign_b, a_gt_b_mag) if '.both_pos_gt' in gate: return [registry.get_id(f"{prefix}.not_sign_a"), registry.get_id(f"{prefix}.not_sign_b"), registry.get_id(f"{prefix}.a_gt_b_mag")] # both_neg_gt: AND(sign_a, sign_b, b_gt_a_mag) if '.both_neg_gt' in gate: return [registry.get_id(f"{prefix}.sign_a"), registry.get_id(f"{prefix}.sign_b"), registry.get_id(f"{prefix}.b_gt_a_mag")] # mag_a_nonzero: OR of bits 0-14 of a if '.mag_a_nonzero' in gate: return [registry.get_id(f"{prefix}.$a[{i}]") for i in range(15)] # mag_b_nonzero: OR of bits 0-14 of b if '.mag_b_nonzero' in gate: return [registry.get_id(f"{prefix}.$b[{i}]") for i in range(15)] registry.register(f"{prefix}.mag_a_nonzero") registry.register(f"{prefix}.mag_b_nonzero") # either_nonzero: OR(mag_a_nonzero, mag_b_nonzero) if '.either_nonzero' in gate: return [registry.get_id(f"{prefix}.mag_a_nonzero"), registry.get_id(f"{prefix}.mag_b_nonzero")] registry.register(f"{prefix}.either_nonzero") # a_pos_b_neg: AND(not_sign_a, sign_b, either_nonzero) if '.a_pos_b_neg' in gate: return [registry.get_id(f"{prefix}.not_sign_a"), registry.get_id(f"{prefix}.sign_b"), registry.get_id(f"{prefix}.either_nonzero")] registry.register(f"{prefix}.both_pos_gt") registry.register(f"{prefix}.both_neg_gt") registry.register(f"{prefix}.a_pos_b_neg") # Final gt: OR(both_pos_gt, both_neg_gt, a_pos_b_neg) if '.gt' in gate: return [registry.get_id(f"{prefix}.both_pos_gt"), registry.get_id(f"{prefix}.both_neg_gt"), registry.get_id(f"{prefix}.a_pos_b_neg")] return [] def infer_float32_cmp_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float32.cmp circuit.""" prefix = "float32.cmp" # Register 32-bit inputs for i in range(32): registry.register(f"{prefix}.$a[{i}]") registry.register(f"{prefix}.$b[{i}]") # Sign bits if '.sign_a' in gate: return [registry.get_id(f"{prefix}.$a[31]")] if '.sign_b' in gate: return [registry.get_id(f"{prefix}.$b[31]")] # NOT sign gates if '.not_sign_a' in gate: return [registry.get_id(f"{prefix}.sign_a")] if '.not_sign_b' in gate: return [registry.get_id(f"{prefix}.sign_b")] registry.register(f"{prefix}.sign_a") registry.register(f"{prefix}.sign_b") registry.register(f"{prefix}.not_sign_a") registry.register(f"{prefix}.not_sign_b") # Magnitude comparison (bits 30-0 of both) if '.mag_cmp' in gate: inputs = [] for i in range(31): inputs.append(registry.get_id(f"{prefix}.$a[{i}]")) for i in range(31): inputs.append(registry.get_id(f"{prefix}.$b[{i}]")) return inputs registry.register(f"{prefix}.mag_cmp") # a_gt_b_mag (pass-through) if '.a_gt_b_mag' in gate: return [registry.get_id(f"{prefix}.mag_cmp")] # b_gt_a_mag (reversed comparison) if '.b_gt_a_mag' in gate: inputs = [] for i in range(31): inputs.append(registry.get_id(f"{prefix}.$b[{i}]")) for i in range(31): inputs.append(registry.get_id(f"{prefix}.$a[{i}]")) return inputs registry.register(f"{prefix}.a_gt_b_mag") registry.register(f"{prefix}.b_gt_a_mag") # both_pos_gt: AND(not_sign_a, not_sign_b, a_gt_b_mag) if '.both_pos_gt' in gate: return [registry.get_id(f"{prefix}.not_sign_a"), registry.get_id(f"{prefix}.not_sign_b"), registry.get_id(f"{prefix}.a_gt_b_mag")] # both_neg_gt: AND(sign_a, sign_b, b_gt_a_mag) if '.both_neg_gt' in gate: return [registry.get_id(f"{prefix}.sign_a"), registry.get_id(f"{prefix}.sign_b"), registry.get_id(f"{prefix}.b_gt_a_mag")] # mag_a_nonzero: OR of bits 0-30 of a if '.mag_a_nonzero' in gate: return [registry.get_id(f"{prefix}.$a[{i}]") for i in range(31)] # mag_b_nonzero: OR of bits 0-30 of b if '.mag_b_nonzero' in gate: return [registry.get_id(f"{prefix}.$b[{i}]") for i in range(31)] registry.register(f"{prefix}.mag_a_nonzero") registry.register(f"{prefix}.mag_b_nonzero") # either_nonzero: OR(mag_a_nonzero, mag_b_nonzero) if '.either_nonzero' in gate: return [registry.get_id(f"{prefix}.mag_a_nonzero"), registry.get_id(f"{prefix}.mag_b_nonzero")] registry.register(f"{prefix}.either_nonzero") # a_pos_b_neg: AND(not_sign_a, sign_b, either_nonzero) if '.a_pos_b_neg' in gate: return [registry.get_id(f"{prefix}.not_sign_a"), registry.get_id(f"{prefix}.sign_b"), registry.get_id(f"{prefix}.either_nonzero")] # Final gt: OR(both_pos_gt, both_neg_gt, a_pos_b_neg) if '.gt' in gate: return [registry.get_id(f"{prefix}.both_pos_gt"), registry.get_id(f"{prefix}.both_neg_gt"), registry.get_id(f"{prefix}.a_pos_b_neg")] return [] def infer_float16_pack_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.pack circuit.""" prefix = "float16.pack" # Register inputs: sign, exp[0:4], mant[0:9] registry.register(f"{prefix}.$sign") for i in range(5): registry.register(f"{prefix}.$exp[{i}]") for i in range(10): registry.register(f"{prefix}.$mant[{i}]") # Output bits if '.out' in gate: match = re.search(r'\.out(\d+)', gate) if match: i = int(match.group(1)) if i == 15: return [registry.get_id(f"{prefix}.$sign")] elif i >= 10: return [registry.get_id(f"{prefix}.$exp[{i-10}]")] else: return [registry.get_id(f"{prefix}.$mant[{i}]")] return [] def infer_float16_unpack_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float16.unpack circuit.""" prefix = "float16.unpack" # Register 16-bit input for i in range(16): registry.register(f"{prefix}.$x[{i}]") # Sign bit (bit 15) if '.sign' in gate: return [registry.get_id(f"{prefix}.$x[15]")] # Exponent bits (bits 14-10) if '.exp' in gate: match = re.search(r'\.exp(\d+)', gate) if match: i = int(match.group(1)) # exp0 = bit 10, exp1 = bit 11, ..., exp4 = bit 14 return [registry.get_id(f"{prefix}.$x[{10+i}]")] # Mantissa bits (bits 9-0) if '.mant' in gate: match = re.search(r'\.mant(\d+)', gate) if match: i = int(match.group(1)) # mant0 = bit 0, mant1 = bit 1, ..., mant9 = bit 9 return [registry.get_id(f"{prefix}.$x[{i}]")] return [] def infer_float32_pack_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float32.pack circuit.""" prefix = "float32.pack" # Register inputs: sign, exp[0:7], mant[0:22] registry.register(f"{prefix}.$sign") for i in range(8): registry.register(f"{prefix}.$exp[{i}]") for i in range(23): registry.register(f"{prefix}.$mant[{i}]") # Output bits if '.out' in gate: match = re.search(r'\.out(\d+)', gate) if match: i = int(match.group(1)) if i == 31: return [registry.get_id(f"{prefix}.$sign")] if i >= 23: return [registry.get_id(f"{prefix}.$exp[{i-23}]")] return [registry.get_id(f"{prefix}.$mant[{i}]")] return [] def infer_float32_unpack_inputs(gate: str, registry: SignalRegistry) -> List[int]: """Infer inputs for float32.unpack circuit.""" prefix = "float32.unpack" # Register 32-bit input for i in range(32): registry.register(f"{prefix}.$x[{i}]") # Sign bit (bit 31) if '.sign' in gate: return [registry.get_id(f"{prefix}.$x[31]")] # Exponent bits (bits 30-23) if '.exp' in gate: match = re.search(r'\.exp(\d+)', gate) if match: i = int(match.group(1)) # exp0 = bit 23, exp7 = bit 30 return [registry.get_id(f"{prefix}.$x[{23+i}]")] # Mantissa bits (bits 22-0) if '.mant' in gate: match = re.search(r'\.mant(\d+)', gate) if match: i = int(match.group(1)) return [registry.get_id(f"{prefix}.$x[{i}]")] return [] def build_float16_neg_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.neg circuit. Negates a float16 by flipping the sign bit. All other bits pass through unchanged. """ tensors = {} prefix = "float16.neg" # Sign bit: NOT of input sign tensors[f"{prefix}.out15.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.out15.bias"] = torch.tensor([0.0]) # All other bits: pass through for i in range(15): tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float16_abs_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.abs circuit. Absolute value: clear the sign bit, pass all others. """ tensors = {} prefix = "float16.abs" # Sign bit: always 0 (use constant #0) # Actually, we can just not output bit 15, or output 0 # For consistency, let's output 0 by using bias that never fires tensors[f"{prefix}.out15.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out15.bias"] = torch.tensor([-2.0]) # never fires # All other bits: pass through for i in range(15): tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float16_normalize_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.normalize circuit. Normalizes an extended mantissa by finding leading 1 and shifting. Used after float16 addition/subtraction. Inputs: - 13-bit extended mantissa ($m[12:0], where $m[12] is overflow bit) - 8-bit raw exponent ($e[7:0]) - 1-bit sign ($sign) Outputs: - shift_amt[3:0]: how many positions to shift left (0-12) - is_zero: mantissa is all zeros - overflow: mantissa bit 12 is set (need right shift) The actual shifting and exponent adjustment are done externally since a full barrel shifter is complex. """ tensors = {} prefix = "float16.normalize" # Detect overflow (bit 12 set) - needs right shift, not left tensors[f"{prefix}.overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.overflow.bias"] = torch.tensor([-0.5]) # Detect all-zero mantissa # is_zero = NOR of all 13 mantissa bits tensors[f"{prefix}.is_zero.weight"] = torch.tensor([-1.0] * 13) tensors[f"{prefix}.is_zero.bias"] = torch.tensor([0.0]) # CLZ on bits 11:0 (excluding overflow bit) to find shift amount # If overflow, shift amount is 0 (actually -1, handled specially) # pz[k] = 1 if top k bits of m[11:0] are all zero for k in range(1, 13): tensors[f"{prefix}.pz{k}.weight"] = torch.tensor([-1.0] * k) tensors[f"{prefix}.pz{k}.bias"] = torch.tensor([0.0]) # ge[k] = sum of pz >= k (CLZ >= k) for k in range(1, 13): tensors[f"{prefix}.ge{k}.weight"] = torch.tensor([1.0] * 12) tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)]) # NOT gates for binary encoding (need all even values for odd AND gates) for k in [2, 4, 6, 8, 10, 12]: tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0]) # Shift amount is min(CLZ, 12) encoded in 4 bits # bit3: CLZ >= 8 tensors[f"{prefix}.shift3.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.shift3.bias"] = torch.tensor([-0.5]) # pass ge8 # bit2: CLZ in {4-7, 12} = (ge4 AND NOT ge8) OR ge12 tensors[f"{prefix}.and_4_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_4_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.shift2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift2.bias"] = torch.tensor([-1.0]) # bit1: CLZ in {2,3,6,7,10,11} tensors[f"{prefix}.and_2_3.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_2_3.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.and_6_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_6_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.and_10_11.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_10_11.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.shift1.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.shift1.bias"] = torch.tensor([-1.0]) # bit0: CLZ is odd {1,3,5,7,9,11} for i in [1, 3, 5, 7, 9, 11]: tensors[f"{prefix}.and_{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.shift0.weight"] = torch.tensor([1.0] * 6) tensors[f"{prefix}.shift0.bias"] = torch.tensor([-1.0]) # When overflow is set, shift amount should be 0 (we'll right-shift by 1 externally) # Mask shift bits with NOT overflow tensors[f"{prefix}.not_overflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_overflow.bias"] = torch.tensor([0.0]) for i in range(4): tensors[f"{prefix}.out_shift{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out_shift{i}.bias"] = torch.tensor([-2.0]) return tensors def build_float16_cmp_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.cmp circuit. Computes a > b for two float16 values. IEEE 754 comparison trick: - If both positive: compare as unsigned integers - If signs differ: positive > negative - If both negative: compare reversed Architecture: 1. sign_a, sign_b extraction 2. Magnitude comparison using existing 8-bit comparators (high/low bytes) 3. Sign-based result selection """ tensors = {} prefix = "float16.cmp" # Sign extraction (pass-through from bit 15) tensors[f"{prefix}.sign_a.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sign_a.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sign_b.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sign_b.bias"] = torch.tensor([-0.5]) # NOT sign gates tensors[f"{prefix}.not_sign_a.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sign_a.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_sign_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sign_b.bias"] = torch.tensor([0.0]) # Magnitude comparison: compare bits 14-0 of a vs b # Use weighted comparison (higher bits have higher weight) # a_mag > b_mag when weighted(a) - weighted(b) > 0 # Weights: bit 14 = 16384, bit 13 = 8192, ..., bit 0 = 1 weights_a = [float(2**i) for i in range(15)] weights_b = [-float(2**i) for i in range(15)] tensors[f"{prefix}.mag_cmp.weight"] = torch.tensor(weights_a + weights_b) tensors[f"{prefix}.mag_cmp.bias"] = torch.tensor([-0.5]) # strict > (not >=) # a_mag > b_mag (pass-through) tensors[f"{prefix}.a_gt_b_mag.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.a_gt_b_mag.bias"] = torch.tensor([-0.5]) # b_mag > a_mag (for negative case) # Inputs are [b bits, a bits], want b - a > 0 # So weights are [+2^i for b, -2^i for a] tensors[f"{prefix}.b_gt_a_mag.weight"] = torch.tensor(weights_a + weights_b) tensors[f"{prefix}.b_gt_a_mag.bias"] = torch.tensor([-0.5]) # strict > # Case: both positive (sign_a=0, sign_b=0) -> result = a_mag > b_mag # AND(not_sign_a, not_sign_b, a_gt_b_mag) tensors[f"{prefix}.both_pos_gt.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.both_pos_gt.bias"] = torch.tensor([-3.0]) # Case: both negative (sign_a=1, sign_b=1) -> result = b_mag > a_mag (reversed) # AND(sign_a, sign_b, b_gt_a_mag) tensors[f"{prefix}.both_neg_gt.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.both_neg_gt.bias"] = torch.tensor([-3.0]) # Check if both magnitudes are zero (for +0 == -0 case) # mag_a_nonzero: OR of bits 0-14 of a tensors[f"{prefix}.mag_a_nonzero.weight"] = torch.tensor([1.0] * 15) tensors[f"{prefix}.mag_a_nonzero.bias"] = torch.tensor([-1.0]) # mag_b_nonzero: OR of bits 0-14 of b tensors[f"{prefix}.mag_b_nonzero.weight"] = torch.tensor([1.0] * 15) tensors[f"{prefix}.mag_b_nonzero.bias"] = torch.tensor([-1.0]) # either_nonzero: OR(mag_a_nonzero, mag_b_nonzero) tensors[f"{prefix}.either_nonzero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_nonzero.bias"] = torch.tensor([-1.0]) # Case: a positive, b negative (sign_a=0, sign_b=1) -> a > b # BUT only if at least one is non-zero (to handle +0 vs -0) # AND(not_sign_a, sign_b, either_nonzero) tensors[f"{prefix}.a_pos_b_neg.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.a_pos_b_neg.bias"] = torch.tensor([-3.0]) # Final result: OR of all true cases tensors[f"{prefix}.gt.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.gt.bias"] = torch.tensor([-1.0]) return tensors def build_float16_pack_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.pack circuit. Takes sign (1 bit), exponent (5 bits), mantissa (10 bits) and assembles them into a 16-bit output. Output layout: - out[15] = sign - out[14:10] = exp[4:0] - out[9:0] = mant[9:0] """ tensors = {} prefix = "float16.pack" # Output bits are pass-throughs from inputs for i in range(16): tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float16_unpack_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.unpack circuit. IEEE 754 half-precision (float16) format: - Bit 15: sign (1 bit) - Bits 14-10: exponent (5 bits) - Bits 9-0: mantissa (10 bits) This circuit extracts each field as a separate output. Uses simple pass-through gates (weight=1, bias=-0.5). """ tensors = {} prefix = "float16.unpack" # Sign bit extraction (bit 15) tensors[f"{prefix}.sign.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sign.bias"] = torch.tensor([-0.5]) # Exponent extraction (bits 14-10, 5 bits) for i in range(5): tensors[f"{prefix}.exp{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.exp{i}.bias"] = torch.tensor([-0.5]) # Mantissa extraction (bits 9-0, 10 bits) for i in range(10): tensors[f"{prefix}.mant{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.mant{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float32_neg_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float32.neg circuit. Negates a float32 by flipping the sign bit. All other bits pass through unchanged. """ tensors = {} prefix = "float32.neg" # Sign bit: NOT of input sign tensors[f"{prefix}.out31.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.out31.bias"] = torch.tensor([0.0]) # All other bits: pass through for i in range(31): tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float32_abs_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float32.abs circuit. Absolute value: clear the sign bit, pass all others. """ tensors = {} prefix = "float32.abs" # Sign bit forced to 0 tensors[f"{prefix}.out31.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out31.bias"] = torch.tensor([-2.0]) # never fires # All other bits: pass through for i in range(31): tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float32_pack_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float32.pack circuit. Takes sign (1 bit), exponent (8 bits), mantissa (23 bits) and assembles them into a 32-bit output. Output layout: - out[31] = sign - out[30:23] = exp[7:0] - out[22:0] = mant[22:0] """ tensors = {} prefix = "float32.pack" # Output bits are pass-throughs from inputs for i in range(32): tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float32_unpack_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float32.unpack circuit. IEEE 754 single-precision (float32) format: - Bit 31: sign (1 bit) - Bits 30-23: exponent (8 bits) - Bits 22-0: mantissa (23 bits) This circuit extracts each field as a separate output. Uses simple pass-through gates (weight=1, bias=-0.5). """ tensors = {} prefix = "float32.unpack" # Sign bit extraction (bit 31) tensors[f"{prefix}.sign.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sign.bias"] = torch.tensor([-0.5]) # Exponent extraction (bits 30-23, 8 bits) for i in range(8): tensors[f"{prefix}.exp{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.exp{i}.bias"] = torch.tensor([-0.5]) # Mantissa extraction (bits 22-0, 23 bits) for i in range(23): tensors[f"{prefix}.mant{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.mant{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float32_cmp_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float32.cmp circuit. Computes a > b for two float32 values using IEEE-754 sign/magnitude rules. """ tensors = {} prefix = "float32.cmp" # Sign extraction (pass-through from bit 31) tensors[f"{prefix}.sign_a.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sign_a.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sign_b.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sign_b.bias"] = torch.tensor([-0.5]) # NOT sign gates tensors[f"{prefix}.not_sign_a.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sign_a.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_sign_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sign_b.bias"] = torch.tensor([0.0]) # Magnitude comparison: compare bits 30-0 of a vs b weights_a = [float(2**i) for i in range(31)] weights_b = [-float(2**i) for i in range(31)] tensors[f"{prefix}.mag_cmp.weight"] = torch.tensor(weights_a + weights_b) tensors[f"{prefix}.mag_cmp.bias"] = torch.tensor([-0.5]) # strict > # a_mag > b_mag (pass-through) tensors[f"{prefix}.a_gt_b_mag.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.a_gt_b_mag.bias"] = torch.tensor([-0.5]) # b_mag > a_mag (reversed comparison) tensors[f"{prefix}.b_gt_a_mag.weight"] = torch.tensor(weights_a + weights_b) tensors[f"{prefix}.b_gt_a_mag.bias"] = torch.tensor([-0.5]) # both positive: AND(not_sign_a, not_sign_b, a_gt_b_mag) tensors[f"{prefix}.both_pos_gt.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.both_pos_gt.bias"] = torch.tensor([-3.0]) # both negative: AND(sign_a, sign_b, b_gt_a_mag) tensors[f"{prefix}.both_neg_gt.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.both_neg_gt.bias"] = torch.tensor([-3.0]) # mag_a_nonzero: OR of bits 0-30 of a tensors[f"{prefix}.mag_a_nonzero.weight"] = torch.tensor([1.0] * 31) tensors[f"{prefix}.mag_a_nonzero.bias"] = torch.tensor([-1.0]) # mag_b_nonzero: OR of bits 0-30 of b tensors[f"{prefix}.mag_b_nonzero.weight"] = torch.tensor([1.0] * 31) tensors[f"{prefix}.mag_b_nonzero.bias"] = torch.tensor([-1.0]) # either_nonzero: OR(mag_a_nonzero, mag_b_nonzero) tensors[f"{prefix}.either_nonzero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_nonzero.bias"] = torch.tensor([-1.0]) # a_pos_b_neg: AND(not_sign_a, sign_b, either_nonzero) tensors[f"{prefix}.a_pos_b_neg.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.a_pos_b_neg.bias"] = torch.tensor([-3.0]) # Final result: OR(both_pos_gt, both_neg_gt, a_pos_b_neg) tensors[f"{prefix}.gt.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.gt.bias"] = torch.tensor([-1.0]) return tensors def build_float16_pow_tensors(mul_tensors: Dict[str, torch.Tensor], ln_outputs: List[int], exp_outputs: List[int]) -> Dict[str, torch.Tensor]: """Build tensors for float16.pow via ln -> mul -> exp.""" tensors: Dict[str, torch.Tensor] = {} # ln(a) LUT tensors.update(build_float16_lut_match_tensors("float16.pow.ln")) tensors.update(build_float16_lut_output_tensors("float16.pow.ln", ln_outputs)) # mul(ln(a), b) tensors.update(clone_prefix_tensors(mul_tensors, "float16.mul", "float16.pow.mul")) # exp(mul) tensors.update(build_float16_lut_match_tensors("float16.pow.exp")) tensors.update(build_float16_lut_output_tensors("float16.pow.exp", exp_outputs)) # Final outputs (pass-through from exp) prefix = "float16.pow" for i in range(16): tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_clz16bit_tensors() -> Dict[str, torch.Tensor]: """Build tensors for arithmetic.clz16bit circuit. CLZ16BIT counts leading zeros in a 16-bit input. Output is 0-16 (5 bits). Architecture (same as CLZ8BIT): 1. pz[k] gates: NOR of top k bits (fires if top k bits are all zero) 2. ge[k] gates: sum of pz >= k (threshold gates) 3. Logic gates to convert thermometer code to binary """ tensors = {} prefix = "arithmetic.clz16bit" # === PREFIX ZERO GATES (NOR of top k bits) === for k in range(1, 17): tensors[f"{prefix}.pz{k}.weight"] = torch.tensor([-1.0] * k) tensors[f"{prefix}.pz{k}.bias"] = torch.tensor([0.0]) # === GE GATES (sum of pz >= k) === for k in range(1, 17): tensors[f"{prefix}.ge{k}.weight"] = torch.tensor([1.0] * 16) tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)]) # === NOT GATES (for all values used in range detection) === for k in [2, 4, 6, 8, 10, 12, 14, 16]: tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0]) # === AND GATES for range detection === # For 5-bit output (0-16), need to detect ranges for each bit # bit4 (16's place): CLZ >= 16, just ge16 # bit3 (8's place): CLZ in {8-15} = ge8 AND NOT ge16 tensors[f"{prefix}.and_8_15.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_8_15.bias"] = torch.tensor([-2.0]) # bit2 (4's place): CLZ in {4-7, 12-15} # = (ge4 AND NOT ge8) OR (ge12 AND NOT ge16) tensors[f"{prefix}.and_4_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_4_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.and_12_15.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_12_15.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.or_bit2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.or_bit2.bias"] = torch.tensor([-1.0]) # bit1 (2's place): CLZ in {2,3,6,7,10,11,14,15} tensors[f"{prefix}.and_2_3.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_2_3.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.and_6_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_6_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.and_10_11.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_10_11.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.and_14_15.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_14_15.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.or_bit1.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.or_bit1.bias"] = torch.tensor([-1.0]) # bit0 (1's place): CLZ is odd {1,3,5,7,9,11,13,15} for i in [1, 3, 5, 7, 9, 11, 13, 15]: tensors[f"{prefix}.and_{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.and_{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.or_bit0.weight"] = torch.tensor([1.0] * 8) tensors[f"{prefix}.or_bit0.bias"] = torch.tensor([-1.0]) # === OUTPUT GATES === tensors[f"{prefix}.out4.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out4.bias"] = torch.tensor([-0.5]) # pass-through ge16 tensors[f"{prefix}.out3.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out3.bias"] = torch.tensor([-0.5]) # pass-through and_8_15 tensors[f"{prefix}.out2.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out2.bias"] = torch.tensor([-0.5]) # pass-through or_bit2 tensors[f"{prefix}.out1.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out1.bias"] = torch.tensor([-0.5]) # pass-through or_bit1 tensors[f"{prefix}.out0.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out0.bias"] = torch.tensor([-0.5]) # pass-through or_bit0 return tensors def build_float16_add_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.add circuit. IEEE 754 half-precision addition with full special case handling: 1. Detect special cases (NaN, infinity, zero, subnormal) 2. Extract sign, exponent, mantissa from both operands 3. Add implicit bit (1 for normal, 0 for subnormal) 4. Compare exponents to find which is larger 5. Align mantissas by shifting smaller exponent's mantissa right 6. Add or subtract mantissas based on signs 7. Normalize result and adjust exponent 8. Handle overflow (to infinity) and underflow (to zero/subnormal) 9. Pack result with correct special case outputs Inputs: $a[0:15], $b[0:15] (two float16 values) Outputs: out[0:15] (float16 result) """ tensors = {} prefix = "float16.add" # ========================================================================= # STAGE 0: SPECIAL CASE DETECTION # ========================================================================= # Detect NaN, infinity, zero, and subnormal inputs. # float16 encoding: # - Zero: exp=0, mant=0 # - Subnormal: exp=0, mant≠0 # - Normal: 0 < exp < 31 # - Infinity: exp=31, mant=0 # - NaN: exp=31, mant≠0 # exp_a_all_ones: all 5 exponent bits are 1 (exp >= 31) # Threshold gate: sum of exp bits >= 5 tensors[f"{prefix}.exp_a_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_a_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.exp_b_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_b_all_ones.bias"] = torch.tensor([-5.0]) # exp_a_zero: all 5 exponent bits are 0 (NOR gate) tensors[f"{prefix}.exp_a_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_a_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_b_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_b_zero.bias"] = torch.tensor([0.0]) # Adjusted exp bit0 for subnormals (effective exponent = 1) tensors[f"{prefix}.a_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_adj_exp0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_adj_exp0.bias"] = torch.tensor([-1.0]) # Adjusted exp bit0 for subnormals (effective exponent = 1) tensors[f"{prefix}.a_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_adj_exp0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_adj_exp0.bias"] = torch.tensor([-1.0]) # Adjusted exponent bit 0 for subnormals: # Subnormals have exp=0 but effective exp=1, so adjust bit 0 # a_adj_exp0 = a_exp[0] OR exp_a_zero tensors[f"{prefix}.a_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_adj_exp0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_adj_exp0.bias"] = torch.tensor([-1.0]) # NOT of adjusted exp bit 0 (for subtractor) tensors[f"{prefix}.not_a_adj_exp0.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_a_adj_exp0.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_b_adj_exp0.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_b_adj_exp0.bias"] = torch.tensor([0.0]) # mant_a_nonzero: OR of all 10 mantissa bits tensors[f"{prefix}.mant_a_nonzero.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.mant_a_nonzero.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_b_nonzero.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.mant_b_nonzero.bias"] = torch.tensor([-1.0]) # mant_a_zero: NOR of all mantissa bits tensors[f"{prefix}.mant_a_zero.weight"] = torch.tensor([-1.0] * 10) tensors[f"{prefix}.mant_a_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.mant_b_zero.weight"] = torch.tensor([-1.0] * 10) tensors[f"{prefix}.mant_b_zero.bias"] = torch.tensor([0.0]) # a_is_nan: exp_a_all_ones AND mant_a_nonzero tensors[f"{prefix}.a_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_nan.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_nan.bias"] = torch.tensor([-2.0]) # a_is_inf: exp_a_all_ones AND mant_a_zero tensors[f"{prefix}.a_is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_inf.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_inf.bias"] = torch.tensor([-2.0]) # a_is_zero: exp_a_zero AND mant_a_zero tensors[f"{prefix}.a_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_zero.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_zero.bias"] = torch.tensor([-2.0]) # a_is_subnormal: exp_a_zero AND mant_a_nonzero tensors[f"{prefix}.a_is_subnormal.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_subnormal.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_subnormal.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_subnormal.bias"] = torch.tensor([-2.0]) # either_is_nan: a_is_nan OR b_is_nan tensors[f"{prefix}.either_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_is_nan.bias"] = torch.tensor([-1.0]) # both_are_zero: a_is_zero AND b_is_zero (special case: output zero) tensors[f"{prefix}.both_are_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.both_are_zero.bias"] = torch.tensor([-2.0]) # either_is_zero: a_is_zero OR b_is_zero (for zero+x = x case) tensors[f"{prefix}.either_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_is_zero.bias"] = torch.tensor([-1.0]) # both_are_inf: a_is_inf AND b_is_inf tensors[f"{prefix}.both_are_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.both_are_inf.bias"] = torch.tensor([-2.0]) # signs_differ: sign_a XOR sign_b (for inf + (-inf) = NaN case) # XOR layer 1 tensors[f"{prefix}.signs_differ.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.signs_differ.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.signs_differ.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{prefix}.signs_differ.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{prefix}.signs_differ.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.signs_differ.layer2.bias"] = torch.tensor([-2.0]) # inf_cancellation: both_are_inf AND signs_differ (produces NaN) tensors[f"{prefix}.inf_cancellation.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.inf_cancellation.bias"] = torch.tensor([-2.0]) # result_is_nan: either_is_nan OR inf_cancellation tensors[f"{prefix}.result_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_is_nan.bias"] = torch.tensor([-1.0]) # either_is_inf: a_is_inf OR b_is_inf tensors[f"{prefix}.either_is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_is_inf.bias"] = torch.tensor([-1.0]) # NOT result_is_nan (for masking inf result) tensors[f"{prefix}.not_result_is_nan.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_nan.bias"] = torch.tensor([0.0]) # result_is_inf: (either_is_inf OR exp_overflow_to_inf) AND NOT result_is_nan # Note: exp_overflow_to_inf is defined later, after result_exp calculation # We use a 3-input weighted gate here: # either_is_inf + exp_overflow_to_inf + 2*not_result_is_nan >= 2.5 # This means: need at least one inf source AND not_result_is_nan tensors[f"{prefix}.result_is_inf.weight"] = torch.tensor([1.0, 1.0, 2.0]) tensors[f"{prefix}.result_is_inf.bias"] = torch.tensor([-2.5]) # ========================================================================= # STAGE 1: EXTRACT COMPONENTS # ========================================================================= # sign_a = a[15], sign_b = b[15] # exp_a[0:4] = a[10:14], exp_b[0:4] = b[10:14] # mant_a[0:9] = a[0:9], mant_b[0:9] = b[0:9] # Pass-through gates for sign extraction tensors[f"{prefix}.sign_a.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sign_a.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sign_b.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sign_b.bias"] = torch.tensor([-0.5]) # Implicit bit calculation: # For normal numbers, implicit bit = 1 # For subnormal numbers, implicit bit = 0 # implicit_a = NOT a_is_subnormal AND NOT a_is_zero = NOT exp_a_zero # Actually simpler: implicit_a = NOT exp_a_zero (since exp=0 means no implicit 1) tensors[f"{prefix}.implicit_a.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_a.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.implicit_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_b.bias"] = torch.tensor([0.0]) for i in range(10): tensors[f"{prefix}.not_div_b{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_div_b{i}.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_implicit_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_implicit_b.bias"] = torch.tensor([0.0]) for i in range(10): tensors[f"{prefix}.not_div_b{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_div_b{i}.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_implicit_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_implicit_b.bias"] = torch.tensor([0.0]) # ========================================================================= # STAGE 2: EXPONENT COMPARISON # ========================================================================= # Compare exp_a vs exp_b using weighted comparison # Weights: bit[i] contributes 2^i to the total # exp_a >= exp_b when weighted(exp_a) - weighted(exp_b) >= 0 weights_exp_a = [float(2**i) for i in range(5)] # +1, +2, +4, +8, +16 weights_exp_b = [-float(2**i) for i in range(5)] # -1, -2, -4, -8, -16 # a_exp_ge_b: exp_a >= exp_b tensors[f"{prefix}.a_exp_ge_b.weight"] = torch.tensor(weights_exp_a + weights_exp_b) tensors[f"{prefix}.a_exp_ge_b.bias"] = torch.tensor([0.0]) # >= (not strict >) # a_exp_gt_b: exp_a > exp_b (for strict comparison) tensors[f"{prefix}.a_exp_gt_b.weight"] = torch.tensor(weights_exp_a + weights_exp_b) tensors[f"{prefix}.a_exp_gt_b.bias"] = torch.tensor([-0.5]) # strict > # b_exp_gt_a: exp_b > exp_a tensors[f"{prefix}.b_exp_gt_a.weight"] = torch.tensor(weights_exp_b[::-1] + weights_exp_a[::-1]) # Actually, simpler: just swap the inputs conceptually # b > a means weights for b positive, weights for a negative tensors[f"{prefix}.b_exp_gt_a.weight"] = torch.tensor(weights_exp_a + weights_exp_b) tensors[f"{prefix}.b_exp_gt_a.bias"] = torch.tensor([-0.5]) # NOT of a_exp_ge_b (for selecting which path) tensors[f"{prefix}.b_exp_gt_a_sel.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_exp_gt_a_sel.bias"] = torch.tensor([0.0]) # ========================================================================= # STAGE 3: COMPUTE EXPONENT DIFFERENCE # ========================================================================= # We need |exp_a - exp_b| for the shift amount. # Use 5-bit subtractors: exp_a - exp_b and exp_b - exp_a # Then select based on which exponent is larger. # 5-bit subtractor for exp_a - exp_b (using two's complement) # NOT gates for exp_b for i in range(5): tensors[f"{prefix}.not_exp_b{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_b{i}.bias"] = torch.tensor([0.0]) # Full adders for exp_a + NOT(exp_b) + 1 = exp_a - exp_b # FA0: bit 0 # XOR1: exp_a[0] XOR not_exp_b[0] tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{prefix}.diff_ab.fa0.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_ab.fa0.xor1.layer2.bias"] = torch.tensor([-2.0]) # XOR2: xor1 XOR cin (cin=1 for subtraction) tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{prefix}.diff_ab.fa0.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_ab.fa0.xor2.layer2.bias"] = torch.tensor([-2.0]) # Carry: (a AND b) OR (xor1 AND cin) tensors[f"{prefix}.diff_ab.fa0.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_ab.fa0.and1.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.diff_ab.fa0.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_ab.fa0.and2.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.diff_ab.fa0.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_ab.fa0.cout.bias"] = torch.tensor([-1.0]) # FA1-FA4: remaining bits (carry chain) for i in range(1, 5): p = f"{prefix}.diff_ab.fa{i}" # XOR1: exp_a[i] XOR not_exp_b[i] tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) # XOR2: xor1 XOR carry_in tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) # Carry tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Similarly for exp_b - exp_a # NOT gates for exp_a for i in range(5): tensors[f"{prefix}.not_exp_a{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_a{i}.bias"] = torch.tensor([0.0]) # Full adders for exp_b + NOT(exp_a) + 1 = exp_b - exp_a for i in range(5): p = f"{prefix}.diff_ba.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 4: SELECT ABSOLUTE DIFFERENCE # ========================================================================= # exp_diff = a_exp_ge_b ? (exp_a - exp_b) : (exp_b - exp_a) # Use 2-to-1 mux for each bit for i in range(5): # Mux: out = (sel AND b) OR (NOT sel AND a) # sel = b_exp_gt_a_sel (1 if b > a, meaning we want diff_ba) # Actually: sel=0 (a>=b) -> use diff_ab, sel=1 (b>a) -> use diff_ba # AND gate for diff_ab path (when a_exp_ge_b = 1) tensors[f"{prefix}.exp_diff_mux{i}.and_ab.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_diff_mux{i}.and_ab.bias"] = torch.tensor([-2.0]) # AND gate for diff_ba path (when b_exp_gt_a_sel = 1, i.e., a_exp_ge_b = 0) tensors[f"{prefix}.exp_diff_mux{i}.and_ba.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_diff_mux{i}.and_ba.bias"] = torch.tensor([-2.0]) # OR to combine tensors[f"{prefix}.exp_diff{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_diff{i}.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 5: SELECT LARGER EXPONENT (for result) # ========================================================================= # exp_larger = a_exp_ge_b ? exp_a : exp_b for i in range(5): # AND gate for exp_a path tensors[f"{prefix}.exp_larger_mux{i}.and_a.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_larger_mux{i}.and_a.bias"] = torch.tensor([-2.0]) # AND gate for exp_b path tensors[f"{prefix}.exp_larger_mux{i}.and_b.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_larger_mux{i}.and_b.bias"] = torch.tensor([-2.0]) # OR to combine tensors[f"{prefix}.exp_larger{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_larger{i}.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 6: MANTISSA ALIGNMENT (Barrel Shifter) # ========================================================================= # The smaller exponent's mantissa needs to be shifted right by exp_diff. # Mantissa is 11 bits: implicit bit + 10 explicit mantissa bits. # # We need to: # 1. Select which mantissa to shift (the one with smaller exponent) # 2. Shift it right by exp_diff positions # 3. The larger mantissa passes through unchanged # # For the barrel shifter, we use cascaded 2-to-1 muxes: # - Stage 0: shift by 0 or 1 (controlled by exp_diff[0]) # - Stage 1: shift by 0 or 2 (controlled by exp_diff[1]) # - Stage 2: shift by 0 or 4 (controlled by exp_diff[2]) # - Stage 3: shift by 0 or 8 (controlled by exp_diff[3]) # # If exp_diff >= 11, the shifted mantissa becomes 0 (complete loss). # First, select which mantissa gets shifted (the smaller exponent one) # mant_to_shift = a_exp_ge_b ? mant_b : mant_a (shift the smaller exp's mantissa) # mant_larger = a_exp_ge_b ? mant_a : mant_b # Full mantissa with implicit bit: 11 bits (bit 10 = implicit, bits 9-0 = explicit) for i in range(11): # mant_shift_src[i] = mux(a_exp_ge_b, mant_b[i], mant_a[i]) # When a_exp_ge_b=1, we shift b's mantissa (a has larger exp) # When a_exp_ge_b=0, we shift a's mantissa (b has larger exp) tensors[f"{prefix}.mant_shift_src{i}.and_b.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_shift_src{i}.and_b.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.mant_shift_src{i}.and_a.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_shift_src{i}.and_a.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.mant_shift_src{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_shift_src{i}.bias"] = torch.tensor([-1.0]) # mant_larger[i] = mux(a_exp_ge_b, mant_a[i], mant_b[i]) tensors[f"{prefix}.mant_larger{i}.and_a.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_larger{i}.and_a.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.mant_larger{i}.and_b.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_larger{i}.and_b.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.mant_larger{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_larger{i}.bias"] = torch.tensor([-1.0]) # Barrel shifter stages # Stage 0: shift by 1 if exp_diff[0]=1 # NOT exp_diff[0] for pass-through path tensors[f"{prefix}.not_exp_diff0.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_diff0.bias"] = torch.tensor([0.0]) for i in range(11): # Output bit i comes from: # - bit i if not shifting (exp_diff[0]=0) # - bit i+1 if shifting (exp_diff[0]=1), or 0 if i+1 >= 11 tensors[f"{prefix}.shift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_s0_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 10: tensors[f"{prefix}.shift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_s0_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.shift_s0_{i}.weight"] = torch.tensor([1.0, 1.0]) else: # MSB: shift path is 0, so just pass-through when not shifting tensors[f"{prefix}.shift_s0_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.shift_s0_{i}.bias"] = torch.tensor([-1.0]) # Stage 1: shift by 2 if exp_diff[1]=1 tensors[f"{prefix}.not_exp_diff1.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_diff1.bias"] = torch.tensor([0.0]) for i in range(11): tensors[f"{prefix}.shift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_s1_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 9: tensors[f"{prefix}.shift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_s1_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.shift_s1_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.shift_s1_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.shift_s1_{i}.bias"] = torch.tensor([-1.0]) # Stage 2: shift by 4 if exp_diff[2]=1 tensors[f"{prefix}.not_exp_diff2.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_diff2.bias"] = torch.tensor([0.0]) for i in range(11): tensors[f"{prefix}.shift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_s2_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 7: tensors[f"{prefix}.shift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_s2_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.shift_s2_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.shift_s2_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.shift_s2_{i}.bias"] = torch.tensor([-1.0]) # Stage 3: shift by 8 if exp_diff[3]=1 tensors[f"{prefix}.not_exp_diff3.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_diff3.bias"] = torch.tensor([0.0]) for i in range(11): tensors[f"{prefix}.shift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_s3_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 3: tensors[f"{prefix}.shift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_s3_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.shift_s3_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.shift_s3_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.shift_s3_{i}.bias"] = torch.tensor([-1.0]) # If exp_diff[4]=1 (shift by 16 or more), result is 0 # mant_aligned = exp_diff[4] ? 0 : shift_s3 result tensors[f"{prefix}.not_exp_diff4.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_diff4.bias"] = torch.tensor([0.0]) for i in range(11): # Only pass through if exp_diff[4]=0 tensors[f"{prefix}.mant_aligned{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_aligned{i}.bias"] = torch.tensor([-2.0]) # ========================================================================= # GUARD/ROUND/STICKY BITS FOR IEEE 754 ROUNDING # ========================================================================= # Track bits shifted out during alignment for proper rounding # Guard (G) = MSB of shifted-out bits = mant_shift_src[exp_diff-1] # Round (R) = next bit = mant_shift_src[exp_diff-2] # Sticky (S) = OR of remaining bits # exp_diff_eq[k]: exp_diff == k (for k = 1 to 11) # Build these using exp_diff bits for k in range(1, 12): # exp_diff_eq[k] = AND of (exp_diff[i] == k_bit[i]) for all i # Use threshold gate: sum of matching bits >= 5 (all must match) weights = [] for i in range(5): bit = (k >> i) & 1 weights.append(1.0 if bit else -1.0) bias = -sum(1.0 for w in weights if w > 0) + 0.5 tensors[f"{prefix}.exp_diff_eq{k}.weight"] = torch.tensor(weights) tensors[f"{prefix}.exp_diff_eq{k}.bias"] = torch.tensor([bias]) # Guard bit selection: mux based on exp_diff # guard_sel[k] = mant_shift_src[k-1] AND exp_diff_eq[k] for k in range(1, 12): if k <= 11: tensors[f"{prefix}.guard_sel{k}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_sel{k}.bias"] = torch.tensor([-2.0]) # guard_bit = OR of all guard_sel[k] tensors[f"{prefix}.guard_bit.weight"] = torch.tensor([1.0] * 11) tensors[f"{prefix}.guard_bit.bias"] = torch.tensor([-1.0]) # round_sel[k] (k=2..11): bit below guard for k in range(2, 12): tensors[f"{prefix}.round_sel{k}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_sel{k}.bias"] = torch.tensor([-2.0]) # round_bit = OR of all round_sel[k] tensors[f"{prefix}.round_bit.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.round_bit.bias"] = torch.tensor([-1.0]) # Sticky bit: OR of all bits that get shifted out (simplified) # sticky_raw[i] = mant_shift_src[i] AND exp_diff > i # For simplicity, just check if any of the lower bits are 1 and we shifted # sticky = OR of (mant_shift_src[i] for i where i < exp_diff-2) # exp_diff_gt[k]: exp_diff > k (for sticky calculation) for k in range(13): # exp_diff > k means exp_diff >= k+1 # Use threshold: sum of (exp_diff[i] * 2^i) >= k+1 tensors[f"{prefix}.exp_diff_gt{k}.weight"] = torch.tensor([2**i for i in range(5)]) tensors[f"{prefix}.exp_diff_gt{k}.bias"] = torch.tensor([-float(k+1) + 0.5]) # sticky_part[i] = mant_shift_src[i] AND exp_diff > i+2 (this bit is below round) for i in range(11): tensors[f"{prefix}.sticky_part{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sticky_part{i}.bias"] = torch.tensor([-2.0]) # sticky_bit = OR of all sticky_part[i] tensors[f"{prefix}.sticky_bit.weight"] = torch.tensor([1.0] * 11) tensors[f"{prefix}.sticky_bit.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 7: MANTISSA ADDITION/SUBTRACTION WITH GUARD+ROUND BITS # ========================================================================= # IEEE 754 standard: include guard+round bits in arithmetic for correct rounding, # especially when subtraction causes cancellation and left normalization. # # Operands extended to 13 bits (11 mantissa + guard + round): # A[12:0] = {mant_larger[10:0], 0, 0} # B[12:0] = {mant_aligned[10:0], guard, round} # # 14-bit adder handles: A + B (addition) or A - B (subtraction) # Sum bit 13 = overflow, bits 12:0 = result with guard+round tensors[f"{prefix}.signs_same.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.signs_same.bias"] = torch.tensor([0.0]) weights_mant = [float(2**i) for i in range(11)] neg_weights_mant = [-float(2**i) for i in range(11)] tensors[f"{prefix}.mant_a_ge_b.weight"] = torch.tensor(weights_mant + neg_weights_mant) tensors[f"{prefix}.mant_a_ge_b.bias"] = torch.tensor([0.0]) # NOT gates for mant_aligned (for subtraction path) for i in range(11): tensors[f"{prefix}.not_mant_aligned{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_mant_aligned{i}.bias"] = torch.tensor([0.0]) # NOT of sticky/round/guard bits (for subtraction) tensors[f"{prefix}.not_sticky_bit.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sticky_bit.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_round_bit.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_round_bit.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_guard_bit.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_guard_bit.bias"] = torch.tensor([0.0]) # sub_cin = signs_differ (carry-in for 2's complement subtraction) tensors[f"{prefix}.sub_cin.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_cin.bias"] = torch.tensor([-0.5]) # Sticky/round/guard bit operand B selection # Sticky bit operand B selection: signs_same ? sticky_bit : NOT(sticky_bit) tensors[f"{prefix}.addsub_b_s.add.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_s.add.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.addsub_b_s.sub.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_s.sub.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.addsub_b_s.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_s.bias"] = torch.tensor([-1.0]) # Round bit operand B selection: signs_same ? round_bit : NOT(round_bit) tensors[f"{prefix}.addsub_b_r.add.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_r.add.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.addsub_b_r.sub.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_r.sub.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.addsub_b_r.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_r.bias"] = torch.tensor([-1.0]) # Guard bit operand B selection: signs_same ? guard_bit : NOT(guard_bit) tensors[f"{prefix}.addsub_b_g.add.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_g.add.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.addsub_b_g.sub.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_g.sub.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.addsub_b_g.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b_g.bias"] = torch.tensor([-1.0]) # Mantissa operand B selection: signs_same ? mant_aligned : NOT(mant_aligned) for i in range(11): tensors[f"{prefix}.addsub_b{i}.add.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b{i}.add.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.addsub_b{i}.sub.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b{i}.sub.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.addsub_b{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.addsub_b{i}.bias"] = torch.tensor([-1.0]) # 15-bit ripple carry adder: # Bit 0: A=#0, B=addsub_b_s (sticky position) # Bit 1: A=#0, B=addsub_b_r (round position) # Bit 2: A=#0, B=addsub_b_g (guard position) # Bits 3-13: A=mant_larger[i-3], B=addsub_b[i-3] # Bit 14: A=#0, B=#0 (overflow detection) for i in range(15): p = f"{prefix}.mant_add.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 8: RESULT SIGN DETERMINATION # ========================================================================= # When signs_same: result_sign = sign_a (= sign_b) # When signs_differ: # If a has larger magnitude: result_sign = sign_a # If b has larger magnitude: result_sign = sign_b # # Magnitude comparison: consider both exponent and mantissa # a_magnitude_ge_b: (exp_a > exp_b) OR (exp_a == exp_b AND mant_a >= mant_b) # exp_a_eq_b: NOT a_exp_gt_b AND NOT b_exp_gt_a tensors[f"{prefix}.not_a_exp_gt_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_a_exp_gt_b.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_a_eq_b.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_a_eq_b.bias"] = torch.tensor([-2.0]) # exp_eq_and_mant_a_ge: exp_a_eq_b AND mant_a_ge_b tensors[f"{prefix}.exp_eq_and_mant_a_ge.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_eq_and_mant_a_ge.bias"] = torch.tensor([-2.0]) # a_magnitude_ge_b: a_exp_gt_b OR exp_eq_and_mant_a_ge tensors[f"{prefix}.a_magnitude_ge_b.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_magnitude_ge_b.bias"] = torch.tensor([-1.0]) # result_sign when signs_differ: # = a_magnitude_ge_b ? sign_a : sign_b tensors[f"{prefix}.not_a_mag_ge_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_a_mag_ge_b.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.diff_sign_sel_a.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_sign_sel_a.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.diff_sign_sel_b.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_sign_sel_b.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.diff_result_sign.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.diff_result_sign.bias"] = torch.tensor([-1.0]) # Final result sign: signs_same ? sign_a : diff_result_sign tensors[f"{prefix}.result_sign_same.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_sign_same.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.result_sign_diff.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_sign_diff.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.result_sign.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_sign.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 9: NORMALIZATION # ========================================================================= # With guard bit included, sum is now 13 bits: # Bit 12 = overflow (carry out) # Bits 11:1 = 11-bit mantissa (implicit + 10 explicit) # Bit 0 = guard bit # # Overflow detection: bit 12 AND signs_same tensors[f"{prefix}.sum_bit12.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sum_bit12.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sum_overflow.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sum_overflow.bias"] = torch.tensor([-2.0]) # Zero detection: bits 13:0 all zero AND signs_differ tensors[f"{prefix}.sum_bits_zero.weight"] = torch.tensor([-1.0] * 14) tensors[f"{prefix}.sum_bits_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.sum_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sum_is_zero.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.not_sum_is_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sum_is_zero.bias"] = torch.tensor([0.0]) # CLZ on 14-bit sum (bits 13:0) to find normalization shift # pz gates: prefix zero detectors for k in range(1, 15): tensors[f"{prefix}.sum_pz{k}.weight"] = torch.tensor([-1.0] * k) tensors[f"{prefix}.sum_pz{k}.bias"] = torch.tensor([0.0]) # ge gates: sum of pz >= k (for 14-bit CLZ, max is 14) for k in range(1, 15): tensors[f"{prefix}.sum_ge{k}.weight"] = torch.tensor([1.0] * 14) tensors[f"{prefix}.sum_ge{k}.bias"] = torch.tensor([-float(k)]) # NOT gates for binary encoding for k in [2, 4, 6, 8, 10, 12]: tensors[f"{prefix}.sum_not_ge{k}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.sum_not_ge{k}.bias"] = torch.tensor([0.0]) # Shift amount encoding (4 bits for 0-12) tensors[f"{prefix}.norm_shift3.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.norm_shift3.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.norm_and_4_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_4_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_and_12.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.norm_and_12.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.norm_shift2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_shift2.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.norm_and_2_3.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_2_3.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_and_6_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_6_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_and_10_11.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_10_11.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_shift1.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.norm_shift1.bias"] = torch.tensor([-1.0]) for i in [1, 3, 5, 7, 9, 11]: tensors[f"{prefix}.norm_and_{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_shift0.weight"] = torch.tensor([1.0] * 7) tensors[f"{prefix}.norm_shift0.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 10: APPLY NORMALIZATION TO MANTISSA # ========================================================================= # With 13-bit sum (bit 12=overflow, bits 11:1=mantissa, bit 0=guard): # 1. Overflow: result mantissa = sum[11:2] (right-shift by 1, skip guard) # 2. No overflow: left-shift sum[11:0], result = shifted[10:1] (skip guard) tensors[f"{prefix}.not_sum_overflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sum_overflow.bias"] = torch.tensor([0.0]) # Overflow mantissa: sum[13:4] = 10 bits (skipping overflow bit and sticky/round/guard) for i in range(10): tensors[f"{prefix}.norm_mant_overflow{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.norm_mant_overflow{i}.bias"] = torch.tensor([-0.5]) # 14-bit left barrel shifter on sum[13:0] tensors[f"{prefix}.not_norm_shift0.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift0.bias"] = torch.tensor([0.0]) for i in range(14): tensors[f"{prefix}.lshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s0_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 0: tensors[f"{prefix}.lshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s0_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s0_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s0_{i}.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift1.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift1.bias"] = torch.tensor([0.0]) for i in range(14): tensors[f"{prefix}.lshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s1_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 1: tensors[f"{prefix}.lshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s1_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s1_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s1_{i}.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift2.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift2.bias"] = torch.tensor([0.0]) for i in range(14): tensors[f"{prefix}.lshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s2_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 3: tensors[f"{prefix}.lshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s2_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s2_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s2_{i}.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift3.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift3.bias"] = torch.tensor([0.0]) for i in range(14): tensors[f"{prefix}.lshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s3_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 7: tensors[f"{prefix}.lshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s3_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s3_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s3_{i}.bias"] = torch.tensor([-1.0]) # norm_mant[i] = overflow ? sum[i+4] : lshift[i+3] for i in range(10): tensors[f"{prefix}.norm_mant{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_mant{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_mant{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 11: ADJUST EXPONENT # ========================================================================= # Overflow: exp_result = exp_larger + 1 # No overflow: exp_result = exp_larger - norm_shift # Increment exponent by 1 (for overflow case) # Half adder chain: exp_larger + 1 tensors[f"{prefix}.exp_inc.ha0.sum.weight"] = torch.tensor([-1.0]) # NOT for XOR with 1 tensors[f"{prefix}.exp_inc.ha0.sum.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_inc.ha0.cout.weight"] = torch.tensor([1.0]) # AND with 1 = passthrough tensors[f"{prefix}.exp_inc.ha0.cout.bias"] = torch.tensor([-0.5]) for i in range(1, 5): # XOR: exp[i] XOR carry_in tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{prefix}.exp_inc.ha{i}.sum.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_inc.ha{i}.sum.bias"] = torch.tensor([-2.0]) # Carry: exp[i] AND carry_in tensors[f"{prefix}.exp_inc.ha{i}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_inc.ha{i}.cout.bias"] = torch.tensor([-2.0]) # Decrement exponent by norm_shift (for non-overflow case) # 5-bit subtractor: exp_larger - norm_shift # NOT gates for norm_shift for i in range(4): tensors[f"{prefix}.not_norm_shift_sub{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift_sub{i}.bias"] = torch.tensor([0.0]) # Full adders for exp_larger + NOT(norm_shift) + 1 = exp_larger - norm_shift for i in range(5): p = f"{prefix}.exp_dec.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Select result exponent: overflow ? exp_inc : exp_dec for i in range(5): tensors[f"{prefix}.result_exp{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_exp{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.result_exp{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_exp{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.result_exp{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_exp{i}.bias"] = torch.tensor([-1.0]) # Detect exponent overflow (result_exp = 31 = all ones → infinity) tensors[f"{prefix}.final_exp_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.final_exp_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.round_exp_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.round_exp_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.exp_overflow_any.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.exp_overflow_any.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_overflow_to_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_overflow_to_inf.bias"] = torch.tensor([-2.0]) # ========================================================================= # STAGE 11B: ROUNDING (round-to-nearest-even) # ========================================================================= # Derive guard/round/sticky from post-normalization bits. # For overflow path: guard = sum bit3, round = sum bit2, sticky = OR(sum bit1, sum bit0). # For normal path: guard = lshift_s3_2, round = lshift_s3_1, sticky = lshift_s3_0. tensors[f"{prefix}.round_guard_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_guard_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_guard_norm.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_guard_norm.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_guard_overflow.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_guard_overflow.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_guard_norm.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_guard_norm.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_guard.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_guard.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_post_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_post_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_post_norm.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_post_norm.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_post_overflow.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_post_overflow.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_post_norm.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_post_norm.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_post.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_post.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sticky_overflow.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.sticky_overflow.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sticky_norm.same.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sticky_norm.same.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sticky_norm.diff.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sticky_norm.diff.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sticky_norm.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sticky_norm.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_overflow_or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_overflow_or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_norm_or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_norm_or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_sticky_overflow.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_sticky_overflow.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_sticky_norm.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_sticky_norm.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_sticky.bias"] = torch.tensor([-1.0]) # round_inc = round_guard AND (round_sticky OR LSB) tensors[f"{prefix}.round_lsb_or_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_lsb_or_sticky.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_inc.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_inc.bias"] = torch.tensor([-2.0]) # Add round_inc to mantissa for i in range(10): p = f"{prefix}.round_norm.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.not_round_overflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_round_overflow.bias"] = torch.tensor([0.0]) # Increment exponent on rounding overflow for i in range(5): p = f"{prefix}.round_exp.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Final exponent select after rounding for i in range(5): tensors[f"{prefix}.final_exp{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.final_exp{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.final_exp{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.final_exp{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.final_exp{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.final_exp{i}.bias"] = torch.tensor([-1.0]) # Final mantissa select for i in range(10): tensors[f"{prefix}.final_mant{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.final_mant{i}.bias"] = torch.tensor([-2.0]) # ========================================================================= # STAGE 11C: SUBNORMAL UNDERFLOW HANDLING # ========================================================================= # Detect exponent underflow (exp_dec <= 0) when no overflow tensors[f"{prefix}.exp_dec_borrow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_dec_borrow.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_dec_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_dec_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_underflow_or_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_underflow_or_zero.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_underflow.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_underflow.bias"] = torch.tensor([-2.0]) # sub_shift = norm_shift + 1 - exp_larger for i in range(5): p = f"{prefix}.sub_shift_add.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(5): tensors[f"{prefix}.sub_shift_not_exp{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_shift_not_exp{i}.bias"] = torch.tensor([0.0]) for i in range(5): p = f"{prefix}.sub_shift.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Right barrel shifter for subnormal mantissa (14 bits) for i in range(5): tensors[f"{prefix}.not_sub_shift{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sub_shift{i}.bias"] = torch.tensor([0.0]) for i in range(14): tensors[f"{prefix}.sub_rshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 13: tensors[f"{prefix}.sub_rshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s0_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.bias"] = torch.tensor([-1.0]) for i in range(14): tensors[f"{prefix}.sub_rshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 12: tensors[f"{prefix}.sub_rshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s1_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.bias"] = torch.tensor([-1.0]) for i in range(14): tensors[f"{prefix}.sub_rshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 10: tensors[f"{prefix}.sub_rshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s2_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.bias"] = torch.tensor([-1.0]) for i in range(14): tensors[f"{prefix}.sub_rshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 6: tensors[f"{prefix}.sub_rshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s3_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.bias"] = torch.tensor([-1.0]) for i in range(14): tensors[f"{prefix}.sub_shifted{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_shifted{i}.bias"] = torch.tensor([-2.0]) for i in range(10): tensors[f"{prefix}.sub_mant{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_mant{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sub_guard.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_guard.bias"] = torch.tensor([-0.5]) # sub_shift_gt[k]: sub_shift > k for k in range(14): tensors[f"{prefix}.sub_shift_gt{k}.weight"] = torch.tensor([2**i for i in range(5)]) tensors[f"{prefix}.sub_shift_gt{k}.bias"] = torch.tensor([-float(k+1) + 0.5]) for i in range(14): tensors[f"{prefix}.sub_sticky_part{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_sticky_part{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_sticky_raw.weight"] = torch.tensor([1.0] * 14) tensors[f"{prefix}.sub_sticky_raw.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_sticky.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_round_lsb_or_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_round_lsb_or_sticky.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sub_round_inc.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_round_inc.bias"] = torch.tensor([-2.0]) for i in range(10): p = f"{prefix}.sub_round.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_round_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_round_overflow.bias"] = torch.tensor([-0.5]) # ========================================================================= # STAGE 12: OUTPUT ASSEMBLY # ========================================================================= # Final output combines: # - Special cases (NaN, Inf) override normal computation # - For NaN: output canonical NaN (0x7E00) # - For Inf: output Inf with correct sign # - For normal: pack normalized result # NaN output: 0x7E00 = 0111111000000000 nan_bits = [0]*9 + [1] + [1]*5 + [0] # bits 0-15 # Final output mux: nan ? nan_val : (inf ? inf_val : (zero ? 0 : normal_val)) tensors[f"{prefix}.not_result_is_inf.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_inf.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_both_are_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_both_are_zero.bias"] = torch.tensor([0.0]) # both_neg_zeros: both inputs are negative zero → result is -0 tensors[f"{prefix}.both_neg_zeros.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.both_neg_zeros.bias"] = torch.tensor([-3.0]) tensors[f"{prefix}.not_exp_underflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_underflow.bias"] = torch.tensor([0.0]) # Normal case selector: NOT nan AND NOT inf AND NOT both_zero AND NOT sum_is_zero AND NOT both_exp_zero tensors[f"{prefix}.not_both_exp_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_both_exp_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.is_normal_result.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.is_normal_result.bias"] = torch.tensor([-4.5]) # Subnormal path: both exponents zero and sum non-zero (override normalization) tensors[f"{prefix}.both_exp_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.both_exp_zero.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.subnorm_condition.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.subnorm_condition.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.subnorm_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.subnorm_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.subnorm_enable.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.subnorm_enable.bias"] = torch.tensor([-3.5]) # Inf sign selection tensors[f"{prefix}.inf_sign_sel_a.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.inf_sign_sel_a.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.inf_sign_sel_b.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.inf_sign_sel_b.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.overflow_sign_sel.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.overflow_sign_sel.bias"] = torch.tensor([-2.5]) tensors[f"{prefix}.inf_sign.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.inf_sign.bias"] = torch.tensor([-1.0]) for i in range(16): # NaN path: output NaN bits gated by result_is_nan if nan_bits[i]: tensors[f"{prefix}.out_nan{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_nan{i}.bias"] = torch.tensor([-0.5]) # Inf path: exponent bits = 1, mantissa = 0, sign from inf operand if i >= 10 and i < 15: tensors[f"{prefix}.out_inf{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_inf{i}.bias"] = torch.tensor([-0.5]) # Normal path if i < 10: # Mantissa bits from norm_mant tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5]) elif i < 15: # Exponent bits from result_exp tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5]) else: # Sign bit from result_sign tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5]) # Subnormal path (both_exp_zero): use unnormalized mantissa bits, exponent=0/1 tensors[f"{prefix}.out_sub{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_sub{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.out{i}.sub_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.sub_gate.bias"] = torch.tensor([-2.0]) # Final output: 3-way mux (nan, inf, normal) + zero_sign for bit 15 tensors[f"{prefix}.out{i}.nan_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.nan_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.inf_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.inf_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.normal_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.normal_gate.bias"] = torch.tensor([-2.0]) if i == 15: # Sign bit: OR(nan_gate, inf_gate, normal_gate, sub_gate, both_neg_zeros) tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) else: tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0]) return tensors def build_float16_sub_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.sub circuit. Subtraction: a - b = a + (-b) Just flip the sign bit of b and use the add circuit. """ tensors = {} prefix = "float16.sub" # Flip sign bit of b: b_neg_sign = NOT b[15] tensors[f"{prefix}.b_neg_sign.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_neg_sign.bias"] = torch.tensor([0.0]) # Output bits: passthrough from float16.add with modified b for i in range(16): tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_float16_mul_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.mul circuit. IEEE 754 float16 multiplication: - Result sign = a_sign XOR b_sign - Result exponent = a_exp + b_exp - bias (15) - Result mantissa = a_mant * b_mant (22-bit product, normalize to 11) Special cases: - NaN * anything = NaN - Inf * 0 = NaN - Inf * finite = Inf - 0 * finite = 0 """ tensors = {} prefix = "float16.mul" # ========================================================================= # STAGE 1: UNPACK AND DETECT SPECIAL CASES # ========================================================================= # exp_a_all_ones: exponent = 31 (inf or nan) tensors[f"{prefix}.exp_a_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_a_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.exp_b_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_b_all_ones.bias"] = torch.tensor([-5.0]) # exp_a_zero: exponent = 0 (zero or subnormal) tensors[f"{prefix}.exp_a_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_a_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_b_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_b_zero.bias"] = torch.tensor([0.0]) # Adjusted exp bit0 for subnormals (effective exponent = 1) tensors[f"{prefix}.a_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_adj_exp0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_adj_exp0.bias"] = torch.tensor([-1.0]) # mant_a_nonzero / mant_b_nonzero tensors[f"{prefix}.mant_a_nonzero.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.mant_a_nonzero.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_b_nonzero.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.mant_b_nonzero.bias"] = torch.tensor([-1.0]) # mant_a_zero = NOT mant_a_nonzero tensors[f"{prefix}.mant_a_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_a_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.mant_b_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_b_zero.bias"] = torch.tensor([0.0]) for i in range(10): tensors[f"{prefix}.mant_a_norm{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.mant_a_norm{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.mant_b_norm{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.mant_b_norm{i}.bias"] = torch.tensor([-0.5]) # a_is_nan = exp_all_ones AND mant_nonzero tensors[f"{prefix}.a_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_nan.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_nan.bias"] = torch.tensor([-2.0]) # a_is_inf = exp_all_ones AND mant_zero tensors[f"{prefix}.a_is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_inf.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_inf.bias"] = torch.tensor([-2.0]) # a_is_zero = exp_zero AND mant_zero tensors[f"{prefix}.a_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_zero.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_zero.bias"] = torch.tensor([-2.0]) # either_is_nan = a_is_nan OR b_is_nan tensors[f"{prefix}.either_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_is_nan.bias"] = torch.tensor([-1.0]) # either_is_inf = a_is_inf OR b_is_inf tensors[f"{prefix}.either_is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_is_inf.bias"] = torch.tensor([-1.0]) # either_is_zero = a_is_zero OR b_is_zero tensors[f"{prefix}.either_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_is_zero.bias"] = torch.tensor([-1.0]) # inf_times_zero = either_is_inf AND either_is_zero (produces NaN) tensors[f"{prefix}.inf_times_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.inf_times_zero.bias"] = torch.tensor([-2.0]) # result_is_nan = either_is_nan OR inf_times_zero tensors[f"{prefix}.result_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_is_nan.bias"] = torch.tensor([-1.0]) # result_is_inf = (either_is_inf OR exp_overflow_to_inf) AND NOT(result_is_nan) AND NOT(either_is_zero) tensors[f"{prefix}.not_result_is_nan.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_nan.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_either_is_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_either_is_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.result_is_inf.weight"] = torch.tensor([1.0, 1.0, 2.0, 2.0]) tensors[f"{prefix}.result_is_inf.bias"] = torch.tensor([-4.5]) # result_is_zero = either_is_zero AND NOT(result_is_nan) tensors[f"{prefix}.result_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_is_zero.bias"] = torch.tensor([-2.0]) # ========================================================================= # STAGE 2: RESULT SIGN (XOR of input signs) # ========================================================================= tensors[f"{prefix}.result_sign.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_sign.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.result_sign.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{prefix}.result_sign.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{prefix}.result_sign.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_sign.layer2.bias"] = torch.tensor([-2.0]) # ========================================================================= # STAGE 3: MANTISSA MULTIPLICATION (11x11 -> 22 bits) # ========================================================================= # Implicit bits for mantissas tensors[f"{prefix}.implicit_a.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_a.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.implicit_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_b.bias"] = torch.tensor([0.0]) # Full mantissas: mant_a[10:0] = {implicit_a, a[9:0]} # mant_b[10:0] = {implicit_b, b[9:0]} # 11x11 array multiplier produces 22-bit product # Partial products: pp[i][j] = mant_a[i] AND mant_b[j] for i in range(11): for j in range(11): tensors[f"{prefix}.pp{i}_{j}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.pp{i}_{j}.bias"] = torch.tensor([-2.0]) # Reduction tree using carry-save adders # Level 0: 11 partial product rows # We'll use a Wallace tree reduction # For simplicity, use a ripple reduction approach # Row i contributes to columns i through i+10 # Sum columns using full adder chains # Column sums using compressor tree for col in range(22): # Count how many partial products contribute to this column count = 0 for i in range(11): j = col - i if 0 <= j < 11: count += 1 if count == 0: continue elif count == 1: # Single bit, just pass through tensors[f"{prefix}.col{col}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.col{col}.bias"] = torch.tensor([-0.5]) else: # Multi-bit column: compute parity (sum mod 2) using threshold gates # parity = (ge1 AND NOT ge2) OR (ge3 AND NOT ge4) OR ... # This captures: sum is odd when in range [1], [3,4), [5,6), etc. # Threshold gates: ge{t} = 1 if sum >= t for t in range(1, count + 1): tensors[f"{prefix}.col{col}_ge{t}.weight"] = torch.tensor([1.0] * count) tensors[f"{prefix}.col{col}_ge{t}.bias"] = torch.tensor([-float(t)]) # NOT gates for even thresholds for t in range(2, count + 1, 2): tensors[f"{prefix}.col{col}_not_ge{t}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.col{col}_not_ge{t}.bias"] = torch.tensor([0.0]) # AND gates for odd ranges: (ge1 AND NOT ge2), (ge3 AND NOT ge4), ... odd_ranges = [] for t in range(1, count + 1, 2): if t + 1 <= count: # ge{t} AND NOT ge{t+1} tensors[f"{prefix}.col{col}_odd{t}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.col{col}_odd{t}.bias"] = torch.tensor([-2.0]) odd_ranges.append(t) else: # ge{t} only (no upper bound needed if t is max) tensors[f"{prefix}.col{col}_odd{t}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.col{col}_odd{t}.bias"] = torch.tensor([-0.5]) odd_ranges.append(t) # col_sum = OR of all odd ranges (parity = bit 0) num_odd = len(odd_ranges) tensors[f"{prefix}.col{col}_sum.weight"] = torch.tensor([1.0] * num_odd) tensors[f"{prefix}.col{col}_sum.bias"] = torch.tensor([-0.5]) # col_bit1 = floor(sum/2) mod 2 = parity of [2,3], [6,7], [10,11], ... # This is (ge2 AND NOT ge4) OR (ge6 AND NOT ge8) OR ... if count >= 2: bit1_ranges = [] for t in range(2, count + 1, 4): upper = t + 2 if upper <= count: tensors[f"{prefix}.col{col}_bit1_{t}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.col{col}_bit1_{t}.bias"] = torch.tensor([-2.0]) if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors: tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0]) else: tensors[f"{prefix}.col{col}_bit1_{t}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.col{col}_bit1_{t}.bias"] = torch.tensor([-0.5]) bit1_ranges.append(t) if bit1_ranges: tensors[f"{prefix}.col{col}_bit1.weight"] = torch.tensor([1.0] * len(bit1_ranges)) tensors[f"{prefix}.col{col}_bit1.bias"] = torch.tensor([-0.5]) # col_bit2 = floor(sum/4) mod 2 = parity of [4,7], [12,15], ... # This is (ge4 AND NOT ge8) OR (ge12 AND NOT ge16) OR ... if count >= 4: bit2_ranges = [] for t in range(4, count + 1, 8): upper = t + 4 if upper <= count: tensors[f"{prefix}.col{col}_bit2_{t}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.col{col}_bit2_{t}.bias"] = torch.tensor([-2.0]) if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors: tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0]) else: tensors[f"{prefix}.col{col}_bit2_{t}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.col{col}_bit2_{t}.bias"] = torch.tensor([-0.5]) bit2_ranges.append(t) if bit2_ranges: tensors[f"{prefix}.col{col}_bit2.weight"] = torch.tensor([1.0] * len(bit2_ranges)) tensors[f"{prefix}.col{col}_bit2.bias"] = torch.tensor([-0.5]) # col_bit3 = floor(sum/8) mod 2 (for col10 with 11 PPs) if count >= 8: bit3_ranges = [] for t in range(8, count + 1, 16): upper = t + 8 if upper <= count: tensors[f"{prefix}.col{col}_bit3_{t}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.col{col}_bit3_{t}.bias"] = torch.tensor([-2.0]) if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors: tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0]) else: tensors[f"{prefix}.col{col}_bit3_{t}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.col{col}_bit3_{t}.bias"] = torch.tensor([-0.5]) bit3_ranges.append(t) if bit3_ranges: tensors[f"{prefix}.col{col}_bit3.weight"] = torch.tensor([1.0] * len(bit3_ranges)) tensors[f"{prefix}.col{col}_bit3.bias"] = torch.tensor([-0.5]) # Carry accumulator for multi-bit carries # For position i, incoming carries are: bit1[i-1], bit2[i-2], bit3[i-3] # We need to sum these and produce: carry_acc_sum (parity), carry_acc_carry (sum >= 2) def get_pp_count(col): if col < 0 or col > 20: return 0 return min(col + 1, 21 - col) for i in range(22): # Determine which carry bits come into position i carry_inputs = [] # bit1 from col[i-1] if i >= 1 and get_pp_count(i-1) >= 2: carry_inputs.append(f"bit1_{i-1}") # bit2 from col[i-2] if i >= 2 and get_pp_count(i-2) >= 4: carry_inputs.append(f"bit2_{i-2}") # bit3 from col[i-3] if i >= 3 and get_pp_count(i-3) >= 8: carry_inputs.append(f"bit3_{i-3}") if len(carry_inputs) == 0: # No carries, use #0 pass elif len(carry_inputs) == 1: # Single carry, no accumulator needed pass else: # Multiple carries, need accumulator n = len(carry_inputs) # Parity (sum mod 2) using threshold gates # ge{t} = sum >= t for t in range(1, n + 1): tensors[f"{prefix}.carry_acc{i}_ge{t}.weight"] = torch.tensor([1.0] * n) tensors[f"{prefix}.carry_acc{i}_ge{t}.bias"] = torch.tensor([-float(t) + 0.5]) # NOT gates for even thresholds for t in range(2, n + 1, 2): tensors[f"{prefix}.carry_acc{i}_not_ge{t}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.carry_acc{i}_not_ge{t}.bias"] = torch.tensor([0.0]) # AND gates for odd ranges: (ge1 AND NOT ge2), (ge3 AND NOT ge4), ... odd_ranges = [] for t in range(1, n + 1, 2): if t + 1 <= n: tensors[f"{prefix}.carry_acc{i}_odd{t}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.carry_acc{i}_odd{t}.bias"] = torch.tensor([-2.0]) else: tensors[f"{prefix}.carry_acc{i}_odd{t}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.carry_acc{i}_odd{t}.bias"] = torch.tensor([-0.5]) odd_ranges.append(t) # carry_acc_sum = OR of odd ranges tensors[f"{prefix}.carry_acc{i}_sum.weight"] = torch.tensor([1.0] * len(odd_ranges)) tensors[f"{prefix}.carry_acc{i}_sum.bias"] = torch.tensor([-0.5]) # carry_acc_carry = ge2 (sum >= 2) if n >= 2: tensors[f"{prefix}.carry_acc{i}_carry.weight"] = torch.tensor([1.0] * n) tensors[f"{prefix}.carry_acc{i}_carry.bias"] = torch.tensor([-1.5]) # Final product assembly using ripple carry # First pass: add col_sum + carry_acc_sum for i in range(22): p = f"{prefix}.prod_fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Second pass: add intermediate result + carry_acc_carry (secondary carries) # This resolves the multi-bit carry propagation issue for i in range(22): p = f"{prefix}.prod2_fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 4: EXPONENT ADDITION # ========================================================================= # Result exponent = exp_a + exp_b - 15 (bias adjustment) # First add exp_a + exp_b using 6-bit adder (max 31+31=62, need 6 bits) for i in range(6): p = f"{prefix}.exp_add.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Subtract 15: exp_sum - 15 using subtractor # NOT of 15 = NOT(01111) = 10000 for i in range(6): p = f"{prefix}.exp_sub.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 5: NORMALIZE PRODUCT # ========================================================================= # Product is in range [1.0, 4.0) since each mantissa is [1.0, 2.0) # If product >= 2.0, right shift by 1 and increment exponent # Check if bit 21 is set (product >= 2.0) tensors[f"{prefix}.prod_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.prod_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.not_prod_overflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_prod_overflow.bias"] = torch.tensor([0.0]) # Normalized mantissa selection (10 bits) # If overflow: take bits 20:11, else take bits 19:10 for i in range(10): tensors[f"{prefix}.norm_mant{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_mant{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_mant{i}.eq10_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.eq10_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_mant{i}.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.bias"] = torch.tensor([-1.0]) # Exponent adjustment: if overflow, add 1 for i in range(5): p = f"{prefix}.result_exp_fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # ========================================================================= # STAGE 5B: ROUNDING (round-to-nearest-even) # ========================================================================= tensors[f"{prefix}.round_guard_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_guard_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_guard_norm.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_guard_norm.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_guard_overflow.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_guard_overflow.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_guard_norm.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_guard_norm.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_guard.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_guard.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sticky_overflow.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.sticky_overflow.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sticky_norm.weight"] = torch.tensor([1.0] * 9) tensors[f"{prefix}.sticky_norm.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_sticky_overflow.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_sticky_overflow.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_sticky_norm.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_sticky_norm.and.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.round_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_sticky.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_lsb_or_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_lsb_or_sticky.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_inc.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_inc.bias"] = torch.tensor([-2.0]) for i in range(10): p = f"{prefix}.round_norm.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.not_round_overflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_round_overflow.bias"] = torch.tensor([0.0]) for i in range(5): p = f"{prefix}.round_exp.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(5): tensors[f"{prefix}.final_exp{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.final_exp{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.final_exp{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.final_exp{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.final_exp{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.final_exp{i}.bias"] = torch.tensor([-1.0]) for i in range(10): tensors[f"{prefix}.final_mant{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.final_mant{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.final_exp_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.final_exp_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.round_exp_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.round_exp_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.exp_overflow_any.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.exp_overflow_any.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_overflow_to_inf.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.exp_overflow_to_inf.bias"] = torch.tensor([-2.5]) # ========================================================================= # STAGE 5C: SUBNORMAL UNDERFLOW HANDLING # ========================================================================= tensors[f"{prefix}.exp_sub_borrow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_sub_borrow.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_sub_zero.weight"] = torch.tensor([-1.0] * 6) tensors[f"{prefix}.exp_sub_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_sub_zero_and_npo.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_sub_zero_and_npo.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.exp_underflow.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.exp_underflow.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_overflow_any.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_overflow_any.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_dec_zero_and_no_overflow.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_dec_zero_and_no_overflow.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.not_exp_underflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_underflow.bias"] = torch.tensor([0.0]) # sub_shift_base = 16 - exp_add (5-bit) for i in range(5): tensors[f"{prefix}.not_exp_add{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_add{i}.bias"] = torch.tensor([0.0]) for i in range(5): p = f"{prefix}.sub_shift_base.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(5): p = f"{prefix}.sub_shift.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(5): tensors[f"{prefix}.not_sub_shift{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sub_shift{i}.bias"] = torch.tensor([0.0]) # norm_full mux (12 bits) for subnormal shifting for i in range(12): tensors[f"{prefix}.norm_full{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_full{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_full{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_full{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_full{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_full{i}.bias"] = torch.tensor([-1.0]) # CLZ on norm_full for normalization shift amount for k in range(1, 13): tensors[f"{prefix}.norm_pz{k}.weight"] = torch.tensor([-1.0] * k) tensors[f"{prefix}.norm_pz{k}.bias"] = torch.tensor([0.0]) for k in range(1, 13): tensors[f"{prefix}.norm_ge{k}.weight"] = torch.tensor([1.0] * 12) tensors[f"{prefix}.norm_ge{k}.bias"] = torch.tensor([-float(k)]) for k in [2, 4, 6, 8, 10, 12]: tensors[f"{prefix}.norm_not_ge{k}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.norm_not_ge{k}.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.norm_shift3.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.norm_shift3.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.norm_and_4_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_4_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_and_12.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.norm_and_12.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.norm_shift2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_shift2.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.norm_and_2_3.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_2_3.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_and_6_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_6_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_and_10_11.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_10_11.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_shift1.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.norm_shift1.bias"] = torch.tensor([-1.0]) for i in [1, 3, 5, 7, 9, 11]: tensors[f"{prefix}.norm_and_{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_and_{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_shift0.weight"] = torch.tensor([1.0] * 6) tensors[f"{prefix}.norm_shift0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.norm_shift_eq10.weight"] = torch.tensor([-1.0, 1.0, -1.0, 1.0]) tensors[f"{prefix}.norm_shift_eq10.bias"] = torch.tensor([-1.5]) for i in range(4): tensors[f"{prefix}.not_norm_shift{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift{i}.bias"] = torch.tensor([0.0]) for i in range(12): tensors[f"{prefix}.lshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s0_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 0: tensors[f"{prefix}.lshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s0_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s0_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s0_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.lshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s1_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 1: tensors[f"{prefix}.lshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s1_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s1_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s1_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.lshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s2_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 3: tensors[f"{prefix}.lshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s2_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s2_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s2_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.lshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s3_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 7: tensors[f"{prefix}.lshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s3_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s3_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s3_{i}.bias"] = torch.tensor([-1.0]) # Left barrel shifter for guard/sticky from low product bits (11 bits) for i in range(11): tensors[f"{prefix}.guard_lshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_lshift_s0_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 0: tensors[f"{prefix}.guard_lshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_lshift_s0_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.guard_lshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.guard_lshift_s0_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.guard_lshift_s0_{i}.bias"] = torch.tensor([-1.0]) for i in range(11): tensors[f"{prefix}.guard_lshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_lshift_s1_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 1: tensors[f"{prefix}.guard_lshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_lshift_s1_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.guard_lshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.guard_lshift_s1_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.guard_lshift_s1_{i}.bias"] = torch.tensor([-1.0]) for i in range(11): tensors[f"{prefix}.guard_lshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_lshift_s2_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 3: tensors[f"{prefix}.guard_lshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_lshift_s2_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.guard_lshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.guard_lshift_s2_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.guard_lshift_s2_{i}.bias"] = torch.tensor([-1.0]) for i in range(11): tensors[f"{prefix}.guard_lshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_lshift_s3_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 7: tensors[f"{prefix}.guard_lshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_lshift_s3_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.guard_lshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.guard_lshift_s3_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.guard_lshift_s3_{i}.bias"] = torch.tensor([-1.0]) # Left barrel shifter for normal-path mantissa from product bits (20 bits) for i in range(22): tensors[f"{prefix}.prod_lshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.prod_lshift_s0_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 0: tensors[f"{prefix}.prod_lshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.prod_lshift_s0_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.prod_lshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.prod_lshift_s0_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.prod_lshift_s0_{i}.bias"] = torch.tensor([-1.0]) for i in range(22): tensors[f"{prefix}.prod_lshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.prod_lshift_s1_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 1: tensors[f"{prefix}.prod_lshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.prod_lshift_s1_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.prod_lshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.prod_lshift_s1_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.prod_lshift_s1_{i}.bias"] = torch.tensor([-1.0]) for i in range(22): tensors[f"{prefix}.prod_lshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.prod_lshift_s2_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 3: tensors[f"{prefix}.prod_lshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.prod_lshift_s2_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.prod_lshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.prod_lshift_s2_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.prod_lshift_s2_{i}.bias"] = torch.tensor([-1.0]) for i in range(22): tensors[f"{prefix}.prod_lshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.prod_lshift_s3_{i}.pass.bias"] = torch.tensor([-2.0]) if i > 7: tensors[f"{prefix}.prod_lshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.prod_lshift_s3_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.prod_lshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.prod_lshift_s3_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.prod_lshift_s3_{i}.bias"] = torch.tensor([-1.0]) for i in range(4): tensors[f"{prefix}.not_norm_shift_sub{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift_sub{i}.bias"] = torch.tensor([0.0]) for i in range(5): p = f"{prefix}.exp_dec.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_dec_borrow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_dec_borrow.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_dec_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_dec_zero.bias"] = torch.tensor([0.0]) for i in range(5): tensors[f"{prefix}.not_exp_dec{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_dec{i}.bias"] = torch.tensor([0.0]) for k in range(12): tensors[f"{prefix}.norm_shift_gt{k}.weight"] = torch.tensor([1.0, 2.0, 4.0, 8.0]) tensors[f"{prefix}.norm_shift_gt{k}.bias"] = torch.tensor([-float(k+1) + 0.5]) for i in range(12): tensors[f"{prefix}.norm_sticky_part{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_sticky_part{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_shifted_out_or.weight"] = torch.tensor([1.0] * 12) tensors[f"{prefix}.norm_shifted_out_or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sticky_norm_ext.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sticky_norm_ext.bias"] = torch.tensor([-0.5]) # Subnormal source selector for right shift (handles prod_overflow offset) for i in range(12): tensors[f"{prefix}.sub_src{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_src{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_src{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_src{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_src{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_src{i}.bias"] = torch.tensor([-1.0]) # Right barrel shifter for subnormal (12 bits) for i in range(12): tensors[f"{prefix}.sub_rshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 11: tensors[f"{prefix}.sub_rshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s0_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.sub_rshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 10: tensors[f"{prefix}.sub_rshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s1_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.sub_rshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 8: tensors[f"{prefix}.sub_rshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s2_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.sub_rshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 4: tensors[f"{prefix}.sub_rshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s3_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.sub_shifted{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_shifted{i}.bias"] = torch.tensor([-2.0]) for i in range(10): tensors[f"{prefix}.sub_mant{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_mant{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sub_guard.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_guard.bias"] = torch.tensor([-0.5]) for k in range(12): tensors[f"{prefix}.sub_shift_gt{k}.weight"] = torch.tensor([2**i for i in range(5)]) tensors[f"{prefix}.sub_shift_gt{k}.bias"] = torch.tensor([-float(k+1) + 0.5]) for i in range(12): tensors[f"{prefix}.sub_sticky_part{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_sticky_part{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_sticky_raw.weight"] = torch.tensor([1.0] * 12) tensors[f"{prefix}.sub_sticky_raw.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_sticky.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_round_lsb_or_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_round_lsb_or_sticky.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sub_round_inc.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_round_inc.bias"] = torch.tensor([-2.0]) for i in range(10): p = f"{prefix}.sub_round.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_round_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_round_overflow.bias"] = torch.tensor([-0.5]) # ========================================================================= # STAGE 6: OUTPUT ASSEMBLY # ========================================================================= tensors[f"{prefix}.not_result_is_inf.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_inf.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_result_is_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_exp_underflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_underflow.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.is_normal_result.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.is_normal_result.bias"] = torch.tensor([-3.5]) tensors[f"{prefix}.subnorm_enable.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.subnorm_enable.bias"] = torch.tensor([-3.5]) for i in range(16): if i < 10: tensors[f"{prefix}.out_nan{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_nan{i}.bias"] = torch.tensor([-0.5]) elif i < 15: tensors[f"{prefix}.out_nan{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_nan{i}.bias"] = torch.tensor([-0.5]) if i < 10: tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5]) elif i < 15: tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.out{i}.nan_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.nan_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.inf_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.inf_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.zero_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.zero_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.normal_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.normal_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out_sub{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_sub{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.out{i}.sub_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.sub_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0]) return tensors def build_float16_div_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.div circuit. IEEE 754 float16 division: - Result sign = a_sign XOR b_sign - Result exponent = a_exp - b_exp + bias (15) - Result mantissa = a_mant / b_mant (using non-restoring division) Special cases: - NaN / anything = NaN - anything / NaN = NaN - Inf / Inf = NaN - 0 / 0 = NaN - Inf / finite = Inf - finite / Inf = 0 - finite / 0 = Inf - 0 / finite = 0 """ tensors = {} prefix = "float16.div" # Similar structure to mul but with subtraction for exponents # and iterative division for mantissas # Special case detection (same as mul) tensors[f"{prefix}.exp_a_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_a_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.exp_b_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_b_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.exp_a_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_a_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_b_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_b_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.a_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_adj_exp0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_adj_exp0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_adj_exp0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_a_nonzero.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.mant_a_nonzero.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_b_nonzero.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.mant_b_nonzero.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_a_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_a_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.mant_b_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_b_zero.bias"] = torch.tensor([0.0]) # Subnormal input normalization tensors[f"{prefix}.a_is_subnormal.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_subnormal.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_subnormal.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_subnormal.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.implicit_a_raw.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_a_raw.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.implicit_b_raw.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_b_raw.bias"] = torch.tensor([0.0]) for k in range(1, 10): tensors[f"{prefix}.a_pz{k}.weight"] = torch.tensor([-1.0] * k) tensors[f"{prefix}.a_pz{k}.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.b_pz{k}.weight"] = torch.tensor([-1.0] * k) tensors[f"{prefix}.b_pz{k}.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.a_lead9.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.a_lead9.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.b_lead9.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.b_lead9.bias"] = torch.tensor([-0.5]) for i in range(9): tensors[f"{prefix}.a_lead{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_lead{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_lead{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_lead{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.a_shift_raw0.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.a_shift_raw0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.a_shift_raw1.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.a_shift_raw1.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.a_shift_raw2.weight"] = torch.tensor([1.0] * 4) tensors[f"{prefix}.a_shift_raw2.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.a_shift_raw3.weight"] = torch.tensor([1.0] * 3) tensors[f"{prefix}.a_shift_raw3.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_shift_raw0.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.b_shift_raw0.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_shift_raw1.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.b_shift_raw1.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_shift_raw2.weight"] = torch.tensor([1.0] * 4) tensors[f"{prefix}.b_shift_raw2.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_shift_raw3.weight"] = torch.tensor([1.0] * 3) tensors[f"{prefix}.b_shift_raw3.bias"] = torch.tensor([-1.0]) for i in range(4): tensors[f"{prefix}.a_shift{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_shift{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_shift{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_shift{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.not_a_shift{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_a_shift{i}.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_b_shift{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_b_shift{i}.bias"] = torch.tensor([0.0]) for stage in range(4): shift_amt = 1 << stage for i in range(11): tensors[f"{prefix}.a_norm_s{stage}_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_norm_s{stage}_{i}.pass.bias"] = torch.tensor([-2.0]) if i >= shift_amt: tensors[f"{prefix}.a_norm_s{stage}_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_norm_s{stage}_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.a_norm_s{stage}_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.a_norm_s{stage}_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.a_norm_s{stage}_{i}.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.b_norm_s{stage}_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_norm_s{stage}_{i}.pass.bias"] = torch.tensor([-2.0]) if i >= shift_amt: tensors[f"{prefix}.b_norm_s{stage}_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_norm_s{stage}_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_norm_s{stage}_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.b_norm_s{stage}_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.b_norm_s{stage}_{i}.bias"] = torch.tensor([-1.0]) for i in range(10): tensors[f"{prefix}.mant_a_norm{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.mant_a_norm{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.mant_b_norm{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.mant_b_norm{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.a_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_nan.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_nan.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.a_is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_inf.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_inf.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.a_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.a_is_zero.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.b_is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.b_is_zero.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.either_is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.either_is_nan.bias"] = torch.tensor([-1.0]) # inf/inf = NaN, 0/0 = NaN tensors[f"{prefix}.both_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.both_inf.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.both_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.both_zero.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.result_is_nan.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.result_is_nan.bias"] = torch.tensor([-1.0]) # finite/0 = inf, inf/finite = inf tensors[f"{prefix}.not_a_is_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_a_is_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.finite_div_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.finite_div_zero.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.not_b_is_inf.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_b_is_inf.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.inf_div_finite.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.inf_div_finite.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.not_result_is_nan.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_nan.bias"] = torch.tensor([0.0]) # result_is_inf = (finite_div_zero OR inf_div_finite OR exp_overflow_to_inf) # AND not_result_is_nan AND not_result_is_zero # Use weighted gate: cond1 + cond2 + cond3 + 3*not_nan + 3*not_zero >= 6.5 tensors[f"{prefix}.result_is_inf.weight"] = torch.tensor([1.0, 1.0, 1.0, 3.0, 3.0]) tensors[f"{prefix}.result_is_inf.bias"] = torch.tensor([-6.5]) # 0/finite = 0, finite/inf = 0 tensors[f"{prefix}.not_b_is_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_b_is_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.zero_div_finite.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.zero_div_finite.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.not_a_is_inf.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_a_is_inf.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.finite_div_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.finite_div_inf.bias"] = torch.tensor([-2.0]) # result_is_zero = (zero_div_finite OR finite_div_inf) AND not_result_is_nan tensors[f"{prefix}.result_is_zero.weight"] = torch.tensor([1.0, 1.0, 2.0]) tensors[f"{prefix}.result_is_zero.bias"] = torch.tensor([-2.5]) # Result sign (XOR) tensors[f"{prefix}.result_sign.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_sign.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.result_sign.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{prefix}.result_sign.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{prefix}.result_sign.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.result_sign.layer2.bias"] = torch.tensor([-2.0]) # Exponent subtraction: exp_a - exp_b + 15 # NOT gates for exp_b for i in range(5): tensors[f"{prefix}.not_exp_b{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_b{i}.bias"] = torch.tensor([0.0]) # 6-bit subtractor for exp_a - exp_b for i in range(6): p = f"{prefix}.exp_sub.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Adjust exponent for subnormal normalization (exp_sub - a_shift + b_shift) for i in range(6): p = f"{prefix}.exp_sub_a.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(6): p = f"{prefix}.exp_sub_ab.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Add 15 to get biased result for i in range(6): p = f"{prefix}.exp_add15.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_underflow_borrow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.exp_underflow_borrow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.exp_norm_zero.weight"] = torch.tensor([-1.0] * 6) tensors[f"{prefix}.exp_norm_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.exp_underflow.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_underflow.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_norm_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_norm_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.exp_out_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_out_all_ones.bias"] = torch.tensor([-5.0]) tensors[f"{prefix}.exp_overflow_any.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_overflow_any.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.exp_overflow_to_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_overflow_to_inf.bias"] = torch.tensor([-2.0]) # Mantissa division: 11-bit non-restoring division # Produces 11-bit quotient tensors[f"{prefix}.implicit_a.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_a.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.implicit_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_b.bias"] = torch.tensor([0.0]) for i in range(10): tensors[f"{prefix}.not_div_b{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_div_b{i}.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_implicit_b.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_implicit_b.bias"] = torch.tensor([0.0]) # Division iterations (13 steps for quotient + guard bits) for step in range(13): p = f"{prefix}.div_step{step}" # Compare and subtract for i in range(12): tensors[f"{p}.sub.fa{i}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.sub.fa{i}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.sub.fa{i}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.sub.fa{i}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.sub.fa{i}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.sub.fa{i}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.sub.fa{i}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.sub.fa{i}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.sub.fa{i}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.sub.fa{i}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.sub.fa{i}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.sub.fa{i}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.sub.fa{i}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.sub.fa{i}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.sub.fa{i}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.sub.fa{i}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.sub.fa{i}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.sub.fa{i}.cout.bias"] = torch.tensor([-1.0]) # Quotient bit tensors[f"{p}.q_bit.weight"] = torch.tensor([1.0]) tensors[f"{p}.q_bit.bias"] = torch.tensor([-0.5]) tensors[f"{p}.not_q_bit.weight"] = torch.tensor([-1.0]) tensors[f"{p}.not_q_bit.bias"] = torch.tensor([0.0]) # Remainder selection for i in range(12): tensors[f"{p}.rem{i}.sub_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.rem{i}.sub_path.bias"] = torch.tensor([-2.0]) tensors[f"{p}.rem{i}.shift_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.rem{i}.shift_path.bias"] = torch.tensor([-2.0]) tensors[f"{p}.rem{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.rem{i}.bias"] = torch.tensor([-1.0]) # Normalization tensors[f"{prefix}.need_norm.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.need_norm.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_need_norm.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_need_norm.bias"] = torch.tensor([0.0]) for i in range(6): p = f"{prefix}.exp_norm.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(10): tensors[f"{prefix}.norm_mant{i}.norm_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.norm_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_mant{i}.direct_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.direct_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_mant{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_mant{i}.bias"] = torch.tensor([-1.0]) # Rounding (guard/sticky) for normal path tensors[f"{prefix}.rem_nonzero.weight"] = torch.tensor([1.0] * 12) tensors[f"{prefix}.rem_nonzero.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.guard_direct.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_direct.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.guard_norm.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_norm.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.guard_bit.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.guard_bit.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sticky_direct_or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sticky_direct_or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sticky_direct.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sticky_direct.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sticky_norm.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sticky_norm.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sticky_bit.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sticky_bit.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_or.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.round_inc.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_inc.bias"] = torch.tensor([-2.0]) for i in range(10): p = f"{prefix}.round_add.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.mant_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.not_mant_overflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_mant_overflow.bias"] = torch.tensor([0.0]) for i in range(5): p = f"{prefix}.exp_inc.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(5): tensors[f"{prefix}.exp_out{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_out{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.exp_out{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_out{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.exp_out{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_out{i}.bias"] = torch.tensor([-1.0]) for i in range(10): tensors[f"{prefix}.mant_out{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_out{i}.bias"] = torch.tensor([-2.0]) # Subnormal path (shift + round) tensors[f"{prefix}.quot_implicit_norm.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.quot_implicit_norm.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.quot_implicit_direct.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.quot_implicit_direct.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.quot_implicit.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.quot_implicit.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.norm_full{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.norm_full{i}.bias"] = torch.tensor([-0.5]) for i in range(5): tensors[f"{prefix}.not_exp_norm{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_norm{i}.bias"] = torch.tensor([0.0]) for i in range(5): p = f"{prefix}.sub_shift_base.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(5): p = f"{prefix}.sub_shift.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) for i in range(5): tensors[f"{prefix}.not_sub_shift{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sub_shift{i}.bias"] = torch.tensor([0.0]) # Right barrel shifter for subnormal (12 bits) for i in range(12): tensors[f"{prefix}.sub_rshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 11: tensors[f"{prefix}.sub_rshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s0_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s0_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.sub_rshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 10: tensors[f"{prefix}.sub_rshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s1_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s1_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.sub_rshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 8: tensors[f"{prefix}.sub_rshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s2_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s2_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.sub_rshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.pass.bias"] = torch.tensor([-2.0]) if i < 4: tensors[f"{prefix}.sub_rshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.sub_rshift_s3_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_rshift_s3_{i}.bias"] = torch.tensor([-1.0]) for i in range(12): tensors[f"{prefix}.sub_shifted{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_shifted{i}.bias"] = torch.tensor([-2.0]) for i in range(10): tensors[f"{prefix}.sub_mant{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_mant{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sub_guard.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_guard.bias"] = torch.tensor([-0.5]) for k in range(12): tensors[f"{prefix}.sub_shift_gt{k}.weight"] = torch.tensor([2**i for i in range(5)]) tensors[f"{prefix}.sub_shift_gt{k}.bias"] = torch.tensor([-float(k+1) + 0.5]) for i in range(12): tensors[f"{prefix}.sub_sticky_part{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_sticky_part{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.sub_sticky_raw.weight"] = torch.tensor([1.0] * 12) tensors[f"{prefix}.sub_sticky_raw.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_sticky.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_round_lsb_or_sticky.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_round_lsb_or_sticky.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sub_round_inc.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.sub_round_inc.bias"] = torch.tensor([-2.0]) for i in range(10): p = f"{prefix}.sub_round.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.sub_round_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.sub_round_overflow.bias"] = torch.tensor([-0.5]) # Output assembly tensors[f"{prefix}.not_result_is_inf.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_inf.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_result_is_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_zero.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.not_exp_underflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_underflow.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.is_normal_result.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.is_normal_result.bias"] = torch.tensor([-3.5]) tensors[f"{prefix}.subnorm_enable.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.subnorm_enable.bias"] = torch.tensor([-3.5]) for i in range(16): tensors[f"{prefix}.out{i}.nan_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.nan_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.inf_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.inf_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.zero_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.zero_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.normal_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.normal_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out_sub{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out_sub{i}.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.out{i}.sub_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.sub_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0]) return tensors def build_float16_toint_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.toint circuit. Convert float16 to signed 16-bit integer (truncate toward zero). Algorithm: 1. Extract mantissa M with implicit bit (11 bits, bit 10 = implicit 1) 2. For exp < 15: result = 0 (|value| < 1) 3. For exp >= 15: right-shift M by (25 - exp) positions - exp = 15: shift by 10, result = 1 for normalized - exp = 25: shift by 0, result = M (up to 2047) - exp > 25: would need left shift, but limited range 4. Apply sign via two's complement negation 5. Handle special cases: NaN, Inf, overflow """ tensors = {} prefix = "float16.toint" # === SPECIAL CASE DETECTION === # exp_all_ones: exponent = 31 (NaN or Inf) tensors[f"{prefix}.exp_all_ones.weight"] = torch.tensor([1.0] * 5) tensors[f"{prefix}.exp_all_ones.bias"] = torch.tensor([-5.0]) # exp_zero: exponent = 0 (zero or subnormal) tensors[f"{prefix}.exp_zero.weight"] = torch.tensor([-1.0] * 5) tensors[f"{prefix}.exp_zero.bias"] = torch.tensor([0.0]) # mant_nonzero: any mantissa bit set tensors[f"{prefix}.mant_nonzero.weight"] = torch.tensor([1.0] * 10) tensors[f"{prefix}.mant_nonzero.bias"] = torch.tensor([-1.0]) # is_nan: exp=31 AND mant!=0 tensors[f"{prefix}.is_nan.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.is_nan.bias"] = torch.tensor([-2.0]) # mant_zero: NOT mant_nonzero tensors[f"{prefix}.mant_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.mant_zero.bias"] = torch.tensor([0.0]) # is_inf: exp=31 AND mant=0 tensors[f"{prefix}.is_inf.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.is_inf.bias"] = torch.tensor([-2.0]) # is_zero: exp=0 AND mant=0 tensors[f"{prefix}.is_zero.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.is_zero.bias"] = torch.tensor([-2.0]) # === CHECK IF |VALUE| < 1 === # exp < 15 means unbiased exponent < 0, so |value| < 1 # Use threshold: sum(exp[i] * 2^i) < 15 weights = [-float(2**i) for i in range(5)] tensors[f"{prefix}.exp_lt_15.weight"] = torch.tensor(weights) tensors[f"{prefix}.exp_lt_15.bias"] = torch.tensor([14.0]) # result_is_zero: exp_zero OR exp_lt_15 OR is_nan tensors[f"{prefix}.result_is_zero.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.result_is_zero.bias"] = torch.tensor([-1.0]) # not_result_is_zero for muxing tensors[f"{prefix}.not_result_is_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_result_is_zero.bias"] = torch.tensor([0.0]) # === COMPUTE SHIFT AMOUNT: 25 - exp === # For right shift: shift_amt = 25 - exp (need 0 to 10 for normal range) # 25 = 0b11001, so we compute NOT(exp) + 25 + 1 = NOT(exp) + 26 # Actually simpler: use 25 - exp directly with threshold gates # We'll use a different approach: compute exp directly and use threshold # gates to determine shift amount bits # Implicit bit (always 1 for normalized numbers, 0 for subnormals) # implicit = NOT exp_zero tensors[f"{prefix}.implicit_bit.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.implicit_bit.bias"] = torch.tensor([0.0]) # === DIRECT SHIFT USING EXPONENT VALUE === # For exp in range 15-25, shift right by (25-exp) # For exp >= 25, no shift or left shift (overflow territory) # # Shift amounts needed: 0-10 for exp 25-15 # shift[0] = 1 if (25-exp) is odd = exp is even when exp in {15,17,19,21,23,25} # This is complex. Let's use threshold gates on exp value. # exp_ge_15: exp >= 15 (value >= 1) tensors[f"{prefix}.exp_ge_15.weight"] = torch.tensor([float(2**i) for i in range(5)]) tensors[f"{prefix}.exp_ge_15.bias"] = torch.tensor([-15.0]) # exp_ge_26: exp >= 26 (left shift needed for large magnitudes) tensors[f"{prefix}.exp_ge_26.weight"] = torch.tensor([float(2**i) for i in range(5)]) tensors[f"{prefix}.exp_ge_26.bias"] = torch.tensor([-26.0]) tensors[f"{prefix}.not_exp_ge_26.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp_ge_26.bias"] = torch.tensor([0.0]) # For each shift stage, determine if we should shift # Right shift by 2^k if bit k of (25-exp) is set # 25 - exp for exp in [15, 25]: shift in [10, 0] # Binary of shift amounts: # exp=15: shift=10 = 0b1010 # exp=16: shift=9 = 0b1001 # exp=17: shift=8 = 0b1000 # exp=18: shift=7 = 0b0111 # exp=19: shift=6 = 0b0110 # exp=20: shift=5 = 0b0101 # exp=21: shift=4 = 0b0100 # exp=22: shift=3 = 0b0011 # exp=23: shift=2 = 0b0010 # exp=24: shift=1 = 0b0001 # exp=25: shift=0 = 0b0000 # Use threshold on exp to determine shift control bits # shift_bit3 (shift by 8): exp <= 17 (shift >= 8) tensors[f"{prefix}.shift_bit3.weight"] = torch.tensor([-float(2**i) for i in range(5)]) tensors[f"{prefix}.shift_bit3.bias"] = torch.tensor([17.0]) # shift_bit2 (shift by 4): (exp <= 17) OR (18 <= exp <= 21) # = exp <= 21 AND NOT (18 <= exp <= 21 AND exp > 17)... complex # Simpler: shift_bit2 = 1 when shift in {4,5,6,7,12,13,14,15} ∩ [0,10] = {4,5,6,7} # = exp in {18,19,20,21} # Use: exp >= 18 AND exp <= 21 tensors[f"{prefix}.exp_ge_18.weight"] = torch.tensor([float(2**i) for i in range(5)]) tensors[f"{prefix}.exp_ge_18.bias"] = torch.tensor([-18.0]) tensors[f"{prefix}.exp_le_21.weight"] = torch.tensor([-float(2**i) for i in range(5)]) tensors[f"{prefix}.exp_le_21.bias"] = torch.tensor([21.0]) tensors[f"{prefix}.shift_bit2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.shift_bit2.bias"] = torch.tensor([-2.0]) # shift_bit1 (shift by 2): shift in {2,3,6,7,10,11,...} ∩ [0,10] = {2,3,6,7,10} # = exp in {15,19,22,23} -- this is getting complex # Let's use a simpler direct threshold approach # Actually, let's compute 25-exp using subtraction, then use those bits # 25 = 0b011001 (6 bits), exp is 5 bits # 25 - exp in two's complement for i in range(5): tensors[f"{prefix}.not_exp{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_exp{i}.bias"] = torch.tensor([0.0]) # 25 - exp = 25 + (~exp) + 1 = 26 + ~exp (in binary) # 26 = 0b011010 const_26 = [0, 1, 0, 1, 1, 0] # bits of 26 for i in range(6): p = f"{prefix}.shift_calc.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # exp_minus_25 = exp + (~25) + 1 = exp + 7 (5-bit) const_7 = [1, 1, 1, 0, 0] # bits of 7 for i in range(5): p = f"{prefix}.exp_minus_25.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # === RIGHT-SHIFT BARREL SHIFTER === # 4 stages for shifts of 1, 2, 4, 8 # Input: mantissa (10 bits) + implicit bit at position 10 = 11 bits # We'll work with 16 bits to have room for stage in range(4): shift_amt = 1 << stage # NOT of shift control bit for mux tensors[f"{prefix}.not_shift{stage}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_shift{stage}.bias"] = torch.tensor([0.0]) for i in range(16): # pass: keep current position (AND with NOT shift_bit) tensors[f"{prefix}.rshift_s{stage}_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.rshift_s{stage}_{i}.pass.bias"] = torch.tensor([-2.0]) # shift: take from higher position (AND with shift_bit) src_pos = i + shift_amt if src_pos < 16: tensors[f"{prefix}.rshift_s{stage}_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.rshift_s{stage}_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.rshift_s{stage}_{i}.weight"] = torch.tensor([1.0, 1.0]) else: # Shift in 0 from above tensors[f"{prefix}.rshift_s{stage}_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.rshift_s{stage}_{i}.bias"] = torch.tensor([-1.0]) # === LEFT-SHIFT BARREL SHIFTER (for exp > 25) === for stage in range(3): shift_amt = 1 << stage tensors[f"{prefix}.not_lshift{stage}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_lshift{stage}.bias"] = torch.tensor([0.0]) for i in range(16): tensors[f"{prefix}.lshift_s{stage}_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s{stage}_{i}.pass.bias"] = torch.tensor([-2.0]) src_pos = i - shift_amt if src_pos >= 0: tensors[f"{prefix}.lshift_s{stage}_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.lshift_s{stage}_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.lshift_s{stage}_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.lshift_s{stage}_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.lshift_s{stage}_{i}.bias"] = torch.tensor([-1.0]) # Select magnitude: right-shift for exp<=25, left-shift for exp>=26 for i in range(16): tensors[f"{prefix}.mag_sel{i}.left.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mag_sel{i}.left.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.mag_sel{i}.right.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mag_sel{i}.right.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.mag_sel{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mag_sel{i}.bias"] = torch.tensor([-1.0]) # === TWO'S COMPLEMENT NEGATION FOR NEGATIVE FLOATS === # If sign bit is 1, negate the result for i in range(16): tensors[f"{prefix}.not_mag{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_mag{i}.bias"] = torch.tensor([0.0]) for i in range(16): p = f"{prefix}.neg.fa{i}" tensors[f"{p}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # === OUTPUT SELECTION === # Select between positive path, negative path, and zero # Gate by not_result_is_zero to force output to 0 for |value| < 1 # NOT of sign bit for muxing positive path tensors[f"{prefix}.not_sign.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_sign.bias"] = torch.tensor([0.0]) for i in range(16): # pos_path = shifted_value AND not_sign AND not_result_is_zero tensors[f"{prefix}.out{i}.pos_path.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.out{i}.pos_path.bias"] = torch.tensor([-3.0]) # neg_path = negated_value AND sign AND not_result_is_zero tensors[f"{prefix}.out{i}.neg_path.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.out{i}.neg_path.bias"] = torch.tensor([-3.0]) # out = pos_path OR neg_path tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0]) return tensors def build_float16_fromint_tensors() -> Dict[str, torch.Tensor]: """Build tensors for float16.fromint circuit. Convert signed 16-bit integer to float16. Algorithm: 1. Take absolute value 2. Count leading zeros to find exponent 3. Shift to normalize mantissa 4. Apply sign """ tensors = {} prefix = "float16.fromint" # Check if zero tensors[f"{prefix}.is_zero.weight"] = torch.tensor([-1.0] * 16) tensors[f"{prefix}.is_zero.bias"] = torch.tensor([0.0]) # NOT is_zero for gating normal output tensors[f"{prefix}.not_is_zero.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_is_zero.bias"] = torch.tensor([0.0]) # Check if negative (sign bit) tensors[f"{prefix}.is_negative.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.is_negative.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.not_negative.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_negative.bias"] = torch.tensor([0.0]) # Absolute value: if negative, negate for i in range(16): tensors[f"{prefix}.not_in{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_in{i}.bias"] = torch.tensor([0.0]) for i in range(16): p = f"{prefix}.abs.fa{i}" tensors[f"{p}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Absolute value mux for i in range(16): tensors[f"{prefix}.abs{i}.neg_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.abs{i}.neg_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.abs{i}.pos_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.abs{i}.pos_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.abs{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.abs{i}.bias"] = torch.tensor([-1.0]) # CLZ on 16-bit absolute value for k in range(1, 17): tensors[f"{prefix}.pz{k}.weight"] = torch.tensor([-1.0] * k) tensors[f"{prefix}.pz{k}.bias"] = torch.tensor([0.0]) for k in range(1, 17): tensors[f"{prefix}.ge{k}.weight"] = torch.tensor([1.0] * 16) tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)]) # CLZ binary encoding for k in [2, 4, 6, 8, 10, 12, 14, 16]: tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.clz3.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.clz3.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.clz_and_4_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.clz_and_4_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.clz_and_12_15.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.clz_and_12_15.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.clz2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.clz2.bias"] = torch.tensor([-1.0]) tensors[f"{prefix}.clz_and_2_3.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.clz_and_2_3.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.clz_and_6_7.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.clz_and_6_7.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.clz_and_10_11.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.clz_and_10_11.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.clz_and_14_15.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.clz_and_14_15.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.clz1.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.clz1.bias"] = torch.tensor([-1.0]) for i in [1, 3, 5, 7, 9, 11, 13, 15]: tensors[f"{prefix}.clz_and_{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.clz_and_{i}.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.clz0.weight"] = torch.tensor([1.0] * 8) tensors[f"{prefix}.clz0.bias"] = torch.tensor([-1.0]) # Exponent = 15 + 15 - CLZ = 30 - CLZ (biased) # Actually: exponent = bias + (15 - CLZ) = 15 + 15 - CLZ for i in range(5): tensors[f"{prefix}.not_clz{i}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_clz{i}.bias"] = torch.tensor([0.0]) for i in range(5): p = f"{prefix}.exp_calc.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Left barrel shifter to normalize mantissa for stage in range(4): shift_amt = 1 << stage tensors[f"{prefix}.not_norm_shift{stage}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_norm_shift{stage}.bias"] = torch.tensor([0.0]) for i in range(16): tensors[f"{prefix}.norm_s{stage}_{i}.pass.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_s{stage}_{i}.pass.bias"] = torch.tensor([-2.0]) if i >= shift_amt: tensors[f"{prefix}.norm_s{stage}_{i}.shift.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.norm_s{stage}_{i}.shift.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.norm_s{stage}_{i}.weight"] = torch.tensor([1.0, 1.0]) else: tensors[f"{prefix}.norm_s{stage}_{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.norm_s{stage}_{i}.bias"] = torch.tensor([-1.0]) # === ROUNDING (round-to-nearest-even) === # Normalized value has implicit 1 at bit 15. Mantissa bits = norm_s3[14:5]. # Guard = bit 4, Round = bit 3, Sticky = OR(bits 0..2). tensors[f"{prefix}.guard_bit.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.guard_bit.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.round_bit.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.round_bit.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.sticky_bit.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.sticky_bit.bias"] = torch.tensor([-0.5]) # round_or = round_bit OR sticky_bit OR lsb_mant tensors[f"{prefix}.round_or.weight"] = torch.tensor([1.0, 1.0, 1.0]) tensors[f"{prefix}.round_or.bias"] = torch.tensor([-0.5]) # round_inc = guard_bit AND round_or tensors[f"{prefix}.round_inc.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.round_inc.bias"] = torch.tensor([-2.0]) # Ripple add round_inc into mantissa bits (10 bits) for i in range(10): p = f"{prefix}.round_add.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # mant_overflow = carry out from rounding adder tensors[f"{prefix}.mant_overflow.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.mant_overflow.bias"] = torch.tensor([-0.5]) tensors[f"{prefix}.not_mant_overflow.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_mant_overflow.bias"] = torch.tensor([0.0]) # Exponent increment if mantissa overflows for i in range(5): p = f"{prefix}.exp_inc.fa{i}" tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and1.bias"] = torch.tensor([-2.0]) tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.and2.bias"] = torch.tensor([-2.0]) tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{p}.cout.bias"] = torch.tensor([-1.0]) # Exponent select based on mant_overflow for i in range(5): tensors[f"{prefix}.exp_out{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_out{i}.overflow_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.exp_out{i}.normal_path.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_out{i}.normal_path.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.exp_out{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.exp_out{i}.bias"] = torch.tensor([-1.0]) # Mantissa select (zero if overflow) for i in range(10): tensors[f"{prefix}.mant_out{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.mant_out{i}.bias"] = torch.tensor([-2.0]) # Output assembly for i in range(16): if i < 10: tensors[f"{prefix}.out{i}.zero_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.zero_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.normal_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.normal_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0]) elif i < 15: tensors[f"{prefix}.out{i}.zero_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.zero_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.normal_gate.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.normal_gate.bias"] = torch.tensor([-2.0]) tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0]) else: tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5]) return tensors def build_modular_power2_tensors() -> Dict[str, torch.Tensor]: """Build tensors for modular.mod2, mod4, mod8 circuits. For powers of 2, mod is just extracting the lower bits: - mod2: val & 1 = bit 0 - mod4: val & 3 = bits 0-1 - mod8: val & 7 = bits 0-2 Each circuit outputs the appropriate number of bits. Inputs are inferred by infer_modular_inputs(). """ tensors = {} # mod2: single output (bit 0) tensors["modular.mod2.out0.weight"] = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) tensors["modular.mod2.out0.bias"] = torch.tensor([-0.5]) # mod4: two outputs (bits 0-1) tensors["modular.mod4.out0.weight"] = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) tensors["modular.mod4.out0.bias"] = torch.tensor([-0.5]) tensors["modular.mod4.out1.weight"] = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) tensors["modular.mod4.out1.bias"] = torch.tensor([-0.5]) # mod8: three outputs (bits 0-2) tensors["modular.mod8.out0.weight"] = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) tensors["modular.mod8.out0.bias"] = torch.tensor([-0.5]) tensors["modular.mod8.out1.weight"] = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) tensors["modular.mod8.out1.bias"] = torch.tensor([-0.5]) tensors["modular.mod8.out2.weight"] = torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]) tensors["modular.mod8.out2.bias"] = torch.tensor([-0.5]) return tensors def build_bitwise_shift_tensors() -> Dict[str, torch.Tensor]: """Build tensors for arithmetic.asr8bit, rol8bit, ror8bit circuits. ASR (Arithmetic Shift Right by 1): - bit[i] = x[i+1] for i in 0..6 - bit[7] = x[7] (sign extension) - shiftout = x[0] ROL (Rotate Left by 1): - bit[0] = x[7] (wrap around) - bit[i] = x[i-1] for i in 1..7 - cout = x[7] ROR (Rotate Right by 1): - bit[i] = x[i+1] for i in 0..6 - bit[7] = x[0] (wrap around) - cout = x[0] """ tensors = {} # ASR8BIT - Arithmetic Shift Right by 1 prefix = "arithmetic.asr8bit" for i in range(8): if i < 7: # bit[i] gets x[i+1] (shift right) w = [0.0] * 8 w[i + 1] = 1.0 else: # bit[7] gets x[7] (sign extension) w = [0.0] * 8 w[7] = 1.0 tensors[f"{prefix}.bit{i}.weight"] = torch.tensor(w) tensors[f"{prefix}.bit{i}.bias"] = torch.tensor([-0.5]) # shiftout gets x[0] w = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] tensors[f"{prefix}.shiftout.weight"] = torch.tensor(w) tensors[f"{prefix}.shiftout.bias"] = torch.tensor([-0.5]) # ROL8BIT - Rotate Left by 1 prefix = "arithmetic.rol8bit" for i in range(8): w = [0.0] * 8 if i == 0: # bit[0] gets x[7] (wrap around) w[7] = 1.0 else: # bit[i] gets x[i-1] w[i - 1] = 1.0 tensors[f"{prefix}.bit{i}.weight"] = torch.tensor(w) tensors[f"{prefix}.bit{i}.bias"] = torch.tensor([-0.5]) # cout gets x[7] w = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] tensors[f"{prefix}.cout.weight"] = torch.tensor(w) tensors[f"{prefix}.cout.bias"] = torch.tensor([-0.5]) # ROR8BIT - Rotate Right by 1 prefix = "arithmetic.ror8bit" for i in range(8): w = [0.0] * 8 if i < 7: # bit[i] gets x[i+1] w[i + 1] = 1.0 else: # bit[7] gets x[0] (wrap around) w[0] = 1.0 tensors[f"{prefix}.bit{i}.weight"] = torch.tensor(w) tensors[f"{prefix}.bit{i}.bias"] = torch.tensor([-0.5]) # cout gets x[0] w = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] tensors[f"{prefix}.cout.weight"] = torch.tensor(w) tensors[f"{prefix}.cout.bias"] = torch.tensor([-0.5]) return tensors def build_symmetry8bit_tensors() -> Dict[str, torch.Tensor]: """Build tensors for pattern_recognition.symmetry8bit circuit. Checks if an 8-bit input is a palindrome (symmetric). bit[0] == bit[7], bit[1] == bit[6], bit[2] == bit[5], bit[3] == bit[4] XNOR as threshold gate: XNOR(a,b) = 1 if a==b This requires a 2-layer structure per XNOR: - Layer 1: AND(a,b) and NOR(a,b) - Layer 2: OR of AND and NOR outputs Then final AND of all 4 XNOR results. """ tensors = {} prefix = "pattern_recognition.symmetry8bit" # XNOR gates for comparing bit pairs: (0,7), (1,6), (2,5), (3,4) pairs = [(0, 7), (1, 6), (2, 5), (3, 4)] for i, (lo, hi) in enumerate(pairs): # Layer 1: AND(a,b) - fires when both are 1 # Weight: select bits lo and hi from 8-bit input w_and = [0.0] * 8 w_and[lo] = 1.0 w_and[hi] = 1.0 tensors[f"{prefix}.xnor{i}.layer1.and.weight"] = torch.tensor(w_and) tensors[f"{prefix}.xnor{i}.layer1.and.bias"] = torch.tensor([-1.5]) # Need both # Layer 1: NOR(a,b) - fires when both are 0 w_nor = [0.0] * 8 w_nor[lo] = -1.0 w_nor[hi] = -1.0 tensors[f"{prefix}.xnor{i}.layer1.nor.weight"] = torch.tensor(w_nor) tensors[f"{prefix}.xnor{i}.layer1.nor.bias"] = torch.tensor([0.0]) # Fire when sum < 0 # Layer 2: OR of AND and NOR - fires when either is 1 (i.e., a==b) tensors[f"{prefix}.xnor{i}.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.xnor{i}.layer2.bias"] = torch.tensor([-0.5]) # Final AND of all 4 XNOR results tensors[f"{prefix}.and.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.and.bias"] = torch.tensor([-3.5]) # Need all 4 return tensors def build_clz8bit_tensors() -> Dict[str, torch.Tensor]: """Build tensors for arithmetic.clz8bit circuit. CLZ8BIT counts leading zeros in an 8-bit input. Output is 0-8 (4 bits). Architecture: 1. pz[k] gates: NOR of top k bits (fires if top k bits are all zero) 2. ge[k] gates: sum of pz >= k (threshold gates) 3. Logic gates to convert thermometer code to binary """ tensors = {} prefix = "arithmetic.clz8bit" # === PREFIX ZERO GATES (NOR of top k bits) === for k in range(1, 9): tensors[f"{prefix}.pz{k}.weight"] = torch.tensor([-1.0] * k) tensors[f"{prefix}.pz{k}.bias"] = torch.tensor([0.0]) # === GE GATES (sum of pz >= k) === for k in range(1, 9): tensors[f"{prefix}.ge{k}.weight"] = torch.tensor([1.0] * 8) tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)]) # === NOT GATES === for k in [2, 4, 6, 8]: tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0]) # === AND GATES for range detection === # and_2_3: ge2 AND NOT ge4 (CLZ in {2,3}) # and_6_7: ge6 AND NOT ge8 (CLZ in {6,7}) # and_1: ge1 AND NOT ge2 (CLZ = 1) # and_3: ge3 AND NOT ge4 (CLZ = 3) # and_5: ge5 AND NOT ge6 (CLZ = 5) # and_7: ge7 AND NOT ge8 (CLZ = 7) for name in ['and_2_3', 'and_6_7', 'and_1', 'and_3', 'and_5', 'and_7']: tensors[f"{prefix}.{name}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.{name}.bias"] = torch.tensor([-2.0]) # === OUTPUT GATES === # out3 (bit 3): CLZ >= 8, passthrough from ge8 tensors[f"{prefix}.out3.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.out3.bias"] = torch.tensor([-0.5]) # out2 (bit 2): CLZ in {4,5,6,7} = ge4 AND NOT ge8 tensors[f"{prefix}.out2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out2.bias"] = torch.tensor([-2.0]) # out1 (bit 1): CLZ in {2,3,6,7} = and_2_3 OR and_6_7 tensors[f"{prefix}.out1.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.out1.bias"] = torch.tensor([-1.0]) # out0 (bit 0): CLZ odd = and_1 OR and_3 OR and_5 OR and_7 tensors[f"{prefix}.out0.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0]) tensors[f"{prefix}.out0.bias"] = torch.tensor([-1.0]) return tensors def add_not_gate(tensors: Dict[str, torch.Tensor], name: str) -> None: tensors[f"{name}.weight"] = torch.tensor([-1.0]) tensors[f"{name}.bias"] = torch.tensor([0.0]) def add_and_gate(tensors: Dict[str, torch.Tensor], name: str) -> None: tensors[f"{name}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{name}.bias"] = torch.tensor([-2.0]) def add_or_gate(tensors: Dict[str, torch.Tensor], name: str) -> None: tensors[f"{name}.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{name}.bias"] = torch.tensor([-1.0]) def add_xor_gate(tensors: Dict[str, torch.Tensor], name: str) -> None: tensors[f"{name}.layer1.or.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{name}.layer1.or.bias"] = torch.tensor([-1.0]) tensors[f"{name}.layer1.nand.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{name}.layer1.nand.bias"] = torch.tensor([1.0]) tensors[f"{name}.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{name}.layer2.bias"] = torch.tensor([-2.0]) def add_xnor_gate(tensors: Dict[str, torch.Tensor], name: str) -> None: tensors[f"{name}.layer1.and.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{name}.layer1.and.bias"] = torch.tensor([-1.5]) tensors[f"{name}.layer1.nor.weight"] = torch.tensor([-1.0, -1.0]) tensors[f"{name}.layer1.nor.bias"] = torch.tensor([0.0]) tensors[f"{name}.layer2.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{name}.layer2.bias"] = torch.tensor([-0.5]) def build_ripplecarry_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} for i in range(bits): fa_prefix = f"{prefix}.fa{i}" add_xor_gate(tensors, f"{fa_prefix}.ha1.sum") add_and_gate(tensors, f"{fa_prefix}.ha1.carry") add_xor_gate(tensors, f"{fa_prefix}.ha2.sum") add_and_gate(tensors, f"{fa_prefix}.ha2.carry") add_or_gate(tensors, f"{fa_prefix}.carry_or") return tensors def build_adc_sbc_tensors(prefix: str, bits: int, with_notb: bool = False) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} if with_notb: for i in range(bits): add_not_gate(tensors, f"{prefix}.notb{i}") for i in range(bits): fa_prefix = f"{prefix}.fa{i}" add_xor_gate(tensors, f"{fa_prefix}.xor1") add_xor_gate(tensors, f"{fa_prefix}.xor2") add_and_gate(tensors, f"{fa_prefix}.and1") add_and_gate(tensors, f"{fa_prefix}.and2") add_or_gate(tensors, f"{fa_prefix}.or_carry") return tensors def build_sub_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} for i in range(bits): add_not_gate(tensors, f"{prefix}.notb{i}") tensors[f"{prefix}.carry_in.weight"] = torch.tensor([1.0]) tensors[f"{prefix}.carry_in.bias"] = torch.tensor([-0.5]) for i in range(bits): fa_prefix = f"{prefix}.fa{i}" add_xor_gate(tensors, f"{fa_prefix}.xor1") add_xor_gate(tensors, f"{fa_prefix}.xor2") add_and_gate(tensors, f"{fa_prefix}.and1") add_and_gate(tensors, f"{fa_prefix}.and2") add_or_gate(tensors, f"{fa_prefix}.or_carry") return tensors def build_cmp_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} for i in range(bits): add_not_gate(tensors, f"{prefix}.notb{i}") for i in range(bits): fa_prefix = f"{prefix}.fa{i}" add_xor_gate(tensors, f"{fa_prefix}.xor1") add_xor_gate(tensors, f"{fa_prefix}.xor2") add_and_gate(tensors, f"{fa_prefix}.and1") add_and_gate(tensors, f"{fa_prefix}.and2") add_or_gate(tensors, f"{fa_prefix}.or_carry") return tensors def build_equality_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} for i in range(bits): add_xnor_gate(tensors, f"{prefix}.xnor{i}") tensors[f"{prefix}.final_and.weight"] = torch.tensor([1.0] * bits) tensors[f"{prefix}.final_and.bias"] = torch.tensor([-(bits - 0.5)]) return tensors def build_neg_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} for i in range(bits): add_not_gate(tensors, f"{prefix}.not{i}") # sum0 = NOT(not0) == x0 (since ~x + 1 toggles the LSB) tensors[f"{prefix}.sum0.weight"] = torch.tensor([-1.0]) tensors[f"{prefix}.sum0.bias"] = torch.tensor([0.0]) tensors[f"{prefix}.carry0.weight"] = torch.tensor([1.0, 1.0]) tensors[f"{prefix}.carry0.bias"] = torch.tensor([-2.0]) for i in range(1, bits): add_xor_gate(tensors, f"{prefix}.xor{i}") add_and_gate(tensors, f"{prefix}.and{i}") return tensors def build_shift_rotate_tensors(prefix: str, bits: int, kind: str) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} for i in range(bits): if kind == "asr": src = i + 1 if i < bits - 1 else bits - 1 elif kind == "rol": src = (i - 1) % bits elif kind == "ror": src = (i + 1) % bits else: raise ValueError(f"unknown shift kind: {kind}") w = [0.0] * bits w[src] = 1.0 tensors[f"{prefix}.bit{i}.weight"] = torch.tensor(w) tensors[f"{prefix}.bit{i}.bias"] = torch.tensor([-0.5]) return tensors def build_comparator_vectors(bits: int) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} weights = [float(2 ** i) for i in range(bits - 1, -1, -1)] names = ["greaterthan", "lessthan", "greaterorequal", "lessorequal"] for name in names: tensors[f"arithmetic.{name}{bits}bit.comparator"] = torch.tensor(weights) return tensors def build_increment_decrement_constants(bits: int) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} one = [0.0] * (bits - 1) + [1.0] tensors[f"arithmetic.incrementer{bits}bit.one"] = torch.tensor(one) tensors[f"arithmetic.incrementer{bits}bit.adder"] = torch.tensor([1.0] * bits) tensors[f"arithmetic.decrementer{bits}bit.neg_one"] = torch.tensor([1.0] * bits) tensors[f"arithmetic.decrementer{bits}bit.adder"] = torch.tensor([1.0] * bits) return tensors def build_minmax_diff_constants(bits: int) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} width = bits * 2 tensors[f"arithmetic.absolutedifference{bits}bit.diff"] = torch.tensor([1.0] * width) tensors[f"arithmetic.max{bits}bit.select"] = torch.tensor([1.0] * width) tensors[f"arithmetic.min{bits}bit.select"] = torch.tensor([1.0] * width) return tensors def main(): print("Loading existing tensors...") tensors = {} with safe_open('arithmetic.safetensors', framework='pt') as f: for name in f.keys(): tensors[name] = f.get_tensor(name) print(f"Loaded {len(tensors)} tensors") # Remove old tensors for circuits we're rebuilding old_float16_add = [k for k in tensors.keys() if k.startswith('float16.add')] for k in old_float16_add: del tensors[k] print(f"Removed {len(old_float16_add)} old float16.add tensors") old_float16_toint = [k for k in tensors.keys() if k.startswith('float16.toint')] for k in old_float16_toint: del tensors[k] print(f"Removed {len(old_float16_toint)} old float16.toint tensors") old_float16_mul = [k for k in tensors.keys() if k.startswith('float16.mul')] for k in old_float16_mul: del tensors[k] print(f"Removed {len(old_float16_mul)} old float16.mul tensors") old_float16_div = [k for k in tensors.keys() if k.startswith('float16.div')] for k in old_float16_div: del tensors[k] print(f"Removed {len(old_float16_div)} old float16.div tensors") old_float16_lut = [k for k in tensors.keys() if k.startswith('float16.lut') or k.startswith('float16.sqrt') or k.startswith('float16.rsqrt') or k.startswith('float16.exp') or k.startswith('float16.ln') or k.startswith('float16.log2') or k.startswith('float16.log10') or k.startswith('float16.deg2rad') or k.startswith('float16.rad2deg') or k.startswith('float16.is_nan') or k.startswith('float16.is_inf') or k.startswith('float16.is_finite') or k.startswith('float16.is_zero') or k.startswith('float16.is_subnormal') or k.startswith('float16.is_normal') or k.startswith('float16.is_negative') or k.startswith('float16.sin') or k.startswith('float16.cos') or k.startswith('float16.tan') or k.startswith('float16.tanh') or k.startswith('float16.sin_deg') or k.startswith('float16.cos_deg') or k.startswith('float16.tan_deg') or k.startswith('float16.asin_deg') or k.startswith('float16.acos_deg') or k.startswith('float16.atan_deg') or k.startswith('float16.asin') or k.startswith('float16.acos') or k.startswith('float16.atan') or k.startswith('float16.sinh') or k.startswith('float16.cosh') or k.startswith('float16.floor') or k.startswith('float16.ceil') or k.startswith('float16.round') or k.startswith('float16.pow')] for k in old_float16_lut: del tensors[k] print(f"Removed {len(old_float16_lut)} old float16 LUT/pow tensors") old_arith_8bit = [k for k in tensors.keys() if k.startswith('arithmetic.') and '8bit' in k] for k in old_arith_8bit: del tensors[k] print(f"Removed {len(old_arith_8bit)} old arithmetic 8-bit tensors") old_mult8x8 = [k for k in tensors.keys() if k.startswith('arithmetic.multiplier8x8')] for k in old_mult8x8: del tensors[k] print(f"Removed {len(old_mult8x8)} old multiplier8x8 tensors") old_div8bit = [k for k in tensors.keys() if k.startswith('arithmetic.div8bit')] for k in old_div8bit: del tensors[k] print(f"Removed {len(old_div8bit)} old div8bit tensors") # Remove broken mod2/mod4/mod8 tensors old_mod_power2 = [k for k in tensors.keys() if k.startswith('modular.mod2') or k.startswith('modular.mod4') or k.startswith('modular.mod8')] for k in old_mod_power2: del tensors[k] print(f"Removed {len(old_mod_power2)} old mod2/mod4/mod8 tensors") # Remove broken bitwise shift tensors old_bitwise = [k for k in tensors.keys() if k.startswith('arithmetic.asr8bit') or k.startswith('arithmetic.rol8bit') or k.startswith('arithmetic.ror8bit')] for k in old_bitwise: del tensors[k] print(f"Removed {len(old_bitwise)} old asr8bit/rol8bit/ror8bit tensors") # Remove broken symmetry8bit tensors old_symmetry = [k for k in tensors.keys() if k.startswith('pattern_recognition.symmetry8bit')] for k in old_symmetry: del tensors[k] print(f"Removed {len(old_symmetry)} old symmetry8bit tensors") # Build new circuits print("Building new circuits...") clz16_tensors = build_clz16bit_tensors() tensors.update(clz16_tensors) print(f" CLZ16BIT: {len(clz16_tensors)} tensors") unpack_tensors = build_float16_unpack_tensors() tensors.update(unpack_tensors) print(f" float16.unpack: {len(unpack_tensors)} tensors") pack_tensors = build_float16_pack_tensors() tensors.update(pack_tensors) print(f" float16.pack: {len(pack_tensors)} tensors") cmp_tensors = build_float16_cmp_tensors() tensors.update(cmp_tensors) print(f" float16.cmp: {len(cmp_tensors)} tensors") norm_tensors = build_float16_normalize_tensors() tensors.update(norm_tensors) print(f" float16.normalize: {len(norm_tensors)} tensors") neg_tensors = build_float16_neg_tensors() tensors.update(neg_tensors) print(f" float16.neg: {len(neg_tensors)} tensors") abs_tensors = build_float16_abs_tensors() tensors.update(abs_tensors) print(f" float16.abs: {len(abs_tensors)} tensors") unpack32_tensors = build_float32_unpack_tensors() tensors.update(unpack32_tensors) print(f" float32.unpack: {len(unpack32_tensors)} tensors") pack32_tensors = build_float32_pack_tensors() tensors.update(pack32_tensors) print(f" float32.pack: {len(pack32_tensors)} tensors") cmp32_tensors = build_float32_cmp_tensors() tensors.update(cmp32_tensors) print(f" float32.cmp: {len(cmp32_tensors)} tensors") neg32_tensors = build_float32_neg_tensors() tensors.update(neg32_tensors) print(f" float32.neg: {len(neg32_tensors)} tensors") abs32_tensors = build_float32_abs_tensors() tensors.update(abs32_tensors) print(f" float32.abs: {len(abs32_tensors)} tensors") add_tensors = build_float16_add_tensors() tensors.update(add_tensors) print(f" float16.add: {len(add_tensors)} tensors") sub_tensors = build_float16_sub_tensors() tensors.update(sub_tensors) print(f" float16.sub: {len(sub_tensors)} tensors") mul_tensors = build_float16_mul_tensors() tensors.update(mul_tensors) print(f" float16.mul: {len(mul_tensors)} tensors") div_tensors = build_float16_div_tensors() tensors.update(div_tensors) print(f" float16.div: {len(div_tensors)} tensors") toint_tensors = build_float16_toint_tensors() tensors.update(toint_tensors) print(f" float16.toint: {len(toint_tensors)} tensors") fromint_tensors = build_float16_fromint_tensors() tensors.update(fromint_tensors) print(f" float16.fromint: {len(fromint_tensors)} tensors") const_map = { "pi": math.pi, "e": math.e, "deg2rad": math.pi / 180.0, "rad2deg": 180.0 / math.pi, } for name, value in const_map.items(): const_tensors = build_float16_const_tensors(f"float16.const_{name}", value) tensors.update(const_tensors) print(f" float16.const_{name}: {len(const_tensors)} tensors") # Shared LUT match gates lut_match_tensors = build_float16_lut_match_tensors("float16.lut") tensors.update(lut_match_tensors) print(f" float16.lut: {len(lut_match_tensors)} tensors") # Unary LUT outputs unary_ops = { "sqrt": torch.sqrt, "rsqrt": torch.rsqrt, "exp": torch.exp, "ln": torch.log, "log2": torch.log2, "log10": unary_float32(torch.log10), "deg2rad": unary_float32(lambda x: x * (math.pi / 180.0)), "rad2deg": unary_float32(lambda x: x * (180.0 / math.pi)), "sin": torch.sin, "cos": torch.cos, "tan": torch.tan, "tanh": torch.tanh, "asin": unary_float32(torch.asin), "acos": unary_float32(torch.acos), "atan": unary_float32(torch.atan), "sinh": unary_float32(torch.sinh), "cosh": unary_float32(torch.cosh), "floor": unary_float32(torch.floor), "ceil": unary_float32(torch.ceil), "round": unary_float32(torch.round), } deg_ops = { "sin_deg": wrap_deg_trig(torch.sin), "cos_deg": wrap_deg_trig(torch.cos), "tan_deg": wrap_deg_trig(torch.tan), "asin_deg": wrap_inv_trig_deg(torch.asin), "acos_deg": wrap_inv_trig_deg(torch.acos), "atan_deg": wrap_inv_trig_deg(torch.atan), } classify_ops = { "is_nan": unary_float32(torch.isnan), "is_inf": unary_float32(torch.isinf), "is_finite": unary_float32(torch.isfinite), "is_zero": unary_float32(lambda x: x == 0), "is_subnormal": unary_float32(lambda x: (torch.abs(x) != 0) & (torch.abs(x) < (2 ** -14))), "is_normal": unary_float32(lambda x: torch.isfinite(x) & (torch.abs(x) >= (2 ** -14))), "is_negative": unary_float32(torch.signbit), } domain_ops = [ "sqrt", "rsqrt", "ln", "log2", "log10", "asin", "acos", "asin_deg", "acos_deg", ] lut_outputs: Dict[str, List[int]] = {} for name, fn in unary_ops.items(): print(f" computing float16.{name} LUT...") outputs = compute_float16_unary_lut_outputs(fn) lut_outputs[name] = outputs op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs) tensors.update(op_tensors) print(f" float16.{name}: {len(op_tensors)} tensors") for name, fn in deg_ops.items(): print(f" computing float16.{name} LUT...") outputs = compute_float16_unary_lut_outputs(fn) lut_outputs[name] = outputs op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs) tensors.update(op_tensors) print(f" float16.{name}: {len(op_tensors)} tensors") for name, fn in classify_ops.items(): print(f" computing float16.{name} LUT...") outputs = compute_float16_unary_lut_outputs(fn) lut_outputs[name] = outputs op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs) tensors.update(op_tensors) print(f" float16.{name}: {len(op_tensors)} tensors") for name in domain_ops: print(f" computing float16.{name} domain flags...") flags = compute_float16_domain_flags(name) flag_tensors = build_float16_lut_flag_tensors(f"float16.{name}", flags, flag_name="domain") tensors.update(flag_tensors) print(f" float16.{name}.domain: {len(flag_tensors)} tensors") checked_tensors = build_float16_checked_outputs(f"float16.{name}") tensors.update(checked_tensors) print(f" float16.{name}.checked_out*: {len(checked_tensors)} tensors") # float16.pow (ln -> mul -> exp) pow_tensors = build_float16_pow_tensors(mul_tensors, lut_outputs["ln"], lut_outputs["exp"]) tensors.update(pow_tensors) print(f" float16.pow: {len(pow_tensors)} tensors") # 16-bit integer arithmetic circuits rc16 = build_ripplecarry_tensors("arithmetic.ripplecarry16bit", 16) tensors.update(rc16) print(f" ripplecarry16bit: {len(rc16)} tensors") rc32 = build_ripplecarry_tensors("arithmetic.ripplecarry32bit", 32) tensors.update(rc32) print(f" ripplecarry32bit: {len(rc32)} tensors") adc16 = build_adc_sbc_tensors("arithmetic.adc16bit", 16) tensors.update(adc16) print(f" adc16bit: {len(adc16)} tensors") adc32 = build_adc_sbc_tensors("arithmetic.adc32bit", 32) tensors.update(adc32) print(f" adc32bit: {len(adc32)} tensors") sbc16 = build_adc_sbc_tensors("arithmetic.sbc16bit", 16, with_notb=True) tensors.update(sbc16) print(f" sbc16bit: {len(sbc16)} tensors") sbc32 = build_adc_sbc_tensors("arithmetic.sbc32bit", 32, with_notb=True) tensors.update(sbc32) print(f" sbc32bit: {len(sbc32)} tensors") sub16 = build_sub_tensors("arithmetic.sub16bit", 16) tensors.update(sub16) print(f" sub16bit: {len(sub16)} tensors") sub32 = build_sub_tensors("arithmetic.sub32bit", 32) tensors.update(sub32) print(f" sub32bit: {len(sub32)} tensors") cmp16 = build_cmp_tensors("arithmetic.cmp16bit", 16) tensors.update(cmp16) print(f" cmp16bit: {len(cmp16)} tensors") cmp32 = build_cmp_tensors("arithmetic.cmp32bit", 32) tensors.update(cmp32) print(f" cmp32bit: {len(cmp32)} tensors") eq16 = build_equality_tensors("arithmetic.equality16bit", 16) tensors.update(eq16) print(f" equality16bit: {len(eq16)} tensors") eq32 = build_equality_tensors("arithmetic.equality32bit", 32) tensors.update(eq32) print(f" equality32bit: {len(eq32)} tensors") neg16 = build_neg_tensors("arithmetic.neg16bit", 16) tensors.update(neg16) print(f" neg16bit: {len(neg16)} tensors") neg32 = build_neg_tensors("arithmetic.neg32bit", 32) tensors.update(neg32) print(f" neg32bit: {len(neg32)} tensors") asr16 = build_shift_rotate_tensors("arithmetic.asr16bit", 16, "asr") rol16 = build_shift_rotate_tensors("arithmetic.rol16bit", 16, "rol") ror16 = build_shift_rotate_tensors("arithmetic.ror16bit", 16, "ror") tensors.update(asr16) tensors.update(rol16) tensors.update(ror16) print(f" asr/rol/ror16bit: {len(asr16) + len(rol16) + len(ror16)} tensors") asr32 = build_shift_rotate_tensors("arithmetic.asr32bit", 32, "asr") rol32 = build_shift_rotate_tensors("arithmetic.rol32bit", 32, "rol") ror32 = build_shift_rotate_tensors("arithmetic.ror32bit", 32, "ror") tensors.update(asr32) tensors.update(rol32) tensors.update(ror32) print(f" asr/rol/ror32bit: {len(asr32) + len(rol32) + len(ror32)} tensors") comp16 = build_comparator_vectors(16) tensors.update(comp16) print(f" comparator16bit: {len(comp16)} tensors") comp32 = build_comparator_vectors(32) tensors.update(comp32) print(f" comparator32bit: {len(comp32)} tensors") incdec16 = build_increment_decrement_constants(16) tensors.update(incdec16) print(f" increment/decrement16bit: {len(incdec16)} tensors") incdec32 = build_increment_decrement_constants(32) tensors.update(incdec32) print(f" increment/decrement32bit: {len(incdec32)} tensors") minmax16 = build_minmax_diff_constants(16) tensors.update(minmax16) print(f" min/max/diff16bit: {len(minmax16)} tensors") minmax32 = build_minmax_diff_constants(32) tensors.update(minmax32) print(f" min/max/diff32bit: {len(minmax32)} tensors") mod_power2_tensors = build_modular_power2_tensors() tensors.update(mod_power2_tensors) print(f" modular.mod2/4/8: {len(mod_power2_tensors)} tensors") symmetry_tensors = build_symmetry8bit_tensors() tensors.update(symmetry_tensors) print(f" symmetry8bit: {len(symmetry_tensors)} tensors") print(f"Total tensors: {len(tensors)}") # Load routing for complex circuits print("Loading routing.json...") try: with open('routing.json', 'r') as f: routing = json.load(f) except FileNotFoundError: routing = {} # Get all gates gates = get_all_gates(tensors) print(f"Found {len(gates)} gates") # Create signal registry registry = SignalRegistry() # Infer inputs for each gate print("Inferring inputs for each gate...") gate_inputs = {} missing_inputs = [] for gate in sorted(gates): inputs = infer_inputs_for_gate(gate, registry, routing) if inputs: gate_inputs[gate] = inputs # Register the gate itself so eval can store its output registry.register(gate) else: missing_inputs.append(gate) print(f"Inferred inputs for {len(gate_inputs)} gates") print(f"Missing inputs for {len(missing_inputs)} gates") if missing_inputs: print("\nGates missing inputs (first 20):") for gate in missing_inputs[:20]: print(f" {gate}") if len(missing_inputs) > 20: print(f" ... and {len(missing_inputs) - 20} more") # Add .inputs tensors print("\nAdding .inputs tensors...") new_tensors = dict(tensors) # Copy existing for gate, inputs in gate_inputs.items(): input_tensor = torch.tensor(inputs, dtype=torch.int64) new_tensors[f"{gate}.inputs"] = input_tensor print(f"Total tensors: {len(new_tensors)}") # Create metadata metadata = { "signal_registry": registry.to_metadata(), "format_version": "2.0", "description": "Self-documenting threshold logic circuits with explicit .inputs tensors" } # Save to temp file then rename (avoid file locking issues) import os print("Saving arithmetic.safetensors...") save_file(new_tensors, 'arithmetic_new.safetensors', metadata=metadata) if os.path.exists('arithmetic.safetensors'): os.remove('arithmetic.safetensors') os.rename('arithmetic_new.safetensors', 'arithmetic.safetensors') size = os.path.getsize('arithmetic.safetensors') print(f"Saved: {size:,} bytes") # Summary print(f"\n=== SUMMARY ===") print(f"Original tensors: {len(tensors)}") print(f"New tensors: {len(new_tensors)}") print(f"Added .inputs tensors: {len(new_tensors) - len(tensors)}") print(f"Signal registry size: {len(registry.name_to_id)} signals") print(f"Gates with inferred inputs: {len(gate_inputs)}") print(f"Gates missing inputs: {len(missing_inputs)}") if __name__ == '__main__': main()