| | |
| | """ |
| | PROPER UNARY CONVERTER — Global quantum, torch-based, BF16 support |
| | |
| | Clips at P99.9 of |weights| instead of absmax to avoid wasting |
| | quantization range on rare outliers. Values above clip point |
| | saturate at K (still represented, just capped). |
| | |
| | (c) 2026 OpenTransformers Ltd / Scott Bisset |
| | """ |
| |
|
| | import torch, json, os, sys, gc, shutil |
| | from safetensors import safe_open |
| | import numpy as np |
| |
|
| | def scan_all_linears(model_dir): |
| | """Scan linear layers, return global stats.""" |
| | index_path = os.path.join(model_dir, "model.safetensors.index.json") |
| | if os.path.exists(index_path): |
| | index = json.load(open(index_path)) |
| | shards = sorted(set(index["weight_map"].values())) |
| | else: |
| | shards = ["model.safetensors"] |
| |
|
| | all_abs_samples = [] |
| | linear_names = [] |
| | global_max = 0.0 |
| |
|
| | for shard in shards: |
| | path = os.path.join(model_dir, shard) |
| | print(f" Scanning {shard}...") |
| | with safe_open(path, framework="pt") as f: |
| | for name in f.keys(): |
| | t = f.get_tensor(name).float() |
| | if t.dim() == 2 and "norm" not in name and "embed" not in name: |
| | linear_names.append(name) |
| | am = t.abs().max().item() |
| | if am > global_max: |
| | global_max = am |
| | |
| | idx = torch.randint(0, t.numel(), (2000,)) |
| | all_abs_samples.append(t.flatten()[idx].abs()) |
| |
|
| | all_abs = torch.cat(all_abs_samples) |
| | return global_max, all_abs, linear_names, shards |
| |
|
| |
|
| | def encode_to_proper_unary_torch(weight_f32, quantum, K): |
| | """ |
| | Encode [rows, cols] float32 tensor to proper unary. |
| | Returns sign_packed [rows, chunks] uint64, slots_packed [K, rows, chunks] uint64 |
| | """ |
| | rows, cols = weight_f32.shape |
| | chunks = (cols + 63) // 64 |
| |
|
| | inv_q = 1.0 / quantum |
| | magnitudes = (weight_f32.abs() * inv_q).round().long().clamp(0, K) |
| | signs = weight_f32 < 0 |
| | clip_count = int((weight_f32.abs() * inv_q > K).sum().item()) |
| |
|
| | |
| | sign_packed = np.zeros((rows, chunks), dtype=np.uint64) |
| | slots_packed = np.zeros((K, rows, chunks), dtype=np.uint64) |
| |
|
| | mags_np = magnitudes.numpy() |
| | signs_np = signs.numpy() |
| |
|
| | for j in range(cols): |
| | c = j // 64 |
| | bit = np.uint64(1) << np.uint64(j % 64) |
| |
|
| | |
| | mask = signs_np[:, j] |
| | sign_packed[mask, c] |= bit |
| |
|
| | |
| | col_mags = mags_np[:, j] |
| | for p in range(K): |
| | active = col_mags > p |
| | slots_packed[p, active, c] |= bit |
| |
|
| | if (j + 1) % 256 == 0: |
| | print(f" col {j+1}/{cols}", end="\r", flush=True) |
| |
|
| | print(f" {cols}/{cols} done, {clip_count} clipped") |
| | return sign_packed, slots_packed, clip_count |
| |
|
| |
|
| | def convert(model_dir, output_dir, K=32, clip_pct=99.9): |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | config = json.load(open(os.path.join(model_dir, "config.json"))) |
| | print(f"Model: {config.get('_name_or_path', config.get('model_type', '?'))}") |
| | print(f" Layers={config['num_hidden_layers']} Hidden={config['hidden_size']} Inter={config['intermediate_size']}") |
| |
|
| | |
| | print("\nScanning weights...") |
| | global_max, all_abs, linear_names, shards = scan_all_linears(model_dir) |
| |
|
| | |
| | clip_val = torch.quantile(all_abs, clip_pct / 100.0).item() |
| | quantum = clip_val / K |
| |
|
| | print(f"\n Global absmax: {global_max:.6f}") |
| | print(f" P{clip_pct} clip: {clip_val:.6f}") |
| | print(f" K = {K}") |
| | print(f" Quantum = {quantum:.8f}") |
| | print(f" Values > clip ({clip_pct}%): saturate at K={K}") |
| |
|
| | |
| | mags = (all_abs / quantum).round().clamp(0, K) |
| | print(f"\n Mean magnitude: {mags.mean():.1f} slots") |
| | print(f" Median: {mags.median():.1f} slots") |
| | print(f" Zero fraction: {100*(mags==0).float().mean():.1f}%") |
| | print(f" At K (clipped): {100*(mags==K).float().mean():.1f}%") |
| | print(f" Unique levels: {len(mags.unique())} / {K+1}") |
| |
|
| | |
| | |
| | |
| | bits_per_elem = K + 1 |
| | ratio = bits_per_elem / 16.0 |
| | print(f"\n Bits per weight: {bits_per_elem}") |
| | print(f" vs BF16 (16 bit): {ratio:.1f}x") |
| | print(f" Original: ~7.6 GB → Estimated: ~{7.6 * ratio:.1f} GB") |
| |
|
| | |
| | index_path = os.path.join(model_dir, "model.safetensors.index.json") |
| | if os.path.exists(index_path): |
| | weight_map = json.load(open(index_path))["weight_map"] |
| | else: |
| | weight_map = None |
| |
|
| | manifest = { |
| | "format": "proper_unary", |
| | "quantum": float(quantum), |
| | "K": K, |
| | "clip_pct": clip_pct, |
| | "clip_val": float(clip_val), |
| | "global_absmax": float(global_max), |
| | "unary": {}, |
| | "fp16": [], |
| | } |
| |
|
| | |
| | shard_linears = {} |
| | for name in linear_names: |
| | shard = weight_map[name] if weight_map else "model.safetensors" |
| | shard_linears.setdefault(shard, []).append(name) |
| |
|
| | total_unary_bytes = 0 |
| | total_fp16_bytes = 0 |
| | total_clipped = 0 |
| | done = 0 |
| |
|
| | for shard in shards: |
| | path = os.path.join(model_dir, shard) |
| | shard_lins = shard_linears.get(shard, []) |
| | print(f"\nProcessing {shard} ({len(shard_lins)} linear layers)...") |
| |
|
| | with safe_open(path, framework="pt") as f: |
| | all_keys = list(f.keys()) |
| |
|
| | |
| | for name in all_keys: |
| | if name in linear_names: |
| | continue |
| | fname = name.replace(".", "_") + ".fp16" |
| | out_path = os.path.join(output_dir, fname) |
| | if not os.path.exists(out_path): |
| | t = f.get_tensor(name).half() |
| | t.numpy().view(np.uint16).tofile(out_path) |
| | sz = os.path.getsize(out_path) |
| | total_fp16_bytes += sz |
| | manifest["fp16"].append(name) |
| | print(f" FP16: {name} {list(t.shape)} ({sz//1024}KB)") |
| |
|
| | |
| | for name in shard_lins: |
| | fname = name.replace(".", "_") |
| | sign_path = os.path.join(output_dir, f"{fname}.usign") |
| | slots_path = os.path.join(output_dir, f"{fname}.uslots") |
| |
|
| | if os.path.exists(sign_path) and os.path.exists(slots_path): |
| | t = f.get_tensor(name) |
| | manifest["unary"][name] = list(t.shape) |
| | total_unary_bytes += os.path.getsize(sign_path) + os.path.getsize(slots_path) |
| | done += 1 |
| | print(f" Skip: {name}") |
| | continue |
| |
|
| | t = f.get_tensor(name).float() |
| | rows, cols = t.shape |
| | print(f" Converting: {name} [{rows}x{cols}]...", flush=True) |
| |
|
| | sign_p, slots_p, clip_c = encode_to_proper_unary_torch(t, quantum, K) |
| | total_clipped += clip_c |
| |
|
| | sign_p.tofile(sign_path) |
| | slots_p.tofile(slots_path) |
| |
|
| | s_sz = os.path.getsize(sign_path) |
| | sl_sz = os.path.getsize(slots_path) |
| | total_unary_bytes += s_sz + sl_sz |
| |
|
| | manifest["unary"][name] = [rows, cols] |
| | done += 1 |
| | print(f" sign={s_sz//1024}KB slots={sl_sz//1024}KB total={( s_sz+sl_sz)//1024//1024}MB") |
| |
|
| | del t, sign_p, slots_p |
| | gc.collect() |
| |
|
| | |
| | for fname in os.listdir(model_dir): |
| | if fname.endswith(('.json', '.txt', '.model')) and not fname.startswith('model.safetensors'): |
| | src = os.path.join(model_dir, fname) |
| | dst = os.path.join(output_dir, fname) |
| | if not os.path.exists(dst): |
| | shutil.copy2(src, dst) |
| |
|
| | manifest_path = os.path.join(output_dir, "manifest.json") |
| | json.dump(manifest, open(manifest_path, "w"), indent=2) |
| |
|
| | total = total_unary_bytes + total_fp16_bytes |
| | print(f"\n{'='*60}") |
| | print(f"PROPER UNARY CONVERSION COMPLETE") |
| | print(f"{'='*60}") |
| | print(f" Quantum: {quantum:.8f}") |
| | print(f" K: {K}") |
| | print(f" Clip at P{clip_pct}: {clip_val:.6f}") |
| | print(f" Linear layers: {done}") |
| | print(f" Clipped vals: {total_clipped}") |
| | print(f" Unary: {total_unary_bytes/1e9:.2f} GB") |
| | print(f" FP16 (norms): {total_fp16_bytes/1e6:.1f} MB") |
| | print(f" Total: {total/1e9:.2f} GB") |
| | print(f" Original BF16: ~7.6 GB") |
| | print(f" Ratio: {total/7.6e9:.1f}x") |
| | print(f" Output dir: {output_dir}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model_dir = sys.argv[1] if len(sys.argv) > 1 else "qwen3-4b-thinking-hf" |
| | output_dir = sys.argv[2] if len(sys.argv) > 2 else "qwen3-4b-proper-unary" |
| | K = int(sys.argv[3]) if len(sys.argv) > 3 else 32 |
| | clip = float(sys.argv[4]) if len(sys.argv) > 4 else 99.9 |
| |
|
| | convert(model_dir, output_dir, K=K, clip_pct=clip) |
| |
|