| | |
| | """ |
| | Threshold-calculus gate-level calculator. |
| | |
| | Pure evaluation via .inputs + .weight/.bias only. No arithmetic shortcuts. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import ast |
| | import json |
| | import math |
| | import re |
| | import struct |
| | import time |
| | from dataclasses import dataclass, field |
| | from typing import Dict, Iterable, List, Optional, Sequence, Tuple |
| |
|
| | import torch |
| | from safetensors import safe_open |
| |
|
| |
|
| | def int_to_bits(val: int, width: int) -> List[float]: |
| | """Convert integer to LSB-first bit list of length width.""" |
| | return [float((val >> i) & 1) for i in range(width)] |
| |
|
| |
|
| | def bits_to_int(bits: Sequence[float]) -> int: |
| | """Convert LSB-first bit list to integer.""" |
| | return sum((1 << i) for i, b in enumerate(bits) if b >= 0.5) |
| |
|
| |
|
| | def float_to_float16_bits(val: float) -> int: |
| | """Convert float to IEEE-754 float16 bits (with canonical NaN).""" |
| | try: |
| | packed = struct.pack(">e", float(val)) |
| | return struct.unpack(">H", packed)[0] |
| | except (OverflowError, struct.error): |
| | if val == float("inf"): |
| | return 0x7C00 |
| | if val == float("-inf"): |
| | return 0xFC00 |
| | if val != val: |
| | return 0x7E00 |
| | return 0x7BFF if val > 0 else 0xFBFF |
| |
|
| |
|
| | def float16_bits_to_float(bits: int) -> float: |
| | """Interpret 16-bit int as IEEE-754 float16.""" |
| | packed = struct.pack(">H", bits & 0xFFFF) |
| | return struct.unpack(">e", packed)[0] |
| |
|
| |
|
| | def parse_external_name(name: str) -> Tuple[Optional[str], Optional[int], Optional[str]]: |
| | """ |
| | Parse an external signal name into (base, index, full_key). |
| | Examples: |
| | "$a" -> ("a", None, "$a") |
| | "float16.add.$a[3]" -> ("a", 3, "float16.add.$a") |
| | """ |
| | if "$" not in name: |
| | return None, None, None |
| | full_key = name.split("[", 1)[0] |
| | base_part = name.split("$", 1)[1] |
| | base = base_part.split("[", 1)[0] |
| | idx = None |
| | if "[" in base_part and "]" in base_part: |
| | try: |
| | idx = int(base_part.split("[", 1)[1].split("]", 1)[0]) |
| | except ValueError: |
| | idx = None |
| | return base, idx, full_key |
| |
|
| |
|
| | def resolve_alias_target(name: str, gates: set) -> Optional[str]: |
| | """Resolve common alias signal names to actual gate names.""" |
| | if name in gates: |
| | return name |
| | cand = name + ".layer2" |
| | if cand in gates: |
| | return cand |
| | if name.endswith(".sum"): |
| | cand = name[:-4] + ".xor2.layer2" |
| | if cand in gates: |
| | return cand |
| | if name.endswith(".cout"): |
| | for suffix in [".or_carry", ".carry_or"]: |
| | cand = name[:-5] + suffix |
| | if cand in gates: |
| | return cand |
| | return None |
| |
|
| |
|
| | _NUM_RE = re.compile(r"(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?") |
| | _IDENT_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") |
| | _FUNC_NAMES = { |
| | "sqrt", "rsqrt", "exp", "ln", "log", "log2", "log10", |
| | "deg2rad", "rad2deg", |
| | "isnan", "is_nan", "isinf", "is_inf", "isfinite", "is_finite", |
| | "iszero", "is_zero", "issubnormal", "is_subnormal", |
| | "isnormal", "is_normal", "isneg", "is_negative", "signbit", |
| | "sin", "cos", "tan", "tanh", |
| | "asin", "acos", "atan", "sinh", "cosh", |
| | "floor", "ceil", "round", "abs", "neg", |
| | } |
| |
|
| |
|
| | def _tokenize_expr(expr: str) -> List[str]: |
| | tokens: List[str] = [] |
| | i = 0 |
| | while i < len(expr): |
| | ch = expr[i] |
| | if ch.isspace(): |
| | i += 1 |
| | continue |
| | if expr.startswith("**", i): |
| | tokens.append("**") |
| | i += 2 |
| | continue |
| | if ch in "+-*/(),": |
| | tokens.append(ch) |
| | i += 1 |
| | continue |
| | num_match = _NUM_RE.match(expr, i) |
| | if num_match: |
| | tokens.append(num_match.group(0)) |
| | i = num_match.end() |
| | continue |
| | ident_match = _IDENT_RE.match(expr, i) |
| | if ident_match: |
| | tokens.append(ident_match.group(0)) |
| | i = ident_match.end() |
| | continue |
| | raise RuntimeError(f"bad token near: {expr[i:]}") |
| | return tokens |
| |
|
| |
|
| | def _needs_implicit_mul(left: str, right: str) -> bool: |
| | if left in {"+", "-", "*", "/", "**", ",", "("}: |
| | return False |
| | if right in {"+", "-", "*", "/", "**", ",", ")"}: |
| | return False |
| | if left in _FUNC_NAMES and right == "(": |
| | return False |
| | return True |
| |
|
| |
|
| | def _insert_implicit_mul(expr: str) -> str: |
| | tokens = _tokenize_expr(expr) |
| | if not tokens: |
| | return expr |
| | out: List[str] = [] |
| | for idx, tok in enumerate(tokens): |
| | out.append(tok) |
| | if idx + 1 >= len(tokens): |
| | continue |
| | nxt = tokens[idx + 1] |
| | if _needs_implicit_mul(tok, nxt): |
| | out.append("*") |
| | return "".join(out) |
| |
|
| |
|
| | def normalize_expr(expr: str) -> str: |
| | """Normalize user-facing calculator syntax to Python AST syntax.""" |
| | expr = expr.replace("\u03c0", "pi") |
| | expr = expr.replace("\u00d7", "*").replace("\u00f7", "/").replace("\u2212", "-") |
| | if "^" in expr: |
| | expr = expr.replace("^", "**") |
| | expr = _insert_implicit_mul(expr) |
| | return expr |
| |
|
| |
|
| | def looks_like_expression(text: str) -> bool: |
| | tokens = ["+", "-", "*", "/", "(", ")", "^", "pi", "\u03c0"] |
| | return any(tok in text for tok in tokens) |
| |
|
| |
|
| | @dataclass |
| | class EvalResult: |
| | bits: List[float] |
| | elapsed_s: float |
| | gates_evaluated: int |
| | non_gate_events: List[str] = field(default_factory=list) |
| |
|
| |
|
| | @dataclass |
| | class LevelBatch: |
| | input_ids: torch.Tensor |
| | weight: torch.Tensor |
| | bias: torch.Tensor |
| | output_ids: torch.Tensor |
| | alias_ids: torch.Tensor |
| | alias_src: torch.Tensor |
| |
|
| |
|
| | @dataclass |
| | class CompiledLevel: |
| | batch: LevelBatch |
| |
|
| |
|
| | @dataclass |
| | class ExternalSpec: |
| | entries: List[Tuple[int, str, int, str]] |
| | width_full: Dict[str, int] |
| | width_base: Dict[str, int] |
| |
|
| |
|
| | @dataclass |
| | class CompiledCircuit: |
| | prefix: str |
| | output_names: List[str] |
| | output_ids: List[int] |
| | levels: List[CompiledLevel] |
| | external_spec: ExternalSpec |
| | gate_count: int |
| |
|
| |
|
| | class ThresholdCalculator: |
| | def __init__(self, model_path: str = "./arithmetic.safetensors") -> None: |
| | self.model_path = model_path |
| | self.tensors: Dict[str, torch.Tensor] = {} |
| | self.gates: List[str] = [] |
| | self.name_to_id: Dict[str, int] = {} |
| | self.id_to_name: Dict[int, str] = {} |
| | self._gate_inputs: Dict[str, torch.Tensor] = {} |
| | self._gate_set: set = set() |
| | self._alias_to_gate: Dict[int, int] = {} |
| | self._gate_to_alias: Dict[int, List[int]] = {} |
| | self._id_to_gate: Dict[int, str] = {} |
| | self._topo_cache: Dict[Tuple[str, Tuple[str, ...]], List[str]] = {} |
| | self._compiled: Dict[Tuple[str, Tuple[str, ...]], CompiledCircuit] = {} |
| | self._const_cache: Dict[str, int] = {} |
| | self._load() |
| |
|
| | def _load(self) -> None: |
| | with safe_open(self.model_path, framework="pt") as f: |
| | for name in f.keys(): |
| | self.tensors[name] = f.get_tensor(name) |
| | metadata = f.metadata() |
| | if metadata and "signal_registry" in metadata: |
| | registry_raw = json.loads(metadata["signal_registry"]) |
| | self.id_to_name = {int(k): v for k, v in registry_raw.items()} |
| | self.name_to_id = {v: int(k) for k, v in registry_raw.items()} |
| | self.gates = sorted({k.rsplit(".", 1)[0] for k in self.tensors.keys() if k.endswith(".weight")}) |
| | self._gate_set = set(self.gates) |
| | for gate in self.gates: |
| | inputs_key = f"{gate}.inputs" |
| | if inputs_key in self.tensors: |
| | self._gate_inputs[gate] = self.tensors[inputs_key].to(dtype=torch.long) |
| | self._build_alias_maps() |
| | for gate in self.gates: |
| | gid = self.name_to_id.get(gate) |
| | if gid is not None: |
| | self._id_to_gate[gid] = gate |
| |
|
| | def _build_alias_maps(self) -> None: |
| | gates = set(self.gates) |
| | alias_to_gate: Dict[int, int] = {} |
| | gate_to_alias: Dict[int, List[int]] = {} |
| | for name, sid in self.name_to_id.items(): |
| | if name in ("#0", "#1"): |
| | continue |
| | if name.startswith("$") or ".$" in name: |
| | continue |
| | if name in gates: |
| | continue |
| | target = resolve_alias_target(name, gates) |
| | if not target: |
| | continue |
| | target_id = self.name_to_id.get(target) |
| | if target_id is None: |
| | continue |
| | alias_to_gate[sid] = target_id |
| | gate_to_alias.setdefault(target_id, []).append(sid) |
| | self._alias_to_gate = alias_to_gate |
| | self._gate_to_alias = gate_to_alias |
| |
|
| | def _signal_to_gate(self, sid: int) -> Optional[str]: |
| | if sid in self._id_to_gate: |
| | return self._id_to_gate[sid] |
| | alias_target = self._alias_to_gate.get(sid) |
| | if alias_target is not None: |
| | return self._id_to_gate.get(alias_target) |
| | return None |
| |
|
| | def _default_outputs(self, prefix: str, out_bits: int) -> List[str]: |
| | if f"{prefix}.out0.weight" in self.tensors: |
| | return [f"{prefix}.out{i}" for i in range(out_bits)] |
| | if prefix in self._gate_set: |
| | return [prefix] |
| | raise RuntimeError(f"{prefix}: no outputs found") |
| |
|
| | def _collect_required_gates(self, output_gates: Sequence[str]) -> List[str]: |
| | required: set = set() |
| | stack = list(output_gates) |
| | while stack: |
| | gate = stack.pop() |
| | if gate in required: |
| | continue |
| | if gate not in self._gate_inputs: |
| | raise RuntimeError(f"{gate}: missing .inputs tensor") |
| | required.add(gate) |
| | input_ids = self._gate_inputs[gate] |
| | for sid in input_ids.tolist(): |
| | dep_gate = self._signal_to_gate(int(sid)) |
| | if dep_gate is not None and dep_gate not in required: |
| | stack.append(dep_gate) |
| | return sorted(required) |
| |
|
| | def _topo_sort(self, gates: Sequence[str]) -> List[str]: |
| | key = tuple(gates) |
| | cache_key = ("__set__", key) |
| | if cache_key in self._topo_cache: |
| | return self._topo_cache[cache_key] |
| | gate_set = set(gates) |
| | deps: Dict[str, set] = {g: set() for g in gates} |
| | rev: Dict[str, List[str]] = {g: [] for g in gates} |
| | for gate in gates: |
| | input_ids = self._gate_inputs[gate].tolist() |
| | for sid in input_ids: |
| | dep_gate = self._signal_to_gate(int(sid)) |
| | if dep_gate is not None and dep_gate in gate_set: |
| | deps[gate].add(dep_gate) |
| | rev[dep_gate].append(gate) |
| | queue = sorted([g for g in gates if not deps[g]]) |
| | order: List[str] = [] |
| | while queue: |
| | g = queue.pop(0) |
| | order.append(g) |
| | for child in rev[g]: |
| | deps[child].remove(g) |
| | if not deps[child]: |
| | queue.append(child) |
| | queue.sort() |
| | if len(order) != len(gates): |
| | raise RuntimeError("Dependency cycle or unresolved inputs in gate graph") |
| | self._topo_cache[cache_key] = order |
| | return order |
| |
|
| | def _required_externals(self, gates: Iterable[str]) -> List[int]: |
| | externals: set = set() |
| | for gate in gates: |
| | for sid in self._gate_inputs[gate].tolist(): |
| | sid = int(sid) |
| | name = self.id_to_name.get(sid, "") |
| | if name.startswith("$") or ".$" in name: |
| | externals.add(sid) |
| | return sorted(externals) |
| |
|
| | def _build_external_spec(self, gates: Iterable[str]) -> ExternalSpec: |
| | required_externals = self._required_externals(gates) |
| | width_full: Dict[str, int] = {} |
| | width_base: Dict[str, int] = {} |
| | entries: List[Tuple[int, str, int, str]] = [] |
| | for sid in required_externals: |
| | name = self.id_to_name.get(sid, "") |
| | base, idx, full_key = parse_external_name(name) |
| | if base is None or full_key is None: |
| | continue |
| | w = (idx + 1) if idx is not None else 1 |
| | width_full[full_key] = max(width_full.get(full_key, 1), w) |
| | width_base[base] = max(width_base.get(base, 1), w) |
| | entries.append((sid, base, idx if idx is not None else 0, full_key)) |
| | entries.sort(key=lambda x: x[0]) |
| | return ExternalSpec(entries=entries, width_full=width_full, width_base=width_base) |
| |
|
| | def _normalize_inputs(self, spec: ExternalSpec, inputs: Dict[str, object]) -> Dict[int, float]: |
| | exact: Dict[str, object] = {} |
| | base_inputs: Dict[str, object] = {} |
| | for key, val in inputs.items(): |
| | if "$" in key: |
| | exact[key] = val |
| | else: |
| | base_inputs[key] = val |
| |
|
| | def ensure_bit_list(val: object, width: int) -> List[float]: |
| | if isinstance(val, (list, tuple)): |
| | bits = [float(b) for b in val] |
| | if len(bits) < width: |
| | raise RuntimeError(f"input width {len(bits)} < required {width}") |
| | return bits |
| | if isinstance(val, int): |
| | return int_to_bits(val, width) |
| | if isinstance(val, float): |
| | if width != 16: |
| | raise RuntimeError("float inputs only supported for 16-bit values") |
| | bits_int = float_to_float16_bits(val) |
| | return int_to_bits(bits_int, width) |
| | raise RuntimeError("inputs must be list/tuple, int, or float16-compatible float") |
| |
|
| | normalized: Dict[int, float] = {} |
| | for sid, base, idx, full_key in spec.entries: |
| | if full_key in exact: |
| | bits = ensure_bit_list(exact[full_key], spec.width_full[full_key]) |
| | elif base in base_inputs: |
| | bits = ensure_bit_list(base_inputs[base], spec.width_base[base]) |
| | else: |
| | raise RuntimeError(f"missing external input for {full_key}") |
| | normalized[sid] = float(bits[idx]) |
| | return normalized |
| |
|
| | def _compile_prefix(self, prefix: str, output_gates: List[str]) -> CompiledCircuit: |
| | key = (prefix, tuple(output_gates)) |
| | if key in self._compiled: |
| | return self._compiled[key] |
| |
|
| | required_gates = self._collect_required_gates(output_gates) |
| | gate_order = self._topo_sort(required_gates) |
| | external_spec = self._build_external_spec(required_gates) |
| |
|
| | gate_set = set(required_gates) |
| | level_map: Dict[str, int] = {} |
| | levels: List[List[str]] = [] |
| | for gate in gate_order: |
| | deps: List[str] = [] |
| | for sid in self._gate_inputs[gate].tolist(): |
| | dep_gate = self._signal_to_gate(int(sid)) |
| | if dep_gate is not None and dep_gate in gate_set: |
| | deps.append(dep_gate) |
| | if deps: |
| | lvl = max(level_map[d] for d in deps) + 1 |
| | else: |
| | lvl = 0 |
| | level_map[gate] = lvl |
| | while lvl >= len(levels): |
| | levels.append([]) |
| | levels[lvl].append(gate) |
| |
|
| | compiled_levels: List[CompiledLevel] = [] |
| | gate_count = 0 |
| | for level_gates in levels: |
| | if not level_gates: |
| | continue |
| | max_fanin = max(int(self._gate_inputs[g].numel()) for g in level_gates) |
| | num_gates = len(level_gates) |
| | input_ids = torch.zeros((num_gates, max_fanin), dtype=torch.long) |
| | weight_mat = torch.zeros((num_gates, max_fanin), dtype=torch.float32) |
| | bias_vec = torch.zeros((num_gates,), dtype=torch.float32) |
| | output_ids = torch.zeros((num_gates,), dtype=torch.long) |
| | alias_ids: List[int] = [] |
| | alias_src: List[int] = [] |
| | for idx, gate in enumerate(level_gates): |
| | gate_inputs = self._gate_inputs[gate] |
| | fan_in = int(gate_inputs.numel()) |
| | input_ids[idx, :fan_in] = gate_inputs |
| | weight = self.tensors[f"{gate}.weight"] |
| | if weight.dtype != torch.float32: |
| | weight = weight.float() |
| | weight_mat[idx, :fan_in] = weight |
| | bias_vec[idx] = self.tensors.get(f"{gate}.bias", torch.tensor([0.0])).float().item() |
| | output_id = self.name_to_id.get(gate) |
| | if output_id is None: |
| | raise RuntimeError(f"{gate}: missing signal id") |
| | output_ids[idx] = output_id |
| | for alias_id in self._gate_to_alias.get(output_id, []): |
| | alias_ids.append(alias_id) |
| | alias_src.append(idx) |
| | if alias_ids: |
| | alias_ids_vec = torch.tensor(alias_ids, dtype=torch.long) |
| | alias_src_vec = torch.tensor(alias_src, dtype=torch.long) |
| | else: |
| | alias_ids_vec = torch.empty((0,), dtype=torch.long) |
| | alias_src_vec = torch.empty((0,), dtype=torch.long) |
| | compiled_levels.append( |
| | CompiledLevel( |
| | batch=LevelBatch( |
| | input_ids=input_ids, |
| | weight=weight_mat, |
| | bias=bias_vec, |
| | output_ids=output_ids, |
| | alias_ids=alias_ids_vec, |
| | alias_src=alias_src_vec, |
| | ) |
| | ) |
| | ) |
| | gate_count += num_gates |
| |
|
| | output_ids: List[int] = [] |
| | for gate in output_gates: |
| | gid = self.name_to_id.get(gate) |
| | if gid is None: |
| | raise RuntimeError(f"{prefix}: missing output {gate}") |
| | output_ids.append(gid) |
| |
|
| | compiled = CompiledCircuit( |
| | prefix=prefix, |
| | output_names=output_gates, |
| | output_ids=output_ids, |
| | levels=compiled_levels, |
| | external_spec=external_spec, |
| | gate_count=gate_count, |
| | ) |
| | self._compiled[key] = compiled |
| | return compiled |
| |
|
| | def evaluate_prefix( |
| | self, |
| | prefix: str, |
| | inputs: Dict[str, object], |
| | out_bits: int = 16, |
| | outputs: Optional[List[str]] = None, |
| | ) -> EvalResult: |
| | output_gates = outputs if outputs is not None else self._default_outputs(prefix, out_bits) |
| | compiled = self._compile_prefix(prefix, output_gates) |
| |
|
| | num_signals = len(self.id_to_name) |
| | signals = torch.full((num_signals,), float("nan")) |
| | if "#0" in self.name_to_id: |
| | signals[self.name_to_id["#0"]] = 0.0 |
| | if "#1" in self.name_to_id: |
| | signals[self.name_to_id["#1"]] = 1.0 |
| |
|
| | seeded = self._normalize_inputs(compiled.external_spec, inputs) |
| | for sid, val in seeded.items(): |
| | signals[sid] = val |
| |
|
| | start = time.time() |
| | evaluated = 0 |
| | for level in compiled.levels: |
| | batch = level.batch |
| | input_vals = torch.take(signals, batch.input_ids) |
| | if torch.isnan(input_vals).any(): |
| | raise RuntimeError(f"{prefix}: unresolved inputs") |
| | totals = (batch.weight * input_vals).sum(dim=1) + batch.bias |
| | outs = (totals >= 0).to(dtype=signals.dtype) |
| | signals[batch.output_ids] = outs |
| | if batch.alias_ids.numel() > 0: |
| | signals.index_copy_(0, batch.alias_ids, outs.index_select(0, batch.alias_src)) |
| | evaluated += int(batch.output_ids.numel()) |
| | elapsed = time.time() - start |
| |
|
| | bits: List[float] = [] |
| | for gid in compiled.output_ids: |
| | if torch.isnan(signals[gid]): |
| | raise RuntimeError(f"{prefix}: missing output") |
| | bits.append(float(signals[gid])) |
| | return EvalResult(bits=bits, elapsed_s=elapsed, gates_evaluated=evaluated) |
| |
|
| | |
| | def float16_binop(self, op: str, a: float, b: float) -> Tuple[float, EvalResult]: |
| | prefix = f"float16.{op}" |
| | a_bits = int_to_bits(float_to_float16_bits(a), 16) |
| | b_bits = int_to_bits(float_to_float16_bits(b), 16) |
| | if op == "sub": |
| | |
| | b_bits[15] = 1.0 - b_bits[15] |
| | result = self.evaluate_prefix(prefix, {"a": a_bits, "b": b_bits}, out_bits=16) |
| | out_int = bits_to_int(result.bits) |
| | return float16_bits_to_float(out_int), result |
| |
|
| | def float16_unary(self, op: str, x: float) -> Tuple[float, EvalResult]: |
| | prefix = f"float16.{op}" |
| | x_bits = int_to_bits(float_to_float16_bits(x), 16) |
| | |
| | result = self.evaluate_prefix(prefix, {"x": x_bits}, out_bits=16) |
| | out_int = bits_to_int(result.bits) |
| | return float16_bits_to_float(out_int), result |
| |
|
| | def float16_pow(self, a: float, b: float) -> Tuple[float, EvalResult]: |
| | prefix = "float16.pow" |
| | a_bits = int_to_bits(float_to_float16_bits(a), 16) |
| | b_bits = int_to_bits(float_to_float16_bits(b), 16) |
| | result = self.evaluate_prefix(prefix, {"a": a_bits, "b": b_bits}, out_bits=16) |
| | out_int = bits_to_int(result.bits) |
| | return float16_bits_to_float(out_int), result |
| |
|
| | def _const_bits(self, name: str, fallback: float) -> int: |
| | if name in self._const_cache: |
| | return self._const_cache[name] |
| | prefix = f"float16.const_{name}" |
| | if f"{prefix}.out0.weight" in self.tensors: |
| | res = self.evaluate_prefix(prefix, {}, out_bits=16) |
| | self._const_cache[name] = bits_to_int(res.bits) |
| | else: |
| | self._const_cache[name] = float_to_float16_bits(fallback) |
| | return self._const_cache[name] |
| |
|
| | def evaluate_rpn( |
| | self, |
| | tokens: Sequence[str], |
| | force_gate_eval: bool = True, |
| | angle_mode: str = "rad", |
| | ) -> EvalResult: |
| | """Evaluate an expression from RPN tokens using float16 circuits.""" |
| | total_elapsed = 0.0 |
| | total_gates = 0 |
| | non_gate_events: List[str] = [] |
| | angle_mode = (angle_mode or "rad").lower() |
| | use_degrees = angle_mode.startswith("deg") |
| |
|
| | def run_prefix(prefix: str, inputs: Dict[str, object], outputs: Optional[List[str]] = None) -> EvalResult: |
| | nonlocal total_elapsed, total_gates |
| | res = self.evaluate_prefix(prefix, inputs, out_bits=16, outputs=outputs) |
| | total_elapsed += res.elapsed_s |
| | total_gates += res.gates_evaluated |
| | return res |
| |
|
| | def resolve_unary_outputs(prefix: str) -> List[str]: |
| | names: List[str] = [] |
| | for i in range(16): |
| | checked = f"{prefix}.checked_out{i}" |
| | if f"{checked}.weight" in self.tensors: |
| | names.append(checked) |
| | else: |
| | names.append(f"{prefix}.out{i}") |
| | return names |
| |
|
| | def const_to_bits(tok: str) -> int: |
| | if tok == "pi": |
| | return self._const_bits("pi", math.pi) |
| | if tok == "e": |
| | return self._const_bits("e", math.e) |
| | if tok == "deg2rad": |
| | return self._const_bits("deg2rad", math.pi / 180.0) |
| | if tok == "rad2deg": |
| | return self._const_bits("rad2deg", 180.0 / math.pi) |
| | if tok == "inf": |
| | return float_to_float16_bits(float("inf")) |
| | if tok == "nan": |
| | return float_to_float16_bits(float("nan")) |
| | try: |
| | return float_to_float16_bits(float(tok)) |
| | except ValueError: |
| | raise RuntimeError(f"bad token: {tok}") |
| |
|
| | stack: List[int] = [] |
| | unary_ops = { |
| | "sqrt": "float16.sqrt", |
| | "rsqrt": "float16.rsqrt", |
| | "exp": "float16.exp", |
| | "ln": "float16.ln", |
| | "log": "float16.ln", |
| | "log2": "float16.log2", |
| | "log10": "float16.log10", |
| | "deg2rad": "float16.deg2rad", |
| | "rad2deg": "float16.rad2deg", |
| | "isnan": "float16.is_nan", |
| | "is_nan": "float16.is_nan", |
| | "isinf": "float16.is_inf", |
| | "is_inf": "float16.is_inf", |
| | "isfinite": "float16.is_finite", |
| | "is_finite": "float16.is_finite", |
| | "iszero": "float16.is_zero", |
| | "is_zero": "float16.is_zero", |
| | "issubnormal": "float16.is_subnormal", |
| | "is_subnormal": "float16.is_subnormal", |
| | "isnormal": "float16.is_normal", |
| | "is_normal": "float16.is_normal", |
| | "isneg": "float16.is_negative", |
| | "is_negative": "float16.is_negative", |
| | "signbit": "float16.is_negative", |
| | "sin": "float16.sin", |
| | "cos": "float16.cos", |
| | "tan": "float16.tan", |
| | "tanh": "float16.tanh", |
| | "asin": "float16.asin", |
| | "acos": "float16.acos", |
| | "atan": "float16.atan", |
| | "sinh": "float16.sinh", |
| | "cosh": "float16.cosh", |
| | "floor": "float16.floor", |
| | "ceil": "float16.ceil", |
| | "round": "float16.round", |
| | "abs": "float16.abs", |
| | "neg": "float16.neg", |
| | } |
| |
|
| | for tok in tokens: |
| | if tok in unary_ops: |
| | if not stack: |
| | raise RuntimeError("stack underflow") |
| | x = stack.pop() |
| | prefix = unary_ops[tok] |
| | if use_degrees and tok in ("sin", "cos", "tan"): |
| | prefix = f"float16.{tok}_deg" |
| | elif use_degrees and tok in ("asin", "acos", "atan"): |
| | prefix = f"float16.{tok}_deg" |
| | if f"{prefix}.domain.weight" in self.tensors: |
| | outs = resolve_unary_outputs(prefix) + [f"{prefix}.domain"] |
| | res = run_prefix(prefix, {"x": x}, outputs=outs) |
| | if res.bits[16] >= 0.5: |
| | x_val = float16_bits_to_float(x) |
| | raise RuntimeError(f"domain error: {tok}({x_val})") |
| | out = bits_to_int(res.bits[:16]) |
| | else: |
| | res = run_prefix(prefix, {"x": x}) |
| | out = bits_to_int(res.bits) |
| | stack.append(out) |
| | continue |
| | if tok in {"+", "-", "*", "/", "^"}: |
| | if len(stack) < 2: |
| | raise RuntimeError("stack underflow") |
| | b = stack.pop() |
| | a = stack.pop() |
| | if tok == "+": |
| | out = bits_to_int(run_prefix("float16.add", {"a": a, "b": b}).bits) |
| | elif tok == "-": |
| | b_flip = b ^ 0x8000 |
| | out = bits_to_int(run_prefix("float16.sub", {"a": a, "b": b_flip}).bits) |
| | elif tok == "*": |
| | out = bits_to_int(run_prefix("float16.mul", {"a": a, "b": b}).bits) |
| | elif tok == "/": |
| | out = bits_to_int(run_prefix("float16.div", {"a": a, "b": b}).bits) |
| | else: |
| | out = bits_to_int(run_prefix("float16.pow", {"a": a, "b": b}).bits) |
| | stack.append(out) |
| | continue |
| | stack.append(const_to_bits(tok)) |
| |
|
| | if len(stack) != 1: |
| | raise RuntimeError("invalid expression") |
| |
|
| | out_bits = stack.pop() |
| | if total_gates == 0: |
| | if force_gate_eval: |
| | out_bits = bits_to_int(run_prefix("float16.add", {"a": out_bits, "b": 0}).bits) |
| | else: |
| | non_gate_events.append("constant_expression_no_gates") |
| |
|
| | return EvalResult( |
| | bits=int_to_bits(out_bits, 16), |
| | elapsed_s=total_elapsed, |
| | gates_evaluated=total_gates, |
| | non_gate_events=non_gate_events, |
| | ) |
| |
|
| | def evaluate_expr( |
| | self, |
| | expr: str, |
| | force_gate_eval: bool = True, |
| | angle_mode: str = "rad", |
| | ) -> EvalResult: |
| | """Evaluate a calculator expression using float16 circuits.""" |
| | expr = normalize_expr(expr) |
| | angle_mode = (angle_mode or "rad").lower() |
| | use_degrees = angle_mode.startswith("deg") |
| | tree = ast.parse(expr, mode="eval") |
| |
|
| | total_elapsed = 0.0 |
| | total_gates = 0 |
| | non_gate_events: List[str] = [] |
| |
|
| | def run_prefix(prefix: str, inputs: Dict[str, object], outputs: Optional[List[str]] = None) -> EvalResult: |
| | nonlocal total_elapsed, total_gates |
| | res = self.evaluate_prefix(prefix, inputs, out_bits=16, outputs=outputs) |
| | total_elapsed += res.elapsed_s |
| | total_gates += res.gates_evaluated |
| | return res |
| |
|
| | def run_unary(prefix: str, x_bits: int, fname: str) -> int: |
| | if f"{prefix}.domain.weight" in self.tensors: |
| | outs = [] |
| | for i in range(16): |
| | checked = f"{prefix}.checked_out{i}" |
| | if f"{checked}.weight" in self.tensors: |
| | outs.append(checked) |
| | else: |
| | outs.append(f"{prefix}.out{i}") |
| | outs.append(f"{prefix}.domain") |
| | res = run_prefix(prefix, {"x": x_bits}, outputs=outs) |
| | if res.bits[16] >= 0.5: |
| | x_val = float16_bits_to_float(x_bits) |
| | raise RuntimeError(f"domain error: {fname}({x_val})") |
| | return bits_to_int(res.bits[:16]) |
| | return bits_to_int(run_prefix(prefix, {"x": x_bits}).bits) |
| |
|
| | def eval_node(node: ast.AST) -> int: |
| | if isinstance(node, ast.Expression): |
| | return eval_node(node.body) |
| | if isinstance(node, ast.Constant): |
| | if isinstance(node.value, (int, float)): |
| | return float_to_float16_bits(float(node.value)) |
| | raise RuntimeError("unsupported literal") |
| | if isinstance(node, ast.Name): |
| | name = node.id |
| | if name == "pi": |
| | return self._const_bits("pi", math.pi) |
| | if name == "e": |
| | return self._const_bits("e", math.e) |
| | if name == "deg2rad": |
| | return self._const_bits("deg2rad", math.pi / 180.0) |
| | if name == "rad2deg": |
| | return self._const_bits("rad2deg", 180.0 / math.pi) |
| | if name == "inf": |
| | return float_to_float16_bits(float("inf")) |
| | if name == "nan": |
| | return float_to_float16_bits(float("nan")) |
| | raise RuntimeError(f"unknown identifier: {name}") |
| | if isinstance(node, ast.UnaryOp): |
| | if isinstance(node.op, ast.UAdd): |
| | return eval_node(node.operand) |
| | if isinstance(node.op, ast.USub): |
| | x = eval_node(node.operand) |
| | return bits_to_int(run_prefix("float16.neg", {"x": x}).bits) |
| | raise RuntimeError("unsupported unary operator") |
| | if isinstance(node, ast.BinOp): |
| | a = eval_node(node.left) |
| | b = eval_node(node.right) |
| | if isinstance(node.op, ast.Add): |
| | return bits_to_int(run_prefix("float16.add", {"a": a, "b": b}).bits) |
| | if isinstance(node.op, ast.Sub): |
| | b_flip = b ^ 0x8000 |
| | return bits_to_int(run_prefix("float16.sub", {"a": a, "b": b_flip}).bits) |
| | if isinstance(node.op, ast.Mult): |
| | return bits_to_int(run_prefix("float16.mul", {"a": a, "b": b}).bits) |
| | if isinstance(node.op, ast.Div): |
| | return bits_to_int(run_prefix("float16.div", {"a": a, "b": b}).bits) |
| | if isinstance(node.op, ast.Pow): |
| | return bits_to_int(run_prefix("float16.pow", {"a": a, "b": b}).bits) |
| | raise RuntimeError("unsupported binary operator") |
| | if isinstance(node, ast.Call): |
| | if not isinstance(node.func, ast.Name): |
| | raise RuntimeError("unsupported function") |
| | fname = node.func.id |
| | if len(node.args) != 1: |
| | raise RuntimeError(f"{fname} expects one argument") |
| | x = eval_node(node.args[0]) |
| | if fname == "sqrt": |
| | return run_unary("float16.sqrt", x, fname) |
| | if fname == "rsqrt": |
| | return run_unary("float16.rsqrt", x, fname) |
| | if fname == "exp": |
| | return run_unary("float16.exp", x, fname) |
| | if fname in ("ln", "log"): |
| | return run_unary("float16.ln", x, fname) |
| | if fname == "log2": |
| | return run_unary("float16.log2", x, fname) |
| | if fname == "log10": |
| | return run_unary("float16.log10", x, fname) |
| | if fname == "deg2rad": |
| | return run_unary("float16.deg2rad", x, fname) |
| | if fname == "rad2deg": |
| | return run_unary("float16.rad2deg", x, fname) |
| | if fname in ("isnan", "is_nan"): |
| | return run_unary("float16.is_nan", x, fname) |
| | if fname in ("isinf", "is_inf"): |
| | return run_unary("float16.is_inf", x, fname) |
| | if fname in ("isfinite", "is_finite"): |
| | return run_unary("float16.is_finite", x, fname) |
| | if fname in ("iszero", "is_zero"): |
| | return run_unary("float16.is_zero", x, fname) |
| | if fname in ("issubnormal", "is_subnormal"): |
| | return run_unary("float16.is_subnormal", x, fname) |
| | if fname in ("isnormal", "is_normal"): |
| | return run_unary("float16.is_normal", x, fname) |
| | if fname in ("isneg", "is_negative", "signbit"): |
| | return run_unary("float16.is_negative", x, fname) |
| | if fname == "sin": |
| | prefix = "float16.sin_deg" if use_degrees else "float16.sin" |
| | return run_unary(prefix, x, fname) |
| | if fname == "cos": |
| | prefix = "float16.cos_deg" if use_degrees else "float16.cos" |
| | return run_unary(prefix, x, fname) |
| | if fname == "tan": |
| | prefix = "float16.tan_deg" if use_degrees else "float16.tan" |
| | return run_unary(prefix, x, fname) |
| | if fname == "tanh": |
| | return run_unary("float16.tanh", x, fname) |
| | if fname == "asin": |
| | prefix = "float16.asin_deg" if use_degrees else "float16.asin" |
| | return run_unary(prefix, x, fname) |
| | if fname == "acos": |
| | prefix = "float16.acos_deg" if use_degrees else "float16.acos" |
| | return run_unary(prefix, x, fname) |
| | if fname == "atan": |
| | prefix = "float16.atan_deg" if use_degrees else "float16.atan" |
| | return run_unary(prefix, x, fname) |
| | if fname == "sinh": |
| | return run_unary("float16.sinh", x, fname) |
| | if fname == "cosh": |
| | return run_unary("float16.cosh", x, fname) |
| | if fname == "floor": |
| | return run_unary("float16.floor", x, fname) |
| | if fname == "ceil": |
| | return run_unary("float16.ceil", x, fname) |
| | if fname == "round": |
| | return run_unary("float16.round", x, fname) |
| | if fname == "abs": |
| | return run_unary("float16.abs", x, fname) |
| | if fname == "neg": |
| | return run_unary("float16.neg", x, fname) |
| | raise RuntimeError(f"unsupported function: {fname}") |
| | raise RuntimeError("unsupported expression") |
| |
|
| | out_bits = eval_node(tree) |
| | if total_gates == 0: |
| | if force_gate_eval: |
| | |
| | out_bits = bits_to_int(run_prefix("float16.add", {"a": out_bits, "b": 0}).bits) |
| | else: |
| | non_gate_events.append("constant_expression_no_gates") |
| | return EvalResult( |
| | bits=int_to_bits(out_bits, 16), |
| | elapsed_s=total_elapsed, |
| | gates_evaluated=total_gates, |
| | non_gate_events=non_gate_events, |
| | ) |
| |
|
| |
|
| | def main() -> int: |
| | parser = argparse.ArgumentParser(description="Gate-level calculator for threshold-calculus") |
| | parser.add_argument("prefix", nargs="?", default="", help="Circuit prefix (e.g., float16.add) or expression") |
| | parser.add_argument("values", nargs="*", help="Input values (float for float16, int otherwise)") |
| | parser.add_argument("--model", default="./arithmetic.safetensors", help="Path to safetensors model") |
| | parser.add_argument("--out-bits", type=int, default=16, help="Number of output bits") |
| | parser.add_argument("--inputs", nargs="*", help="Explicit inputs as name=value (e.g., a=0x3c00)") |
| | parser.add_argument("--hex", action="store_true", help="Parse numeric inputs as hex") |
| | parser.add_argument("--expr", help="Evaluate expression using float16 circuits") |
| | parser.add_argument("--angle", default="rad", choices=["rad", "deg"], help="Angle mode for trig functions") |
| | parser.add_argument("--json", action="store_true", help="Output JSON result") |
| | parser.add_argument("--strict", action="store_true", help="Warn if any non-gate path is used") |
| | args = parser.parse_args() |
| |
|
| | calc = ThresholdCalculator(args.model) |
| |
|
| | def emit_result(prefix: str, out_int: int, result: EvalResult, expr: Optional[str] = None) -> int: |
| | out_float = float16_bits_to_float(out_int) if len(result.bits) == 16 else None |
| | if args.strict and result.non_gate_events: |
| | print(f"STRICT WARNING: non-gate path used: {result.non_gate_events}") |
| | if args.json: |
| | payload = { |
| | "prefix": prefix, |
| | "expr": expr, |
| | "bits": f"0x{out_int:04x}", |
| | "float16": out_float, |
| | "gates": result.gates_evaluated, |
| | "elapsed_s": result.elapsed_s, |
| | "non_gate_events": result.non_gate_events, |
| | } |
| | print(json.dumps(payload)) |
| | else: |
| | if expr: |
| | print(f"expr={expr}") |
| | print(f"bits=0x{out_int:04x} float16={out_float}") |
| | print(f"gates={result.gates_evaluated} elapsed_s={result.elapsed_s:.4f}") |
| | return 0 |
| |
|
| | if args.expr or (args.prefix and not args.values and not args.inputs and looks_like_expression(args.prefix)): |
| | expr = args.expr if args.expr else args.prefix |
| | result = calc.evaluate_expr(expr, angle_mode=args.angle) |
| | out_int = bits_to_int(result.bits) |
| | return emit_result("expr", out_int, result, expr=expr) |
| |
|
| | if not args.prefix: |
| | raise RuntimeError("Provide a circuit prefix or use --expr") |
| |
|
| | if args.inputs: |
| | inputs: Dict[str, object] = {} |
| | for item in args.inputs: |
| | if "=" not in item: |
| | raise RuntimeError("inputs must be name=value") |
| | key, val = item.split("=", 1) |
| | if args.hex or val.startswith("0x"): |
| | inputs[key] = int(val, 16) |
| | else: |
| | try: |
| | inputs[key] = int(val) |
| | except ValueError: |
| | inputs[key] = float(val) |
| | result = calc.evaluate_prefix(args.prefix, inputs, out_bits=args.out_bits) |
| | out_int = bits_to_int(result.bits) |
| | print(f"bits={result.bits}") |
| | print(f"int=0x{out_int:0{(args.out_bits + 3) // 4}x}") |
| | if args.out_bits == 16: |
| | pass |
| | return emit_result(args.prefix, out_int, result) |
| |
|
| | |
| | prefix = args.prefix |
| | if prefix.startswith("float16."): |
| | op = prefix.split(".", 1)[1] |
| | if op == "pow": |
| | if len(args.values) != 2: |
| | raise RuntimeError("float16.pow requires two values") |
| | out, result = calc.float16_pow(float(args.values[0]), float(args.values[1])) |
| | elif op in ("add", "sub", "mul", "div"): |
| | if len(args.values) != 2: |
| | raise RuntimeError(f"{prefix} requires two values") |
| | out, result = calc.float16_binop(op, float(args.values[0]), float(args.values[1])) |
| | else: |
| | if len(args.values) != 1: |
| | raise RuntimeError(f"{prefix} requires one value") |
| | out, result = calc.float16_unary(op, float(args.values[0])) |
| | out_bits = bits_to_int(result.bits) |
| | return emit_result(prefix, out_bits, result) |
| |
|
| | raise RuntimeError("Provide --inputs for non-float16 circuits") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | raise SystemExit(main()) |
| |
|