unary-quantization-research / convert_fast.py
OpenTransformer's picture
Add files using upload-large-folder tool
19ed98b verified
#!/usr/bin/env python3
"""
FAST proper unary converter — vectorized bitpacking via numpy.
Instead of iterating columns one at a time, processes plane-by-plane
with vectorized comparisons, then packs to uint64 using np.packbits.
(c) 2026 OpenTransformers Ltd / Scott Bisset
"""
import torch, json, os, sys, gc, shutil
from safetensors import safe_open
import numpy as np
def pack_bits_to_uint64(bool_matrix):
"""
Pack [rows, cols] boolean → [rows, chunks] uint64
where chunks = ceil(cols/64).
Bit j of element (r, c) corresponds to column c*64+j.
Uses little-endian bit ordering within each uint64.
"""
rows, cols = bool_matrix.shape
chunks = (cols + 63) // 64
# Pad cols to multiple of 64
if cols % 64:
padded = np.zeros((rows, chunks * 64), dtype=np.uint8)
padded[:, :cols] = bool_matrix.astype(np.uint8)
else:
padded = bool_matrix.astype(np.uint8)
# Reshape to [rows, chunks, 64]
reshaped = padded.reshape(rows, chunks, 64)
# Pack: bit j of uint64 = reshaped[r, c, j]
# Build uint64 from 64 bits using shifts
result = np.zeros((rows, chunks), dtype=np.uint64)
for bit in range(64):
result |= reshaped[:, :, bit].astype(np.uint64) << np.uint64(bit)
return result
def encode_fast(weight_f32_np, quantum, K):
"""
Fast vectorized proper unary encoding.
weight_f32_np: [rows, cols] numpy float32
Returns: sign [rows, chunks] uint64, slots [K, rows, chunks] uint64, clip_count
"""
rows, cols = weight_f32_np.shape
chunks = (cols + 63) // 64
inv_q = 1.0 / quantum
magnitudes = np.round(np.abs(weight_f32_np) * inv_q).astype(np.int32)
clip_count = int(np.sum(magnitudes > K))
magnitudes = np.clip(magnitudes, 0, K)
# Sign: negative elements
signs_bool = weight_f32_np < 0 # [rows, cols]
sign_packed = pack_bits_to_uint64(signs_bool) # [rows, chunks]
# Unary slots: plane p is set where magnitude > p
# Process plane by plane (K iterations, each vectorized over entire matrix)
slots_packed = np.zeros((K, rows, chunks), dtype=np.uint64)
for p in range(K):
active = magnitudes > p # [rows, cols] boolean, fully vectorized
slots_packed[p] = pack_bits_to_uint64(active)
if (p + 1) % 8 == 0 or p == K - 1:
print(f" plane {p+1}/{K}", end="\r", flush=True)
print(f" {K}/{K} planes done, {clip_count} clipped")
return sign_packed, slots_packed, clip_count
def convert(model_dir, output_dir, K=32, clip_pct=99.9):
os.makedirs(output_dir, exist_ok=True)
config = json.load(open(os.path.join(model_dir, "config.json")))
print(f"Model: {config.get('model_type', '?')}")
print(f" Layers={config['num_hidden_layers']} Hidden={config['hidden_size']} Inter={config['intermediate_size']}")
# Index
index_path = os.path.join(model_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
index = json.load(open(index_path))
shards = sorted(set(index["weight_map"].values()))
weight_map = index["weight_map"]
else:
shards = ["model.safetensors"]
weight_map = None
# Scan for quantum
print("\nScanning weights...")
all_abs = []
linear_names = []
global_max = 0.0
for shard in shards:
path = os.path.join(model_dir, shard)
print(f" {shard}...")
with safe_open(path, framework="pt") as f:
for name in f.keys():
t = f.get_tensor(name).float()
if t.dim() == 2 and "norm" not in name and "embed" not in name:
linear_names.append(name)
am = t.abs().max().item()
if am > global_max: global_max = am
idx = torch.randint(0, t.numel(), (2000,))
all_abs.append(t.flatten()[idx].abs())
all_abs_t = torch.cat(all_abs)
clip_val = torch.quantile(all_abs_t, clip_pct / 100.0).item()
quantum = clip_val / K
print(f"\n Absmax={global_max:.6f} P{clip_pct}={clip_val:.6f}")
print(f" K={K} quantum={quantum:.8f}")
mags = (all_abs_t / quantum).round().clamp(0, K)
print(f" Mean mag={mags.mean():.1f} Median={mags.median():.1f} Zero={100*(mags==0).float().mean():.1f}% Clipped={100*(mags==K).float().mean():.1f}%")
del all_abs, all_abs_t, mags
gc.collect()
manifest = {
"format": "proper_unary",
"quantum": float(quantum),
"K": K,
"clip_pct": clip_pct,
"clip_val": float(clip_val),
"global_absmax": float(global_max),
"unary": {},
"fp16": [],
}
total_unary = 0
total_fp16 = 0
total_clip = 0
done = 0
for shard in shards:
path = os.path.join(model_dir, shard)
# Get linear names in this shard
shard_lins = [n for n in linear_names if (weight_map or {}).get(n, "model.safetensors") == shard]
print(f"\n{shard}: {len(shard_lins)} linear layers")
with safe_open(path, framework="pt") as f:
# Non-linear → FP16
for name in f.keys():
if name in linear_names:
continue
fname = name.replace(".", "_") + ".fp16"
out_path = os.path.join(output_dir, fname)
if not os.path.exists(out_path):
t = f.get_tensor(name).half().numpy()
t.view(np.uint16).tofile(out_path)
total_fp16 += os.path.getsize(out_path)
manifest["fp16"].append(name)
print(f" FP16: {name} {t.shape}")
# Linear → proper unary
for name in shard_lins:
fname = name.replace(".", "_")
sign_path = os.path.join(output_dir, f"{fname}.usign")
slots_path = os.path.join(output_dir, f"{fname}.uslots")
if os.path.exists(sign_path) and os.path.exists(slots_path):
t_shape = list(f.get_tensor(name).shape)
manifest["unary"][name] = t_shape
total_unary += os.path.getsize(sign_path) + os.path.getsize(slots_path)
done += 1
print(f" Skip: {name}")
continue
t = f.get_tensor(name).float().numpy()
rows, cols = t.shape
print(f" {name} [{rows}x{cols}]", flush=True)
sign_p, slots_p, clip_c = encode_fast(t, quantum, K)
total_clip += clip_c
sign_p.tofile(sign_path)
slots_p.tofile(slots_path)
s_sz = os.path.getsize(sign_path)
sl_sz = os.path.getsize(slots_path)
total_unary += s_sz + sl_sz
manifest["unary"][name] = [rows, cols]
done += 1
mb = (s_sz + sl_sz) / 1e6
print(f" → {mb:.1f} MB ({s_sz//1024}KB sign + {sl_sz//1024}KB slots)")
del t, sign_p, slots_p
gc.collect()
# Copy tokenizer/config files
for fname in os.listdir(model_dir):
if fname.endswith(('.json', '.txt', '.model')) and not fname.startswith('model.safetensors'):
src = os.path.join(model_dir, fname)
dst = os.path.join(output_dir, fname)
if not os.path.exists(dst):
shutil.copy2(src, dst)
json.dump(manifest, open(os.path.join(output_dir, "manifest.json"), "w"), indent=2)
total = total_unary + total_fp16
print(f"\n{'='*60}")
print(f"DONE: {done} layers, quantum={quantum:.8f}, K={K}")
print(f" Unary: {total_unary/1e9:.2f} GB")
print(f" FP16: {total_fp16/1e6:.1f} MB")
print(f" Total: {total/1e9:.2f} GB (vs ~7.6 GB BF16 = {total/7.6e9:.1f}x)")
print(f" Clipped: {total_clip} values")
print(f"{'='*60}")
if __name__ == "__main__":
model_dir = sys.argv[1] if len(sys.argv) > 1 else "qwen3-4b-thinking-hf"
output_dir = sys.argv[2] if len(sys.argv) > 2 else "qwen3-4b-proper-unary"
K = int(sys.argv[3]) if len(sys.argv) > 3 else 32
clip = float(sys.argv[4]) if len(sys.argv) > 4 else 99.9
convert(model_dir, output_dir, K=K, clip_pct=clip)