unary-quantization-research / unary_convert_v2.py
OpenTransformer's picture
Add files using upload-large-folder tool
19ed98b verified
#!/usr/bin/env python3
"""
Pure Unary Converter - interleaved plane layout [out_dim][chunks][n_planes]
for cache-friendly access in the kernel.
(c) 2026 OpenTransformers Ltd / Scott Bisset
"""
import os, json, sys, time
import numpy as np
from pathlib import Path
def load_safetensors(model_dir):
import torch
from safetensors.torch import load_file
tensors = {}
for f in sorted(Path(model_dir).glob("*.safetensors")):
print(f"Loading {f.name}...")
for k, v in load_file(str(f)).items():
tensors[k] = v.float().numpy()
return tensors
def quantize_unary_interleaved(weight, n_planes):
"""Quantize and pack into interleaved layout [out_dim][chunks][n_planes]"""
w = weight.astype(np.float32)
out_dim, in_dim = w.shape
chunks = (in_dim + 63) // 64
padded = chunks * 64
row_max = np.max(np.abs(w), axis=1, keepdims=True)
row_max = np.where(row_max == 0, 1.0, row_max)
scales = (row_max.flatten() / n_planes).astype(np.float32)
w_scaled = w / scales[:, None]
magnitudes = np.round(np.abs(w_scaled)).astype(np.int32)
magnitudes = np.clip(magnitudes, 0, n_planes)
signs = (w < 0)
sparsity = np.mean(magnitudes == 0)
if in_dim < padded:
magnitudes = np.concatenate([magnitudes, np.zeros((out_dim, padded-in_dim), dtype=np.int32)], axis=1)
signs = np.concatenate([signs, np.zeros((out_dim, padded-in_dim), dtype=bool)], axis=1)
# Pack sign bits [out_dim][chunks]
bit_positions = (np.uint64(1) << np.arange(64, dtype=np.uint64))
signs_r = signs.reshape(out_dim, chunks, 64).astype(np.uint64)
sign_bits = np.bitwise_or.reduce(signs_r * bit_positions, axis=2)
# Pack magnitude planes INTERLEAVED: [out_dim][chunks][n_planes]
mag_planes = np.zeros((out_dim, chunks, n_planes), dtype=np.uint64)
for p in range(n_planes):
active = (magnitudes >= (p + 1)).reshape(out_dim, chunks, 64).astype(np.uint64)
mag_planes[:, :, p] = np.bitwise_or.reduce(active * bit_positions, axis=2)
return sign_bits, mag_planes, scales, sparsity
def convert(model_dir, output_dir, n_planes):
os.makedirs(output_dir, exist_ok=True)
tensors = load_safetensors(model_dir)
config = {
"hidden_size": 1536, "intermediate_size": 8960,
"num_attention_heads": 12, "num_key_value_heads": 2,
"num_hidden_layers": 28, "vocab_size": 151936,
"head_dim": 128, "rope_theta": 1000000.0,
"rms_norm_eps": 1e-6, "n_planes": n_planes,
"quant_type": "unary_interleaved",
}
linear_keys = [k for k in tensors if any(p in k for p in
['q_proj.weight','k_proj.weight','v_proj.weight','o_proj.weight',
'gate_proj.weight','up_proj.weight','down_proj.weight'])]
other_keys = [k for k in tensors if k not in linear_keys]
print(f"\nUnary: {len(linear_keys)} layers, {n_planes} planes ({2*n_planes+1} levels)")
print(f"FP16: {len(other_keys)} layers\n")
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=2)
total_unary = total_orig = total_fp16 = 0
for key in linear_keys:
w = tensors[key]
total_orig += w.nbytes
t0 = time.time()
sign_bits, mag_planes, scales, sparsity = quantize_unary_interleaved(w, n_planes)
dt = time.time() - t0
prefix = os.path.join(output_dir, key.replace(".", "_"))
sign_bits.tofile(prefix + ".sign")
mag_planes.tofile(prefix + ".planes") # [out_dim][chunks][n_planes] contiguous
scales.tofile(prefix + ".scales")
ub = sign_bits.nbytes + mag_planes.nbytes + scales.nbytes
total_unary += ub
bpw = (ub * 8) / (w.shape[0] * w.shape[1])
print(f" {key}: {w.shape} -> {ub/1024:.0f}KB ({bpw:.1f}bpw, {sparsity:.0%} sparse, {dt:.1f}s)")
for key in other_keys:
w = tensors[key].astype(np.float16)
prefix = os.path.join(output_dir, key.replace(".", "_"))
w.tofile(prefix + ".fp16")
total_fp16 += w.nbytes
print(f" {key}: {w.shape} -> fp16 ({w.nbytes/1024:.0f}KB)")
manifest = {
"unary": {k: list(tensors[k].shape) for k in linear_keys},
"fp16": {k: list(tensors[k].shape) for k in other_keys},
}
with open(os.path.join(output_dir, "manifest.json"), "w") as f:
json.dump(manifest, f, indent=2)
total = total_unary + total_fp16
avg_bpw = (total_unary * 8) / sum(np.prod(tensors[k].shape) for k in linear_keys)
print(f"\n=== Summary ===")
print(f"Unary weights: {total_unary/1024/1024:.1f} MB ({avg_bpw:.1f} avg bpw)")
print(f"FP16 weights: {total_fp16/1024/1024:.1f} MB")
print(f"Total: {total/1024/1024:.1f} MB")
print(f"Planes: {n_planes}, Levels: {2*n_planes+1}")
print(f"Layout: interleaved [out_dim][chunks][n_planes]")
print("Done!")
if __name__ == "__main__":
model_dir = sys.argv[1] if len(sys.argv) > 1 else "deepseek-r1-1.5b-hf"
output_dir = sys.argv[2] if len(sys.argv) > 2 else "deepseek-r1-1.5b-unary31"
n_planes = int(sys.argv[3]) if len(sys.argv) > 3 else 31
convert(model_dir, output_dir, n_planes)