""" Quantize threshold-computer safetensors to the minimum signed integer dtype that exactly represents each tensor. Weights and biases in this library are integer-valued by construction, with one historical exception: a handful of legacy buffer gates use a bias of -0.5 (e.g. arithmetic.asr8bit.bit*.bias). For binary inputs, H(x - 0.5) and H(x - 1) are identical, so those biases are floored to -1 before casting. This is a packaging optimization, not a precision change: the eval pipeline already promotes weights to float32 on load, so integer storage is exact. The --ternary flag also rewrites single-input weight=+/-2 identity buffers (SHL/SHR/ROL/ROR bit gates, stack data buffers, RET address buffers, flag buffers) to weight=+/-1 with bias adjusted as needed to preserve heaviside output for binary inputs. After this pass every weight tensor in the file lies in {-1, 0, 1} except for positional comparators and a few hand-constructed modular arithmetic circuits (see the violation report); fully ternarizing those requires bit- cascading in build.py. Usage: python quantize.py path/to/file.safetensors # in-place python quantize.py path/to/file.safetensors -o out.safetensors # to new file python quantize.py variants/ # whole directory in place python quantize.py variants/ -o variants_int/ # whole directory to new dir python quantize.py file.safetensors --ternary # try ternary weights python quantize.py file.safetensors --ternary --strict # error if any weight non-ternary """ from __future__ import annotations import argparse import sys from pathlib import Path from typing import Dict, Tuple import torch from safetensors import safe_open from safetensors.torch import save_file DTYPES = [ (torch.int8, -(1 << 7), (1 << 7) - 1), (torch.int16, -(1 << 15), (1 << 15) - 1), (torch.int32, -(1 << 31), (1 << 31) - 1), (torch.int64, None, None), # always fits ] def _normalize_to_int(tensor: torch.Tensor) -> torch.Tensor: """Return a tensor with strictly integer values, floored from any half-integer values. Floor (not round) because a -0.5 bias must become -1 (not 0) to preserve H(x + bias) for binary x.""" if not tensor.dtype.is_floating_point: return tensor.to(torch.float64) # promote for range checks tf = tensor.to(torch.float64) rounded = tf.round() if torch.equal(rounded, tf): return tf doubled = tf * 2.0 if torch.equal(doubled.round(), doubled): return torch.floor(tf) raise ValueError( f"tensor has non-half-integer values; range " f"[{tf.min().item()}, {tf.max().item()}]" ) def _min_signed_int_dtype(tensor: torch.Tensor) -> torch.dtype: if tensor.numel() == 0: return torch.int8 lo = int(tensor.min().item()) hi = int(tensor.max().item()) for dtype, lo_lim, hi_lim in DTYPES: if lo_lim is None or (lo_lim <= lo and hi <= hi_lim): return dtype return torch.int64 def _ternarize_modular_and_patterns( tensors: Dict[str, torch.Tensor], ) -> Tuple[Dict[str, torch.Tensor], Dict]: """Replace seed-file modular detectors and pattern_recognition gates with bit-cascade-equivalent ternary structures. Modular: for each modulus N in {3,5,6,7,9,10,11,12}, replace the layer1.geq{i}/layer1.leq{i}/layer2.eq{i}/layer3.or chain with a bit-cascade equality detector per multiple of N in [0, 256), then OR all detectors together. Top-level gate stays named `modular.mod{N}` (a multi-input OR over per-multiple matches). Pattern_recognition.leadingones / trailingones are dropped: they are seed-file artifacts with no eval coverage and no downstream consumers in this codebase. """ new_tensors = dict(tensors) fixed = 0 # --- pattern_recognition: drop leadingones/trailingones --- pr_dropped = 0 for k in list(new_tensors.keys()): if (k.startswith("pattern_recognition.leadingones") or k.startswith("pattern_recognition.trailingones")): del new_tensors[k] pr_dropped += 1 # --- modular: rebuild as ternary bit-cascade equality per multiple --- moduli = [3, 5, 6, 7, 9, 10, 11, 12] mod_gates_added = 0 for mod in moduli: prefix = f"modular.mod{mod}" # Drop old structure for k in list(new_tensors.keys()): if k.startswith(prefix + "."): del new_tensors[k] multiples = list(range(0, 256, mod)) # Per-bit match gates + per-multiple AND for k in multiples: for i in range(8): # i=0 is MSB (matches inputs MSB-first ordering) k_bit = (k >> (7 - i)) & 1 if k_bit == 1: # bit_match = x[i]: H(x - 0.5) ~~ identity for binary; # use weight=1, bias=-1 -> H(x-1): x=0 -> 0, x=1 -> 1. new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.weight"] = torch.tensor([1.0], dtype=torch.float64) new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.bias"] = torch.tensor([-1.0], dtype=torch.float64) else: # bit_match = NOT x[i]: weight=-1, bias=0 -> H(-x): x=0 -> 1, x=1 -> 0. new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.weight"] = torch.tensor([-1.0], dtype=torch.float64) new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.bias"] = torch.tensor([0.0], dtype=torch.float64) mod_gates_added += 1 # AND of all 8 bit-matches: weights [1]*8, bias -8 new_tensors[f"{prefix}.eq.k{k}.all.weight"] = torch.tensor([1.0] * 8, dtype=torch.float64) new_tensors[f"{prefix}.eq.k{k}.all.bias"] = torch.tensor([-8.0], dtype=torch.float64) mod_gates_added += 1 # Final OR over all per-multiple match outputs m = len(multiples) new_tensors[f"{prefix}.weight"] = torch.tensor([1.0] * m, dtype=torch.float64) new_tensors[f"{prefix}.bias"] = torch.tensor([-1.0], dtype=torch.float64) mod_gates_added += 1 return new_tensors, { "pattern_recognition_dropped": pr_dropped, "modular_gates_added": mod_gates_added, "modular_moduli": len(moduli), } def _ternarize_buffers( tensors: Dict[str, torch.Tensor], ) -> Tuple[Dict[str, torch.Tensor], Dict]: """Rewrite single-input weight=+-2 identity buffers as weight=+-1 with bias adjusted to preserve heaviside output for binary inputs. For a single-input gate H(w*x + b) with x in {0, 1}, the only thing that matters is the pair (H(b), H(w + b)). Pick the smallest integer bias b' such that (H(b'), H(sgn + b')) matches, with sgn = sign(w). Returns (new_tensors, stats). stats has 'fixed', 'failed', 'failed_names'. """ new_tensors = dict(tensors) fixed = 0 failed_names = [] weight_keys = [k for k in tensors if k.endswith(".weight")] for wkey in weight_keys: w = tensors[wkey] wf = w.float() if w.dtype.is_floating_point else w.to(torch.float64).float() if (wf.abs() <= 1.0).all(): continue # already ternary gate = wkey[: -len(".weight")] bkey = gate + ".bias" # Single-input weight=+-2 buffer with single bias if ( wf.numel() == 1 and abs(wf.item()) == 2.0 and bkey in tensors and tensors[bkey].numel() == 1 ): w_val = wf.item() b_val = float(tensors[bkey].float().item()) sgn = 1.0 if w_val > 0 else -1.0 x0_target = 1 if b_val >= 0 else 0 x1_target = 1 if (w_val + b_val) >= 0 else 0 chosen = None # Prefer keeping the bias unchanged when possible for b_new in [int(b_val), int(b_val) - 1, -1, 0, -2, 1, -3, 2]: x0 = 1 if b_new >= 0 else 0 x1 = 1 if (sgn + b_new) >= 0 else 0 if x0 == x0_target and x1 == x1_target: chosen = b_new break if chosen is not None: new_tensors[wkey] = torch.tensor([sgn], dtype=torch.float64) new_tensors[bkey] = torch.tensor([float(chosen)], dtype=torch.float64) fixed += 1 continue failed_names.append(wkey) return new_tensors, {"fixed": fixed, "failed_names": failed_names} def quantize_tensors( tensors: Dict[str, torch.Tensor], ternary: bool = False, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, int], Tuple[int, int], Dict]: """Quantize a dict of tensors. Returns (new_tensors, dtype_counts, (bytes_before, bytes_after), ternary_stats).""" ternary_stats: Dict = {"applied": False, "fixed": 0, "failed_names": []} if ternary: tensors, ternary_stats = _ternarize_buffers(tensors) ternary_stats["applied"] = True # Also rebuild modular detectors and drop pattern_recognition stragglers. tensors, mod_stats = _ternarize_modular_and_patterns(tensors) ternary_stats["modular_gates_added"] = mod_stats["modular_gates_added"] ternary_stats["pattern_recognition_dropped"] = mod_stats["pattern_recognition_dropped"] new_tensors: Dict[str, torch.Tensor] = {} counts: Dict[str, int] = {"int8": 0, "int16": 0, "int32": 0, "int64": 0, "manifest_kept": 0, "skipped": 0} bytes_before = 0 bytes_after = 0 for name, t in tensors.items(): bytes_before += t.numel() * t.element_size() if name.startswith("manifest."): new_tensors[name] = t counts["manifest_kept"] += 1 bytes_after += t.numel() * t.element_size() continue try: normalized = _normalize_to_int(t) except ValueError: new_tensors[name] = t counts["skipped"] += 1 bytes_after += t.numel() * t.element_size() continue target = _min_signed_int_dtype(normalized) cast = normalized.to(target) new_tensors[name] = cast bytes_after += cast.numel() * cast.element_size() counts[str(target).replace("torch.", "")] += 1 return new_tensors, counts, (bytes_before, bytes_after), ternary_stats def quantize_file(in_path: Path, out_path: Path, verbose: bool = False, ternary: bool = False, strict_ternary: bool = False) -> Dict: file_before = in_path.stat().st_size tensors: Dict[str, torch.Tensor] = {} metadata: Dict[str, str] = {} with safe_open(str(in_path), framework="pt") as f: meta = f.metadata() if meta: metadata = dict(meta) for name in f.keys(): # clone so the source mmap can be released before we write tensors[name] = f.get_tensor(name).clone() new_tensors, counts, (before, after), tstats = quantize_tensors(tensors, ternary=ternary) # Audit final ternary status (count of weight tensors with |w| > 1) final_nonternary = [] for k, v in new_tensors.items(): if not k.endswith(".weight"): continue if k.startswith("manifest."): continue vf = v.float() if v.dtype.is_floating_point else v.to(torch.float64).float() if (vf.abs() > 1.0).any(): final_nonternary.append(k) if ternary and strict_ternary and final_nonternary: raise ValueError( f"--strict failed: {len(final_nonternary)} weight tensors are not " f"ternary after transformation; first: {final_nonternary[:5]}" ) # Drop the original mmap-backed tensors before writing in-place. del tensors out_path.parent.mkdir(parents=True, exist_ok=True) if ternary: # Note ternary mode in metadata so downstream tools can see it if metadata is None: metadata = {} metadata = dict(metadata) metadata["weight_quantization"] = ( "ternary_partial" if final_nonternary else "ternary" ) save_file(new_tensors, str(out_path), metadata=metadata or None) file_after = out_path.stat().st_size return { "in_path": str(in_path), "out_path": str(out_path), "tensor_counts": counts, "tensor_bytes_before": before, "tensor_bytes_after": after, "file_size_before": file_before, "file_size_after": file_after, "ternary": tstats, "final_nonternary": final_nonternary, } def _print_summary(label: str, info: Dict) -> None: cb = info["tensor_bytes_before"] ca = info["tensor_bytes_after"] fb = info["file_size_before"] fa = info["file_size_after"] counts = info["tensor_counts"] bucket_str = " ".join(f"{k}={v}" for k, v in counts.items() if v) ratio_t = cb / ca if ca else 1.0 ratio_f = fb / fa if fa else 1.0 print( f" {label}: file {fb / 1e6:6.1f} MB -> {fa / 1e6:6.1f} MB " f"({ratio_f:.2f}x); tensor data {cb / 1e6:6.1f} MB -> {ca / 1e6:6.1f} MB " f"({ratio_t:.2f}x)" ) print(f" {bucket_str}") if info.get("ternary", {}).get("applied"): ts = info["ternary"] nt = info["final_nonternary"] print(f" ternary: {ts['fixed']} buffer gates rewritten; " f"{len(nt)} weight tensors remain non-ternary") def main() -> int: parser = argparse.ArgumentParser(description="Quantize safetensors to min signed int dtype") parser.add_argument("input", type=Path, help=".safetensors file or directory of files") parser.add_argument("-o", "--output", type=Path, default=None, help="output file or directory (default: in-place)") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--ternary", action="store_true", help="Rewrite single-input weight=+/-2 buffers as +/-1 to push toward " "ternary {-1, 0, 1} weights and report remaining violations") parser.add_argument("--strict", action="store_true", help="With --ternary, fail if any weight tensor is still non-ternary") parser.add_argument("--report-violations", type=int, default=0, metavar="N", help="Print first N non-ternary weight tensor names per file") args = parser.parse_args() inputs = [] if args.input.is_dir(): inputs = sorted(p for p in args.input.glob("*.safetensors")) elif args.input.is_file(): inputs = [args.input] else: print(f"not found: {args.input}", file=sys.stderr) return 2 if not inputs: print(f"no .safetensors files under {args.input}", file=sys.stderr) return 2 if args.output is None: outputs = inputs # in-place elif args.output.suffix == ".safetensors": if len(inputs) != 1: print("output is a single file but input is a directory; pass a directory output", file=sys.stderr) return 2 outputs = [args.output] else: args.output.mkdir(parents=True, exist_ok=True) outputs = [args.output / p.name for p in inputs] total_before = 0 total_after = 0 print(f"Quantizing {len(inputs)} file(s)" + (" (ternary mode)" if args.ternary else "") + "\n") for src, dst in zip(inputs, outputs): info = quantize_file(src, dst, verbose=args.verbose, ternary=args.ternary, strict_ternary=args.strict) _print_summary(src.name, info) if args.report_violations and info.get("final_nonternary"): for name in info["final_nonternary"][: args.report_violations]: print(f" non-ternary: {name}") if len(info["final_nonternary"]) > args.report_violations: print(f" ... and {len(info['final_nonternary']) - args.report_violations} more") total_before += info["file_size_before"] total_after += info["file_size_after"] print() print("=" * 76) print( f"Total: {total_before / 1e6:.1f} MB -> {total_after / 1e6:.1f} MB " f"({total_before / max(total_after, 1):.2f}x reduction)" ) return 0 if __name__ == "__main__": sys.exit(main())