unary-quantization-research / convert_proper_unary_v2.py
OpenTransformer's picture
Add files using upload-large-folder tool
19ed98b verified
#!/usr/bin/env python3
"""
PROPER UNARY CONVERTER — Global quantum, torch-based, BF16 support
Clips at P99.9 of |weights| instead of absmax to avoid wasting
quantization range on rare outliers. Values above clip point
saturate at K (still represented, just capped).
(c) 2026 OpenTransformers Ltd / Scott Bisset
"""
import torch, json, os, sys, gc, shutil
from safetensors import safe_open
import numpy as np
def scan_all_linears(model_dir):
"""Scan linear layers, return global stats."""
index_path = os.path.join(model_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
index = json.load(open(index_path))
shards = sorted(set(index["weight_map"].values()))
else:
shards = ["model.safetensors"]
all_abs_samples = []
linear_names = []
global_max = 0.0
for shard in shards:
path = os.path.join(model_dir, shard)
print(f" Scanning {shard}...")
with safe_open(path, framework="pt") as f:
for name in f.keys():
t = f.get_tensor(name).float()
if t.dim() == 2 and "norm" not in name and "embed" not in name:
linear_names.append(name)
am = t.abs().max().item()
if am > global_max:
global_max = am
# Sample 2000 values for distribution
idx = torch.randint(0, t.numel(), (2000,))
all_abs_samples.append(t.flatten()[idx].abs())
all_abs = torch.cat(all_abs_samples)
return global_max, all_abs, linear_names, shards
def encode_to_proper_unary_torch(weight_f32, quantum, K):
"""
Encode [rows, cols] float32 tensor to proper unary.
Returns sign_packed [rows, chunks] uint64, slots_packed [K, rows, chunks] uint64
"""
rows, cols = weight_f32.shape
chunks = (cols + 63) // 64
inv_q = 1.0 / quantum
magnitudes = (weight_f32.abs() * inv_q).round().long().clamp(0, K)
signs = weight_f32 < 0
clip_count = int((weight_f32.abs() * inv_q > K).sum().item())
# Pack to uint64 bitplanes using numpy (torch lacks bit manipulation)
sign_packed = np.zeros((rows, chunks), dtype=np.uint64)
slots_packed = np.zeros((K, rows, chunks), dtype=np.uint64)
mags_np = magnitudes.numpy()
signs_np = signs.numpy()
for j in range(cols):
c = j // 64
bit = np.uint64(1) << np.uint64(j % 64)
# Sign
mask = signs_np[:, j]
sign_packed[mask, c] |= bit
# Unary slots: for each element, set slots 0..mag-1
col_mags = mags_np[:, j]
for p in range(K):
active = col_mags > p
slots_packed[p, active, c] |= bit
if (j + 1) % 256 == 0:
print(f" col {j+1}/{cols}", end="\r", flush=True)
print(f" {cols}/{cols} done, {clip_count} clipped")
return sign_packed, slots_packed, clip_count
def convert(model_dir, output_dir, K=32, clip_pct=99.9):
os.makedirs(output_dir, exist_ok=True)
config = json.load(open(os.path.join(model_dir, "config.json")))
print(f"Model: {config.get('_name_or_path', config.get('model_type', '?'))}")
print(f" Layers={config['num_hidden_layers']} Hidden={config['hidden_size']} Inter={config['intermediate_size']}")
# Scan
print("\nScanning weights...")
global_max, all_abs, linear_names, shards = scan_all_linears(model_dir)
# Pick quantum from clip percentile
clip_val = torch.quantile(all_abs, clip_pct / 100.0).item()
quantum = clip_val / K
print(f"\n Global absmax: {global_max:.6f}")
print(f" P{clip_pct} clip: {clip_val:.6f}")
print(f" K = {K}")
print(f" Quantum = {quantum:.8f}")
print(f" Values > clip ({clip_pct}%): saturate at K={K}")
# Distribution with chosen quantum
mags = (all_abs / quantum).round().clamp(0, K)
print(f"\n Mean magnitude: {mags.mean():.1f} slots")
print(f" Median: {mags.median():.1f} slots")
print(f" Zero fraction: {100*(mags==0).float().mean():.1f}%")
print(f" At K (clipped): {100*(mags==K).float().mean():.1f}%")
print(f" Unique levels: {len(mags.unique())} / {K+1}")
# Memory estimate
# Per linear: sign=rows*chunks*8 bytes, slots=K*rows*chunks*8 bytes
# Approx: (K+1) bits per element vs 16 bits BF16
bits_per_elem = K + 1 # K slot bits + 1 sign bit (stored in uint64 chunks)
ratio = bits_per_elem / 16.0
print(f"\n Bits per weight: {bits_per_elem}")
print(f" vs BF16 (16 bit): {ratio:.1f}x")
print(f" Original: ~7.6 GB → Estimated: ~{7.6 * ratio:.1f} GB")
# Build weight map
index_path = os.path.join(model_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
weight_map = json.load(open(index_path))["weight_map"]
else:
weight_map = None
manifest = {
"format": "proper_unary",
"quantum": float(quantum),
"K": K,
"clip_pct": clip_pct,
"clip_val": float(clip_val),
"global_absmax": float(global_max),
"unary": {},
"fp16": [],
}
# Group linears by shard
shard_linears = {}
for name in linear_names:
shard = weight_map[name] if weight_map else "model.safetensors"
shard_linears.setdefault(shard, []).append(name)
total_unary_bytes = 0
total_fp16_bytes = 0
total_clipped = 0
done = 0
for shard in shards:
path = os.path.join(model_dir, shard)
shard_lins = shard_linears.get(shard, [])
print(f"\nProcessing {shard} ({len(shard_lins)} linear layers)...")
with safe_open(path, framework="pt") as f:
all_keys = list(f.keys())
# Non-linear weights → FP16
for name in all_keys:
if name in linear_names:
continue
fname = name.replace(".", "_") + ".fp16"
out_path = os.path.join(output_dir, fname)
if not os.path.exists(out_path):
t = f.get_tensor(name).half()
t.numpy().view(np.uint16).tofile(out_path)
sz = os.path.getsize(out_path)
total_fp16_bytes += sz
manifest["fp16"].append(name)
print(f" FP16: {name} {list(t.shape)} ({sz//1024}KB)")
# Linear weights → proper unary
for name in shard_lins:
fname = name.replace(".", "_")
sign_path = os.path.join(output_dir, f"{fname}.usign")
slots_path = os.path.join(output_dir, f"{fname}.uslots")
if os.path.exists(sign_path) and os.path.exists(slots_path):
t = f.get_tensor(name)
manifest["unary"][name] = list(t.shape)
total_unary_bytes += os.path.getsize(sign_path) + os.path.getsize(slots_path)
done += 1
print(f" Skip: {name}")
continue
t = f.get_tensor(name).float()
rows, cols = t.shape
print(f" Converting: {name} [{rows}x{cols}]...", flush=True)
sign_p, slots_p, clip_c = encode_to_proper_unary_torch(t, quantum, K)
total_clipped += clip_c
sign_p.tofile(sign_path)
slots_p.tofile(slots_path)
s_sz = os.path.getsize(sign_path)
sl_sz = os.path.getsize(slots_path)
total_unary_bytes += s_sz + sl_sz
manifest["unary"][name] = [rows, cols]
done += 1
print(f" sign={s_sz//1024}KB slots={sl_sz//1024}KB total={( s_sz+sl_sz)//1024//1024}MB")
del t, sign_p, slots_p
gc.collect()
# Copy config and tokenizer
for fname in os.listdir(model_dir):
if fname.endswith(('.json', '.txt', '.model')) and not fname.startswith('model.safetensors'):
src = os.path.join(model_dir, fname)
dst = os.path.join(output_dir, fname)
if not os.path.exists(dst):
shutil.copy2(src, dst)
manifest_path = os.path.join(output_dir, "manifest.json")
json.dump(manifest, open(manifest_path, "w"), indent=2)
total = total_unary_bytes + total_fp16_bytes
print(f"\n{'='*60}")
print(f"PROPER UNARY CONVERSION COMPLETE")
print(f"{'='*60}")
print(f" Quantum: {quantum:.8f}")
print(f" K: {K}")
print(f" Clip at P{clip_pct}: {clip_val:.6f}")
print(f" Linear layers: {done}")
print(f" Clipped vals: {total_clipped}")
print(f" Unary: {total_unary_bytes/1e9:.2f} GB")
print(f" FP16 (norms): {total_fp16_bytes/1e6:.1f} MB")
print(f" Total: {total/1e9:.2f} GB")
print(f" Original BF16: ~7.6 GB")
print(f" Ratio: {total/7.6e9:.1f}x")
print(f" Output dir: {output_dir}")
if __name__ == "__main__":
model_dir = sys.argv[1] if len(sys.argv) > 1 else "qwen3-4b-thinking-hf"
output_dir = sys.argv[2] if len(sys.argv) > 2 else "qwen3-4b-proper-unary"
K = int(sys.argv[3]) if len(sys.argv) > 3 else 32
clip = float(sys.argv[4]) if len(sys.argv) > 4 else 99.9
convert(model_dir, output_dir, K=K, clip_pct=clip)