CharlesCNorton
Full ternary: rebuild modular detectors and drop pattern_recognition
c844a11
"""
Quantize threshold-computer safetensors to the minimum signed integer
dtype that exactly represents each tensor.
Weights and biases in this library are integer-valued by construction,
with one historical exception: a handful of legacy buffer gates use a
bias of -0.5 (e.g. arithmetic.asr8bit.bit*.bias). For binary inputs,
H(x - 0.5) and H(x - 1) are identical, so those biases are floored to
-1 before casting.
This is a packaging optimization, not a precision change: the eval
pipeline already promotes weights to float32 on load, so integer
storage is exact.
The --ternary flag also rewrites single-input weight=+/-2 identity
buffers (SHL/SHR/ROL/ROR bit gates, stack data buffers, RET address
buffers, flag buffers) to weight=+/-1 with bias adjusted as needed to
preserve heaviside output for binary inputs. After this pass every
weight tensor in the file lies in {-1, 0, 1} except for positional
comparators and a few hand-constructed modular arithmetic circuits
(see the violation report); fully ternarizing those requires bit-
cascading in build.py.
Usage:
python quantize.py path/to/file.safetensors # in-place
python quantize.py path/to/file.safetensors -o out.safetensors # to new file
python quantize.py variants/ # whole directory in place
python quantize.py variants/ -o variants_int/ # whole directory to new dir
python quantize.py file.safetensors --ternary # try ternary weights
python quantize.py file.safetensors --ternary --strict # error if any weight non-ternary
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Dict, Tuple
import torch
from safetensors import safe_open
from safetensors.torch import save_file
DTYPES = [
(torch.int8, -(1 << 7), (1 << 7) - 1),
(torch.int16, -(1 << 15), (1 << 15) - 1),
(torch.int32, -(1 << 31), (1 << 31) - 1),
(torch.int64, None, None), # always fits
]
def _normalize_to_int(tensor: torch.Tensor) -> torch.Tensor:
"""Return a tensor with strictly integer values, floored from any
half-integer values. Floor (not round) because a -0.5 bias must
become -1 (not 0) to preserve H(x + bias) for binary x."""
if not tensor.dtype.is_floating_point:
return tensor.to(torch.float64) # promote for range checks
tf = tensor.to(torch.float64)
rounded = tf.round()
if torch.equal(rounded, tf):
return tf
doubled = tf * 2.0
if torch.equal(doubled.round(), doubled):
return torch.floor(tf)
raise ValueError(
f"tensor has non-half-integer values; range "
f"[{tf.min().item()}, {tf.max().item()}]"
)
def _min_signed_int_dtype(tensor: torch.Tensor) -> torch.dtype:
if tensor.numel() == 0:
return torch.int8
lo = int(tensor.min().item())
hi = int(tensor.max().item())
for dtype, lo_lim, hi_lim in DTYPES:
if lo_lim is None or (lo_lim <= lo and hi <= hi_lim):
return dtype
return torch.int64
def _ternarize_modular_and_patterns(
tensors: Dict[str, torch.Tensor],
) -> Tuple[Dict[str, torch.Tensor], Dict]:
"""Replace seed-file modular detectors and pattern_recognition gates
with bit-cascade-equivalent ternary structures.
Modular: for each modulus N in {3,5,6,7,9,10,11,12}, replace the
layer1.geq{i}/layer1.leq{i}/layer2.eq{i}/layer3.or chain with a
bit-cascade equality detector per multiple of N in [0, 256), then
OR all detectors together. Top-level gate stays named
`modular.mod{N}` (a multi-input OR over per-multiple matches).
Pattern_recognition.leadingones / trailingones are dropped: they are
seed-file artifacts with no eval coverage and no downstream
consumers in this codebase.
"""
new_tensors = dict(tensors)
fixed = 0
# --- pattern_recognition: drop leadingones/trailingones ---
pr_dropped = 0
for k in list(new_tensors.keys()):
if (k.startswith("pattern_recognition.leadingones")
or k.startswith("pattern_recognition.trailingones")):
del new_tensors[k]
pr_dropped += 1
# --- modular: rebuild as ternary bit-cascade equality per multiple ---
moduli = [3, 5, 6, 7, 9, 10, 11, 12]
mod_gates_added = 0
for mod in moduli:
prefix = f"modular.mod{mod}"
# Drop old structure
for k in list(new_tensors.keys()):
if k.startswith(prefix + "."):
del new_tensors[k]
multiples = list(range(0, 256, mod))
# Per-bit match gates + per-multiple AND
for k in multiples:
for i in range(8): # i=0 is MSB (matches inputs MSB-first ordering)
k_bit = (k >> (7 - i)) & 1
if k_bit == 1:
# bit_match = x[i]: H(x - 0.5) ~~ identity for binary;
# use weight=1, bias=-1 -> H(x-1): x=0 -> 0, x=1 -> 1.
new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.weight"] = torch.tensor([1.0], dtype=torch.float64)
new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.bias"] = torch.tensor([-1.0], dtype=torch.float64)
else:
# bit_match = NOT x[i]: weight=-1, bias=0 -> H(-x): x=0 -> 1, x=1 -> 0.
new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.weight"] = torch.tensor([-1.0], dtype=torch.float64)
new_tensors[f"{prefix}.eq.k{k}.bit{i}.match.bias"] = torch.tensor([0.0], dtype=torch.float64)
mod_gates_added += 1
# AND of all 8 bit-matches: weights [1]*8, bias -8
new_tensors[f"{prefix}.eq.k{k}.all.weight"] = torch.tensor([1.0] * 8, dtype=torch.float64)
new_tensors[f"{prefix}.eq.k{k}.all.bias"] = torch.tensor([-8.0], dtype=torch.float64)
mod_gates_added += 1
# Final OR over all per-multiple match outputs
m = len(multiples)
new_tensors[f"{prefix}.weight"] = torch.tensor([1.0] * m, dtype=torch.float64)
new_tensors[f"{prefix}.bias"] = torch.tensor([-1.0], dtype=torch.float64)
mod_gates_added += 1
return new_tensors, {
"pattern_recognition_dropped": pr_dropped,
"modular_gates_added": mod_gates_added,
"modular_moduli": len(moduli),
}
def _ternarize_buffers(
tensors: Dict[str, torch.Tensor],
) -> Tuple[Dict[str, torch.Tensor], Dict]:
"""Rewrite single-input weight=+-2 identity buffers as weight=+-1 with
bias adjusted to preserve heaviside output for binary inputs.
For a single-input gate H(w*x + b) with x in {0, 1}, the only thing
that matters is the pair (H(b), H(w + b)). Pick the smallest integer
bias b' such that (H(b'), H(sgn + b')) matches, with sgn = sign(w).
Returns (new_tensors, stats). stats has 'fixed', 'failed', 'failed_names'.
"""
new_tensors = dict(tensors)
fixed = 0
failed_names = []
weight_keys = [k for k in tensors if k.endswith(".weight")]
for wkey in weight_keys:
w = tensors[wkey]
wf = w.float() if w.dtype.is_floating_point else w.to(torch.float64).float()
if (wf.abs() <= 1.0).all():
continue # already ternary
gate = wkey[: -len(".weight")]
bkey = gate + ".bias"
# Single-input weight=+-2 buffer with single bias
if (
wf.numel() == 1
and abs(wf.item()) == 2.0
and bkey in tensors
and tensors[bkey].numel() == 1
):
w_val = wf.item()
b_val = float(tensors[bkey].float().item())
sgn = 1.0 if w_val > 0 else -1.0
x0_target = 1 if b_val >= 0 else 0
x1_target = 1 if (w_val + b_val) >= 0 else 0
chosen = None
# Prefer keeping the bias unchanged when possible
for b_new in [int(b_val), int(b_val) - 1, -1, 0, -2, 1, -3, 2]:
x0 = 1 if b_new >= 0 else 0
x1 = 1 if (sgn + b_new) >= 0 else 0
if x0 == x0_target and x1 == x1_target:
chosen = b_new
break
if chosen is not None:
new_tensors[wkey] = torch.tensor([sgn], dtype=torch.float64)
new_tensors[bkey] = torch.tensor([float(chosen)], dtype=torch.float64)
fixed += 1
continue
failed_names.append(wkey)
return new_tensors, {"fixed": fixed, "failed_names": failed_names}
def quantize_tensors(
tensors: Dict[str, torch.Tensor],
ternary: bool = False,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, int], Tuple[int, int], Dict]:
"""Quantize a dict of tensors. Returns
(new_tensors, dtype_counts, (bytes_before, bytes_after), ternary_stats)."""
ternary_stats: Dict = {"applied": False, "fixed": 0, "failed_names": []}
if ternary:
tensors, ternary_stats = _ternarize_buffers(tensors)
ternary_stats["applied"] = True
# Also rebuild modular detectors and drop pattern_recognition stragglers.
tensors, mod_stats = _ternarize_modular_and_patterns(tensors)
ternary_stats["modular_gates_added"] = mod_stats["modular_gates_added"]
ternary_stats["pattern_recognition_dropped"] = mod_stats["pattern_recognition_dropped"]
new_tensors: Dict[str, torch.Tensor] = {}
counts: Dict[str, int] = {"int8": 0, "int16": 0, "int32": 0, "int64": 0,
"manifest_kept": 0, "skipped": 0}
bytes_before = 0
bytes_after = 0
for name, t in tensors.items():
bytes_before += t.numel() * t.element_size()
if name.startswith("manifest."):
new_tensors[name] = t
counts["manifest_kept"] += 1
bytes_after += t.numel() * t.element_size()
continue
try:
normalized = _normalize_to_int(t)
except ValueError:
new_tensors[name] = t
counts["skipped"] += 1
bytes_after += t.numel() * t.element_size()
continue
target = _min_signed_int_dtype(normalized)
cast = normalized.to(target)
new_tensors[name] = cast
bytes_after += cast.numel() * cast.element_size()
counts[str(target).replace("torch.", "")] += 1
return new_tensors, counts, (bytes_before, bytes_after), ternary_stats
def quantize_file(in_path: Path, out_path: Path, verbose: bool = False,
ternary: bool = False, strict_ternary: bool = False) -> Dict:
file_before = in_path.stat().st_size
tensors: Dict[str, torch.Tensor] = {}
metadata: Dict[str, str] = {}
with safe_open(str(in_path), framework="pt") as f:
meta = f.metadata()
if meta:
metadata = dict(meta)
for name in f.keys():
# clone so the source mmap can be released before we write
tensors[name] = f.get_tensor(name).clone()
new_tensors, counts, (before, after), tstats = quantize_tensors(tensors, ternary=ternary)
# Audit final ternary status (count of weight tensors with |w| > 1)
final_nonternary = []
for k, v in new_tensors.items():
if not k.endswith(".weight"):
continue
if k.startswith("manifest."):
continue
vf = v.float() if v.dtype.is_floating_point else v.to(torch.float64).float()
if (vf.abs() > 1.0).any():
final_nonternary.append(k)
if ternary and strict_ternary and final_nonternary:
raise ValueError(
f"--strict failed: {len(final_nonternary)} weight tensors are not "
f"ternary after transformation; first: {final_nonternary[:5]}"
)
# Drop the original mmap-backed tensors before writing in-place.
del tensors
out_path.parent.mkdir(parents=True, exist_ok=True)
if ternary:
# Note ternary mode in metadata so downstream tools can see it
if metadata is None:
metadata = {}
metadata = dict(metadata)
metadata["weight_quantization"] = (
"ternary_partial" if final_nonternary else "ternary"
)
save_file(new_tensors, str(out_path), metadata=metadata or None)
file_after = out_path.stat().st_size
return {
"in_path": str(in_path),
"out_path": str(out_path),
"tensor_counts": counts,
"tensor_bytes_before": before,
"tensor_bytes_after": after,
"file_size_before": file_before,
"file_size_after": file_after,
"ternary": tstats,
"final_nonternary": final_nonternary,
}
def _print_summary(label: str, info: Dict) -> None:
cb = info["tensor_bytes_before"]
ca = info["tensor_bytes_after"]
fb = info["file_size_before"]
fa = info["file_size_after"]
counts = info["tensor_counts"]
bucket_str = " ".join(f"{k}={v}" for k, v in counts.items() if v)
ratio_t = cb / ca if ca else 1.0
ratio_f = fb / fa if fa else 1.0
print(
f" {label}: file {fb / 1e6:6.1f} MB -> {fa / 1e6:6.1f} MB "
f"({ratio_f:.2f}x); tensor data {cb / 1e6:6.1f} MB -> {ca / 1e6:6.1f} MB "
f"({ratio_t:.2f}x)"
)
print(f" {bucket_str}")
if info.get("ternary", {}).get("applied"):
ts = info["ternary"]
nt = info["final_nonternary"]
print(f" ternary: {ts['fixed']} buffer gates rewritten; "
f"{len(nt)} weight tensors remain non-ternary")
def main() -> int:
parser = argparse.ArgumentParser(description="Quantize safetensors to min signed int dtype")
parser.add_argument("input", type=Path, help=".safetensors file or directory of files")
parser.add_argument("-o", "--output", type=Path, default=None,
help="output file or directory (default: in-place)")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("--ternary", action="store_true",
help="Rewrite single-input weight=+/-2 buffers as +/-1 to push toward "
"ternary {-1, 0, 1} weights and report remaining violations")
parser.add_argument("--strict", action="store_true",
help="With --ternary, fail if any weight tensor is still non-ternary")
parser.add_argument("--report-violations", type=int, default=0, metavar="N",
help="Print first N non-ternary weight tensor names per file")
args = parser.parse_args()
inputs = []
if args.input.is_dir():
inputs = sorted(p for p in args.input.glob("*.safetensors"))
elif args.input.is_file():
inputs = [args.input]
else:
print(f"not found: {args.input}", file=sys.stderr)
return 2
if not inputs:
print(f"no .safetensors files under {args.input}", file=sys.stderr)
return 2
if args.output is None:
outputs = inputs # in-place
elif args.output.suffix == ".safetensors":
if len(inputs) != 1:
print("output is a single file but input is a directory; pass a directory output", file=sys.stderr)
return 2
outputs = [args.output]
else:
args.output.mkdir(parents=True, exist_ok=True)
outputs = [args.output / p.name for p in inputs]
total_before = 0
total_after = 0
print(f"Quantizing {len(inputs)} file(s)" + (" (ternary mode)" if args.ternary else "") + "\n")
for src, dst in zip(inputs, outputs):
info = quantize_file(src, dst, verbose=args.verbose,
ternary=args.ternary, strict_ternary=args.strict)
_print_summary(src.name, info)
if args.report_violations and info.get("final_nonternary"):
for name in info["final_nonternary"][: args.report_violations]:
print(f" non-ternary: {name}")
if len(info["final_nonternary"]) > args.report_violations:
print(f" ... and {len(info['final_nonternary']) - args.report_violations} more")
total_before += info["file_size_before"]
total_after += info["file_size_after"]
print()
print("=" * 76)
print(
f"Total: {total_before / 1e6:.1f} MB -> {total_after / 1e6:.1f} MB "
f"({total_before / max(total_after, 1):.2f}x reduction)"
)
return 0
if __name__ == "__main__":
sys.exit(main())