| """ |
| Quantize threshold-computer safetensors to the minimum signed integer |
| dtype that exactly represents each tensor. |
| |
| Weights and biases in this library are integer-valued by construction, |
| with one historical exception: a handful of legacy buffer gates use a |
| bias of -0.5 (e.g. arithmetic.asr8bit.bit*.bias). For binary inputs, |
| H(x - 0.5) and H(x - 1) are identical, so those biases are floored to |
| -1 before casting. |
| |
| This is a packaging optimization, not a precision change: the eval |
| pipeline already promotes weights to float32 on load, so integer |
| storage is exact. |
| |
| The --ternary flag also rewrites single-input weight=+/-2 identity |
| buffers (SHL/SHR/ROL/ROR bit gates, stack data buffers, RET address |
| buffers, flag buffers) to weight=+/-1 with bias adjusted as needed to |
| preserve heaviside output for binary inputs. After this pass every |
| weight tensor in the file lies in {-1, 0, 1} except for positional |
| comparators and a few hand-constructed modular arithmetic circuits |
| (see the violation report); fully ternarizing those requires bit- |
| cascading in build.py. |
| |
| Usage: |
| python quantize.py path/to/file.safetensors # in-place |
| python quantize.py path/to/file.safetensors -o out.safetensors # to new file |
| python quantize.py variants/ # whole directory in place |
| python quantize.py variants/ -o variants_int/ # whole directory to new dir |
| python quantize.py file.safetensors --ternary # try ternary weights |
| python quantize.py file.safetensors --ternary --strict # error if any weight non-ternary |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| import torch |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
| DTYPES = [ |
| (torch.int8, -(1 << 7), (1 << 7) - 1), |
| (torch.int16, -(1 << 15), (1 << 15) - 1), |
| (torch.int32, -(1 << 31), (1 << 31) - 1), |
| (torch.int64, None, None), |
| ] |
|
|
|
|
| def _normalize_to_int(tensor: torch.Tensor) -> torch.Tensor: |
| """Return a tensor with strictly integer values, floored from any |
| half-integer values. Floor (not round) because a -0.5 bias must |
| become -1 (not 0) to preserve H(x + bias) for binary x.""" |
| if not tensor.dtype.is_floating_point: |
| return tensor.to(torch.float64) |
| tf = tensor.to(torch.float64) |
| rounded = tf.round() |
| if torch.equal(rounded, tf): |
| return tf |
| doubled = tf * 2.0 |
| if torch.equal(doubled.round(), doubled): |
| return torch.floor(tf) |
| raise ValueError( |
| f"tensor has non-half-integer values; range " |
| f"[{tf.min().item()}, {tf.max().item()}]" |
| ) |
|
|
|
|
| def _min_signed_int_dtype(tensor: torch.Tensor) -> torch.dtype: |
| if tensor.numel() == 0: |
| return torch.int8 |
| lo = int(tensor.min().item()) |
| hi = int(tensor.max().item()) |
| for dtype, lo_lim, hi_lim in DTYPES: |
| if lo_lim is None or (lo_lim <= lo and hi <= hi_lim): |
| return dtype |
| return torch.int64 |
|
|
|
|
| def _ternarize_modular_and_patterns( |
| tensors: Dict[str, torch.Tensor], |
| ) -> Tuple[Dict[str, torch.Tensor], Dict]: |
| """Replace seed-file modular detectors and pattern_recognition gates |
| with bit-cascade-equivalent ternary structures. |
| |
| Modular: for each modulus N in {3,5,6,7,9,10,11,12}, replace the |
| layer1.geq{i}/layer1.leq{i}/layer2.eq{i}/layer3.or chain with a |
| bit-cascade equality detector per multiple of N in [0, 256), then |
| OR all detectors together. Top-level gate stays named |
| `modular.mod{N}` (a multi-input OR over per-multiple matches). |
| |
| Pattern_recognition.leadingones / trailingones are dropped: they are |
| seed-file artifacts with no eval coverage and no downstream |
| consumers in this codebase. |
| """ |
| new_tensors = dict(tensors) |
| fixed = 0 |
|
|
| |
| pr_dropped = 0 |
| for k in list(new_tensors.keys()): |
| if (k.startswith("pattern_recognition.leadingones") |
| or k.startswith("pattern_recognition.trailingones")): |
| del new_tensors[k] |
| pr_dropped += 1 |
|
|
| |
| moduli = [3, 5, 6, 7, 9, 10, 11, 12] |
| mod_gates_added = 0 |
| for mod in moduli: |
| prefix = f"modular.mod{mod}" |
| |
| for k in list(new_tensors.keys()): |
| if k.startswith(prefix + "."): |
| del new_tensors[k] |
|
|
| multiples = list(range(0, 256, mod)) |
| |
| for k in multiples: |
| for i in range(8): |
| k_bit = (k >> (7 - i)) & 1 |
| if k_bit == 1: |
| |
| |
| new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.weight"] = torch.tensor([1.0], dtype=torch.float64) |
| new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.bias"] = torch.tensor([-1.0], dtype=torch.float64) |
| else: |
| |
| new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.weight"] = torch.tensor([-1.0], dtype=torch.float64) |
| new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.bias"] = torch.tensor([0.0], dtype=torch.float64) |
| mod_gates_added += 1 |
| |
| new_tensors[f"{prefix}.eq.k{k}.all.weight"] = torch.tensor([1.0] * 8, dtype=torch.float64) |
| new_tensors[f"{prefix}.eq.k{k}.all.bias"] = torch.tensor([-8.0], dtype=torch.float64) |
| mod_gates_added += 1 |
|
|
| |
| m = len(multiples) |
| new_tensors[f"{prefix}.weight"] = torch.tensor([1.0] * m, dtype=torch.float64) |
| new_tensors[f"{prefix}.bias"] = torch.tensor([-1.0], dtype=torch.float64) |
| mod_gates_added += 1 |
|
|
| return new_tensors, { |
| "pattern_recognition_dropped": pr_dropped, |
| "modular_gates_added": mod_gates_added, |
| "modular_moduli": len(moduli), |
| } |
|
|
|
|
| def _ternarize_buffers( |
| tensors: Dict[str, torch.Tensor], |
| ) -> Tuple[Dict[str, torch.Tensor], Dict]: |
| """Rewrite single-input weight=+-2 identity buffers as weight=+-1 with |
| bias adjusted to preserve heaviside output for binary inputs. |
| |
| For a single-input gate H(w*x + b) with x in {0, 1}, the only thing |
| that matters is the pair (H(b), H(w + b)). Pick the smallest integer |
| bias b' such that (H(b'), H(sgn + b')) matches, with sgn = sign(w). |
| |
| Returns (new_tensors, stats). stats has 'fixed', 'failed', 'failed_names'. |
| """ |
| new_tensors = dict(tensors) |
| fixed = 0 |
| failed_names = [] |
|
|
| weight_keys = [k for k in tensors if k.endswith(".weight")] |
| for wkey in weight_keys: |
| w = tensors[wkey] |
| wf = w.float() if w.dtype.is_floating_point else w.to(torch.float64).float() |
| if (wf.abs() <= 1.0).all(): |
| continue |
|
|
| gate = wkey[: -len(".weight")] |
| bkey = gate + ".bias" |
|
|
| |
| if ( |
| wf.numel() == 1 |
| and abs(wf.item()) == 2.0 |
| and bkey in tensors |
| and tensors[bkey].numel() == 1 |
| ): |
| w_val = wf.item() |
| b_val = float(tensors[bkey].float().item()) |
| sgn = 1.0 if w_val > 0 else -1.0 |
| x0_target = 1 if b_val >= 0 else 0 |
| x1_target = 1 if (w_val + b_val) >= 0 else 0 |
| chosen = None |
| |
| for b_new in [int(b_val), int(b_val) - 1, -1, 0, -2, 1, -3, 2]: |
| x0 = 1 if b_new >= 0 else 0 |
| x1 = 1 if (sgn + b_new) >= 0 else 0 |
| if x0 == x0_target and x1 == x1_target: |
| chosen = b_new |
| break |
| if chosen is not None: |
| new_tensors[wkey] = torch.tensor([sgn], dtype=torch.float64) |
| new_tensors[bkey] = torch.tensor([float(chosen)], dtype=torch.float64) |
| fixed += 1 |
| continue |
|
|
| failed_names.append(wkey) |
|
|
| return new_tensors, {"fixed": fixed, "failed_names": failed_names} |
|
|
|
|
| def quantize_tensors( |
| tensors: Dict[str, torch.Tensor], |
| ternary: bool = False, |
| ) -> Tuple[Dict[str, torch.Tensor], Dict[str, int], Tuple[int, int], Dict]: |
| """Quantize a dict of tensors. Returns |
| (new_tensors, dtype_counts, (bytes_before, bytes_after), ternary_stats).""" |
| ternary_stats: Dict = {"applied": False, "fixed": 0, "failed_names": []} |
| if ternary: |
| tensors, ternary_stats = _ternarize_buffers(tensors) |
| ternary_stats["applied"] = True |
| |
| tensors, mod_stats = _ternarize_modular_and_patterns(tensors) |
| ternary_stats["modular_gates_added"] = mod_stats["modular_gates_added"] |
| ternary_stats["pattern_recognition_dropped"] = mod_stats["pattern_recognition_dropped"] |
|
|
| new_tensors: Dict[str, torch.Tensor] = {} |
| counts: Dict[str, int] = {"int8": 0, "int16": 0, "int32": 0, "int64": 0, |
| "manifest_kept": 0, "skipped": 0} |
| bytes_before = 0 |
| bytes_after = 0 |
|
|
| for name, t in tensors.items(): |
| bytes_before += t.numel() * t.element_size() |
|
|
| if name.startswith("manifest."): |
| new_tensors[name] = t |
| counts["manifest_kept"] += 1 |
| bytes_after += t.numel() * t.element_size() |
| continue |
|
|
| try: |
| normalized = _normalize_to_int(t) |
| except ValueError: |
| new_tensors[name] = t |
| counts["skipped"] += 1 |
| bytes_after += t.numel() * t.element_size() |
| continue |
|
|
| target = _min_signed_int_dtype(normalized) |
| cast = normalized.to(target) |
| new_tensors[name] = cast |
| bytes_after += cast.numel() * cast.element_size() |
| counts[str(target).replace("torch.", "")] += 1 |
|
|
| return new_tensors, counts, (bytes_before, bytes_after), ternary_stats |
|
|
|
|
| def quantize_file(in_path: Path, out_path: Path, verbose: bool = False, |
| ternary: bool = False, strict_ternary: bool = False) -> Dict: |
| file_before = in_path.stat().st_size |
| tensors: Dict[str, torch.Tensor] = {} |
| metadata: Dict[str, str] = {} |
| with safe_open(str(in_path), framework="pt") as f: |
| meta = f.metadata() |
| if meta: |
| metadata = dict(meta) |
| for name in f.keys(): |
| |
| tensors[name] = f.get_tensor(name).clone() |
|
|
| new_tensors, counts, (before, after), tstats = quantize_tensors(tensors, ternary=ternary) |
|
|
| |
| final_nonternary = [] |
| for k, v in new_tensors.items(): |
| if not k.endswith(".weight"): |
| continue |
| if k.startswith("manifest."): |
| continue |
| vf = v.float() if v.dtype.is_floating_point else v.to(torch.float64).float() |
| if (vf.abs() > 1.0).any(): |
| final_nonternary.append(k) |
|
|
| if ternary and strict_ternary and final_nonternary: |
| raise ValueError( |
| f"--strict failed: {len(final_nonternary)} weight tensors are not " |
| f"ternary after transformation; first: {final_nonternary[:5]}" |
| ) |
|
|
| |
| del tensors |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| if ternary: |
| |
| if metadata is None: |
| metadata = {} |
| metadata = dict(metadata) |
| metadata["weight_quantization"] = ( |
| "ternary_partial" if final_nonternary else "ternary" |
| ) |
| save_file(new_tensors, str(out_path), metadata=metadata or None) |
|
|
| file_after = out_path.stat().st_size |
| return { |
| "in_path": str(in_path), |
| "out_path": str(out_path), |
| "tensor_counts": counts, |
| "tensor_bytes_before": before, |
| "tensor_bytes_after": after, |
| "file_size_before": file_before, |
| "file_size_after": file_after, |
| "ternary": tstats, |
| "final_nonternary": final_nonternary, |
| } |
|
|
|
|
| def _print_summary(label: str, info: Dict) -> None: |
| cb = info["tensor_bytes_before"] |
| ca = info["tensor_bytes_after"] |
| fb = info["file_size_before"] |
| fa = info["file_size_after"] |
| counts = info["tensor_counts"] |
| bucket_str = " ".join(f"{k}={v}" for k, v in counts.items() if v) |
| ratio_t = cb / ca if ca else 1.0 |
| ratio_f = fb / fa if fa else 1.0 |
| print( |
| f" {label}: file {fb / 1e6:6.1f} MB -> {fa / 1e6:6.1f} MB " |
| f"({ratio_f:.2f}x); tensor data {cb / 1e6:6.1f} MB -> {ca / 1e6:6.1f} MB " |
| f"({ratio_t:.2f}x)" |
| ) |
| print(f" {bucket_str}") |
| if info.get("ternary", {}).get("applied"): |
| ts = info["ternary"] |
| nt = info["final_nonternary"] |
| print(f" ternary: {ts['fixed']} buffer gates rewritten; " |
| f"{len(nt)} weight tensors remain non-ternary") |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description="Quantize safetensors to min signed int dtype") |
| parser.add_argument("input", type=Path, help=".safetensors file or directory of files") |
| parser.add_argument("-o", "--output", type=Path, default=None, |
| help="output file or directory (default: in-place)") |
| parser.add_argument("-v", "--verbose", action="store_true") |
| parser.add_argument("--ternary", action="store_true", |
| help="Rewrite single-input weight=+/-2 buffers as +/-1 to push toward " |
| "ternary {-1, 0, 1} weights and report remaining violations") |
| parser.add_argument("--strict", action="store_true", |
| help="With --ternary, fail if any weight tensor is still non-ternary") |
| parser.add_argument("--report-violations", type=int, default=0, metavar="N", |
| help="Print first N non-ternary weight tensor names per file") |
| args = parser.parse_args() |
|
|
| inputs = [] |
| if args.input.is_dir(): |
| inputs = sorted(p for p in args.input.glob("*.safetensors")) |
| elif args.input.is_file(): |
| inputs = [args.input] |
| else: |
| print(f"not found: {args.input}", file=sys.stderr) |
| return 2 |
|
|
| if not inputs: |
| print(f"no .safetensors files under {args.input}", file=sys.stderr) |
| return 2 |
|
|
| if args.output is None: |
| outputs = inputs |
| elif args.output.suffix == ".safetensors": |
| if len(inputs) != 1: |
| print("output is a single file but input is a directory; pass a directory output", file=sys.stderr) |
| return 2 |
| outputs = [args.output] |
| else: |
| args.output.mkdir(parents=True, exist_ok=True) |
| outputs = [args.output / p.name for p in inputs] |
|
|
| total_before = 0 |
| total_after = 0 |
| print(f"Quantizing {len(inputs)} file(s)" + (" (ternary mode)" if args.ternary else "") + "\n") |
| for src, dst in zip(inputs, outputs): |
| info = quantize_file(src, dst, verbose=args.verbose, |
| ternary=args.ternary, strict_ternary=args.strict) |
| _print_summary(src.name, info) |
| if args.report_violations and info.get("final_nonternary"): |
| for name in info["final_nonternary"][: args.report_violations]: |
| print(f" non-ternary: {name}") |
| if len(info["final_nonternary"]) > args.report_violations: |
| print(f" ... and {len(info['final_nonternary']) - args.report_violations} more") |
| total_before += info["file_size_before"] |
| total_after += info["file_size_after"] |
|
|
| print() |
| print("=" * 76) |
| print( |
| f"Total: {total_before / 1e6:.1f} MB -> {total_after / 1e6:.1f} MB " |
| f"({total_before / max(total_after, 1):.2f}x reduction)" |
| ) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|