unary-quantization-research / unary_convert.py
OpenTransformer's picture
Add files using upload-large-folder tool
19ed98b verified
#!/usr/bin/env python3
"""
Convert model weights to UNARY (base-1) thermometer encoding.
True unary: magnitude N = N consecutive 1-bits across N bitplanes.
Each bitplane contributes equally (value=1), NOT binary powers.
Weight 0.3 with scale -> magnitude 5 -> planes 0,1,2,3,4 have bit set
Weight -0.1 with scale -> magnitude 2, sign=neg -> planes 0,1 set + sign bit
More precision than ternary (N+1 levels vs 3), still no multiplication.
(c) 2026 OpenTransformers Ltd / Scott Bisset
"""
import os
import json
import numpy as np
from pathlib import Path
import time
def load_safetensors(model_dir):
"""Load all tensors from safetensors files."""
import torch
from safetensors.torch import load_file
tensors = {}
for f in sorted(Path(model_dir).glob("*.safetensors")):
print(f"Loading {f.name}...")
state = load_file(str(f))
for key, val in state.items():
tensors[key] = val.float().numpy()
return tensors
def quantize_matrix_unary(weight, n_planes=7):
"""Quantize weight matrix to unary thermometer encoding.
n_planes determines max magnitude (and precision levels = n_planes + 1).
n_planes=7 gives 8 levels: {0,1,2,3,4,5,6,7} * sign = 15 distinct values.
Returns: sign_bits, mag_planes, scales, sparsity
"""
w = weight.astype(np.float32)
out_dim, in_dim = w.shape
chunks = ((in_dim + 63) // 64)
padded = chunks * 64
# Per-row quantization
row_max = np.max(np.abs(w), axis=1, keepdims=True)
row_max = np.where(row_max == 0, 1.0, row_max)
# Scale to [0, n_planes] range per row
scales = (row_max.flatten() / n_planes).astype(np.float32)
# Quantize to integer magnitudes
w_scaled = w / scales[:, None] # Now in [-n_planes, +n_planes]
magnitudes = np.round(np.abs(w_scaled)).astype(np.int32)
magnitudes = np.clip(magnitudes, 0, n_planes)
signs = (w < 0) # True = negative
# Sparsity (magnitude 0)
sparsity = np.mean(magnitudes == 0)
# Pad to multiple of 64
if in_dim < padded:
magnitudes = np.concatenate([magnitudes, np.zeros((out_dim, padded - in_dim), dtype=np.int32)], axis=1)
signs = np.concatenate([signs, np.zeros((out_dim, padded - in_dim), dtype=bool)], axis=1)
# Pack sign bits - vectorized
bit_positions = (np.uint64(1) << np.arange(64, dtype=np.uint64))
signs_r = signs.reshape(out_dim, chunks, 64).astype(np.uint64)
sign_bits = np.bitwise_or.reduce(signs_r * bit_positions, axis=2) # [out_dim, chunks]
# Pack magnitude planes - thermometer encoding
# Plane p has bit set where magnitude > p (i.e., magnitude >= p+1)
mag_planes = np.zeros((n_planes, out_dim, chunks), dtype=np.uint64)
for p in range(n_planes):
active = (magnitudes >= (p + 1)) # [out_dim, padded]
active_r = active.reshape(out_dim, chunks, 64).astype(np.uint64)
mag_planes[p] = np.bitwise_or.reduce(active_r * bit_positions, axis=2)
return sign_bits, mag_planes, scales, sparsity
def save_unary_model(tensors, output_dir, n_planes=7):
"""Convert and save full model to unary format."""
os.makedirs(output_dir, exist_ok=True)
config = {
"hidden_size": 1536,
"intermediate_size": 8960,
"num_attention_heads": 12,
"num_key_value_heads": 2,
"num_hidden_layers": 28,
"vocab_size": 151936,
"head_dim": 128,
"rope_theta": 1000000.0,
"rms_norm_eps": 1e-6,
"n_planes": n_planes,
"quant_type": "unary",
}
ternary_keys = []
keep_keys = []
for key in tensors:
if any(p in key for p in ['q_proj.weight', 'k_proj.weight', 'v_proj.weight',
'o_proj.weight', 'gate_proj.weight', 'up_proj.weight',
'down_proj.weight']):
ternary_keys.append(key)
else:
keep_keys.append(key)
print(f"\nUnary layers: {len(ternary_keys)} (n_planes={n_planes}, levels={n_planes+1})")
print(f"FP16 layers: {len(keep_keys)}")
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=2)
total_unary_bytes = 0
total_original_bytes = 0
for key in ternary_keys:
w = tensors[key]
out_dim, in_dim = w.shape
total_original_bytes += w.nbytes
t0 = time.time()
sign_bits, mag_planes, scales, sparsity = quantize_matrix_unary(w, n_planes)
dt = time.time() - t0
prefix = os.path.join(output_dir, key.replace(".", "_"))
sign_bits.tofile(prefix + ".sign")
mag_planes.tofile(prefix + ".planes")
scales.tofile(prefix + ".scales")
unary_bytes = sign_bits.nbytes + mag_planes.nbytes + scales.nbytes
total_unary_bytes += unary_bytes
ratio = w.nbytes / unary_bytes
# Calculate effective bits per weight
bpw = (unary_bytes * 8) / (out_dim * in_dim)
print(f" {key}: {w.shape} -> unary ({unary_bytes/1024:.0f}KB, "
f"{ratio:.1f}x compress, {bpw:.2f} bpw, {sparsity:.1%} sparse, {dt:.1f}s)")
total_fp16_bytes = 0
for key in keep_keys:
w = tensors[key].astype(np.float16)
prefix = os.path.join(output_dir, key.replace(".", "_"))
w.tofile(prefix + ".fp16")
total_fp16_bytes += w.nbytes
print(f" {key}: {w.shape} -> fp16 ({w.nbytes/1024:.0f}KB)")
manifest = {
"unary": {k: list(tensors[k].shape) for k in ternary_keys},
"fp16": {k: list(tensors[k].shape) for k in keep_keys},
}
with open(os.path.join(output_dir, "manifest.json"), "w") as f:
json.dump(manifest, f, indent=2)
total_bytes = total_unary_bytes + total_fp16_bytes
avg_bpw = (total_unary_bytes * 8) / sum(np.prod(tensors[k].shape) for k in ternary_keys)
print(f"\n=== Summary ===")
print(f"Original FP32 linear weights: {total_original_bytes/1024/1024:.1f} MB")
print(f"Unary linear weights: {total_unary_bytes/1024/1024:.1f} MB")
print(f"FP16 other weights: {total_fp16_bytes/1024/1024:.1f} MB")
print(f"Total model size: {total_bytes/1024/1024:.1f} MB")
print(f"Average bits per weight (linear): {avg_bpw:.2f}")
print(f"Compression vs FP32: {(total_original_bytes + total_fp16_bytes)/total_bytes:.1f}x")
print(f"Precision levels: {n_planes + 1} (vs ternary=3, INT4=16)")
if __name__ == "__main__":
import sys
model_dir = sys.argv[1] if len(sys.argv) > 1 else "deepseek-r1-1.5b-hf"
output_dir = sys.argv[2] if len(sys.argv) > 2 else "deepseek-r1-1.5b-unary"
n_planes = int(sys.argv[3]) if len(sys.argv) > 3 else 7
print(f"Loading model from {model_dir}...")
tensors = load_safetensors(model_dir)
print(f"Converting to unary (n_planes={n_planes})...")
save_unary_model(tensors, output_dir, n_planes)
print("Done!")