"""Ternary quantizer v2 — multi-config single-pass with per-matrix checkpointing. Key improvements over v1: 1. Multi-config: --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3 Computes per-group MSE-best scales (over a fixed 4-candidate set) ONCE per matrix, derives all configs. ~3x faster than running v1 four times. 2. Per-matrix checkpoint: each matrix's quantized output saved to .checkpoint/ dir as soon as it's done. Crash-resume picks up where it left off. 3. Durable atomic writes (write to .tmp, fsync, rename) — no half-written or post-power-loss-truncated checkpoints. 4. Streaming progress.json — monitors can poll without parsing logs. 5. Per-config HF model assembled at the end from checkpoints. 6. Resume validation: a fingerprint of (model id, revision, codec version, depth-power mapping, tensor shape) is stored in each checkpoint and re-checked on resume. A mismatch causes the stale checkpoint to be discarded and re-quantized rather than silently mixed. What this codec quantizes (and what it does not): - Quantized: every 2D linear weight matrix in the model. - Kept FP16: token embeddings, all *_norm layers, and lm_head. This matches the convention used by GPTQ/AWQ/NF4 and is what the paper's bits-per-weight figures account for. Usage: python quantize_model_v2.py --model Qwen/Qwen2.5-7B \ --configs uniform-d2,uniform-d3 \ --output /path/to/output_root \ --revision \ --workers 8 --dtype float16 Output structure: output_root/ .checkpoint/ matrix_00000__model.layers.0.self_attn.q_proj.npz # all configs in one file matrix_00001__model.layers.0.self_attn.k_proj.npz ... progress.json # live status / model/ # HF-format output config.json """ import os, sys, time, json, gc, argparse, tempfile from multiprocessing import Pool import numpy as np # ============================================================ # CODEC CORE (unchanged from v1) # ============================================================ GS = 16 DEPTH_POWERS = {1: 1.0, 2: 1.5, 3: 1.2, 4: 1.0} def build_levels(half, power): int_levels = np.arange(-half, half + 1).astype(np.float64) n = int_levels / max(half, 1) if power != 1.0: return np.sign(n) * np.abs(n) ** power * max(half, 1) return int_levels def make_boundaries(level_map, zero_boundary=None): """Default = midpoints between levels. If zero_boundary given, override the boundaries straddling 0 (used for d1 with custom zero-zone width).""" boundaries = (level_map[:-1] + level_map[1:]) / 2 if zero_boundary is not None: zero_idx = int(np.argmin(np.abs(level_map))) if zero_idx > 0: boundaries[zero_idx - 1] = -abs(zero_boundary) if zero_idx < len(level_map) - 1: boundaries[zero_idx] = abs(zero_boundary) return boundaries def compute_best_scale_4cand(groups, depth, power, zero_boundary=None): """Pick the per-group scale that minimises reconstruction MSE among 4 fixed order-statistic candidates of the sorted absolute weights: indices [gs-6, gs-4, gs-2, gs-1] (roughly the 69th/81st/94th/100th percentiles for gs=16). This is a deliberately small candidate set, not an exhaustive sweep. Empirically <1% PPL gap from a dense sweep on Qwen2.5-7B; in exchange quantization is ~50x faster than evaluating every percentile. """ half = (3 ** depth) // 2 gs = groups.shape[1] sa = np.sort(np.abs(groups), axis=1) cand_idx = np.clip(np.array([gs-6, gs-4, gs-2, gs-1]), 0, gs-1) level_map = build_levels(half, power) boundaries = make_boundaries(level_map, zero_boundary) N = len(groups) best_scale = np.zeros(N); best_mse = np.full(N, np.inf) for ki in cand_idx: scales = np.maximum(sa[:, ki] / max(half, 1), 1e-30) normalized = groups / scales[:, None] idx = np.searchsorted(boundaries, normalized.ravel()) idx = np.clip(idx, 0, len(level_map) - 1) q = level_map[idx].reshape(N, gs) recon = q * scales[:, None] mse = np.mean((groups - recon) ** 2, axis=1) better = mse < best_mse best_mse[better] = mse[better]; best_scale[better] = scales[better] return best_scale, best_mse # Backwards-compatible alias — earlier scripts and the published paper repo # refer to this as the "MSE-optimal" call site. The name overstates the # guarantee (see docstring on compute_best_scale_4cand) but the algorithm is # unchanged. compute_optimal_scale = compute_best_scale_4cand def trit_quantize_scales(scales, sd): log_scales = np.log(np.maximum(scales, 1e-30)) half = (3 ** sd) // 2 n_levels = 2 * half + 1 log_min = np.percentile(log_scales, 0.1) log_max = np.max(log_scales) # 100th pct — never clip large scales if log_max - log_min < 1e-9: log_max = log_min + 1e-9 codebook_log = np.linspace(log_min, log_max, n_levels) idx = np.argmin(np.abs(log_scales[:, None] - codebook_log[None, :]), axis=1) return np.exp(codebook_log[idx]) def quantize_with_scale(groups, scale, depth, power, zero_boundary=None): half = (3 ** depth) // 2 level_map = build_levels(half, power) boundaries = make_boundaries(level_map, zero_boundary) scale = np.maximum(scale, 1e-30) normalized = groups / scale[:, None] idx = np.searchsorted(boundaries, normalized.ravel()) idx = np.clip(idx, 0, len(level_map) - 1) q = level_map[idx].reshape(groups.shape) return q * scale[:, None] # ============================================================ # CODEC CONFIGS # ============================================================ CODECS = { 'd3scale-sens002': {'mode': 'adaptive', 'scale_depth': 3, 'threshold': 0.002}, 'd3scale-sens003': {'mode': 'adaptive', 'scale_depth': 3, 'threshold': 0.003}, # d1 with narrow zero zone (zw=0.25): 3 levels {-1,0,+1}, zero only when |w|<0.25*scale. # Old default was zw=0.5 which made 97.5% of weights round to 0 (random-chance MMLU). 'uniform-d1': {'mode': 'uniform', 'scale_depth': 3, 'depth': 1, 'zero_boundary': 0.25}, 'uniform-d2': {'mode': 'uniform', 'scale_depth': 3, 'depth': 2}, 'uniform-d3': {'mode': 'uniform', 'scale_depth': 3, 'depth': 3}, 'uniform-d4': {'mode': 'uniform', 'scale_depth': 3, 'depth': 4}, } # ============================================================ # MULTI-CONFIG MATRIX QUANTIZATION # ============================================================ def quantize_matrix_multi(args): """Quantize one matrix for ALL requested configs in a single pass. Returns dict: config_name -> (recon_w, depth_counts, weight_bits, scale_bits, n_groups) """ w_flat, rows, cols, config_names = args w = w_flat.reshape(rows, cols) pad = (GS - cols % GS) % GS if pad > 0: w = np.pad(w, ((0, 0), (0, pad))) groups = w.reshape(-1, GS).astype(np.float64) N = len(groups) group_var = np.maximum(np.var(groups, axis=1), 1e-30) # Precompute optimal scale + MSE for every (depth, zero_boundary) combo used. # Adaptive uses default boundaries for d2/d3/d4; uniform configs may override (e.g. d1 zw=0.25). needed_keys = set() # (depth, zero_boundary) for cn in config_names: cfg = CODECS[cn] if cfg['mode'] == 'adaptive': for d in (2, 3, 4): needed_keys.add((d, None)) else: needed_keys.add((cfg['depth'], cfg.get('zero_boundary'))) scales_per_key = {} mse_per_key = {} recon_per_key = {} for d, zb in sorted(needed_keys, key=lambda x: (x[0], x[1] or 0)): power = DEPTH_POWERS[d] opt_s, _ = compute_optimal_scale(groups, d, power, zero_boundary=zb) use_s = trit_quantize_scales(opt_s, 3) r = quantize_with_scale(groups, use_s, d, power, zero_boundary=zb) mse = np.mean((groups - r) ** 2, axis=1) scales_per_key[(d, zb)] = use_s mse_per_key[(d, zb)] = mse recon_per_key[(d, zb)] = r out = {} for cn in config_names: cfg = CODECS[cn] if cfg['mode'] == 'uniform': d = cfg['depth'] zb = cfg.get('zero_boundary') recon = recon_per_key[(d, zb)] depth_counts = {1:0, 2:0, 3:0, 4:0} depth_counts[d] = N wb = N * GS * d * np.log2(3) sb = N * cfg['scale_depth'] * np.log2(3) else: # adaptive eff_thresh = cfg['threshold'] * 5.5 recon = np.zeros_like(groups) assigned = np.zeros(N, dtype=bool) depth_counts = {1:0, 2:0, 3:0, 4:0} wb = 0.0; sb = 0.0 for d in [2, 3, 4]: unassigned = ~assigned if not np.any(unassigned): break if d == 4: recon[unassigned] = recon_per_key[(4, None)][unassigned] n_d = int(np.sum(unassigned)) depth_counts[d] = n_d wb += n_d * GS * d * np.log2(3) sb += n_d * cfg['scale_depth'] * np.log2(3) break mse_d = mse_per_key[(d, None)][unassigned] meets = (mse_d / group_var[unassigned]) < eff_thresh uidx = np.where(unassigned)[0] midx = uidx[meets] recon[midx] = recon_per_key[(d, None)][midx] assigned[midx] = True n_d = int(np.sum(meets)) depth_counts[d] = n_d wb += n_d * GS * d * np.log2(3) sb += n_d * cfg['scale_depth'] * np.log2(3) recon_w = recon.reshape(rows, -1)[:, :cols].astype(np.float32) out[cn] = { 'recon_w': recon_w, 'depth_counts': depth_counts, 'weight_bits': float(wb), 'scale_bits': float(sb), 'n_groups': N, } return out # ============================================================ # CHECKPOINTING # ============================================================ def matrix_ckpt_path(ckpt_dir, idx, name): safe = name.replace('/', '__').replace('.', '_') return os.path.join(ckpt_dir, f'matrix_{idx:05d}__{safe}.npz') def atomic_save_npz(path, data): """Write `data` to `path` atomically, with fsync before rename so the checkpoint survives power loss / SIGKILL after the rename returns.""" # NOTE: np.savez_compressed silently appends '.npz' if missing — so we # name the tmp file with .npz suffix and pass it the same path. fd, tmp = tempfile.mkstemp(prefix='.tmp_', suffix='.npz', dir=os.path.dirname(path)) os.close(fd) np.savez_compressed(tmp, **data) # fsync the file so its data is durable before we rename. os.replace then # makes the rename atomic (POSIX guarantees same-filesystem rename atomicity). fd = os.open(tmp, os.O_RDONLY) try: os.fsync(fd) finally: os.close(fd) os.replace(tmp, path) # fsync the parent directory so the rename itself is durable. dir_fd = os.open(os.path.dirname(path) or '.', os.O_RDONLY) try: os.fsync(dir_fd) except OSError: pass # not all filesystems support directory fsync (e.g. some FUSE) finally: os.close(dir_fd) # Codec version — bumped whenever the algorithm changes in a way that would # make older checkpoints invalid (e.g. depth-power mapping change, scale # codebook range change, group-size change). Used by the fingerprint validator. CODEC_VERSION = 'v2.0' def codec_fingerprint(model_id, revision, depth_powers, group_size, codec_version): """Stable string that identifies the algorithmic state behind a checkpoint. Two checkpoints with the same fingerprint can be safely interleaved. Two with different fingerprints must not be mixed — a mismatch on resume causes the stale checkpoint to be discarded and re-quantized. """ parts = [ f'codec={codec_version}', f'model={model_id}', f'revision={revision or "unspecified"}', f'gs={group_size}', f'powers=' + ','.join(f'{d}:{p}' for d, p in sorted(depth_powers.items())), ] return '|'.join(parts) def load_ckpt(path): with np.load(path, allow_pickle=True) as z: return {k: z[k] for k in z.files} def write_progress(out_root, state): path = os.path.join(out_root, 'progress.json') fd, tmp = tempfile.mkstemp(prefix='.tmp_', dir=out_root) with os.fdopen(fd, 'w') as f: json.dump(state, f, indent=2) os.replace(tmp, path) # ============================================================ # MAIN # ============================================================ def main(): parser = argparse.ArgumentParser(description='Multi-config ternary quantizer with checkpointing') parser.add_argument('--model', required=True) parser.add_argument('--configs', required=True, help='Comma-separated codec names: ' + ','.join(CODECS.keys())) parser.add_argument('--output', required=True, help='Output root dir') parser.add_argument('--workers', type=int, default=1) parser.add_argument('--dtype', default='float16', choices=['float16', 'bfloat16']) parser.add_argument('--skip-assembly', action='store_true', help='Quantize matrices and checkpoint only; skip final HF model assembly.') parser.add_argument('--matrix-range', default=None, help='Slice of matrices to process: "start:end" (0-indexed, end exclusive). ' 'Use to manually parallelize across processes/machines via shared checkpoint dir.') parser.add_argument('--revision', default=None, help='HuggingFace revision (commit SHA or tag) to pin the source model. ' 'Recommended for reproducibility — without it, the upstream repo can move under you.') args = parser.parse_args() config_names = [c.strip() for c in args.configs.split(',') if c.strip()] for cn in config_names: if cn not in CODECS: print(f'ERROR: unknown codec {cn}', file=sys.stderr); sys.exit(2) os.makedirs(args.output, exist_ok=True) ckpt_dir = os.path.join(args.output, '.checkpoint') os.makedirs(ckpt_dir, exist_ok=True) print(f'=== Quantizing {args.model} ===', flush=True) print(f' configs: {config_names}', flush=True) print(f' workers: {args.workers}', flush=True) import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16 print(' loading model (CPU)...', flush=True) t_load = time.time() _cfg = AutoConfig.from_pretrained(args.model, revision=args.revision, trust_remote_code=True) _arch = ((getattr(_cfg, 'architectures', None) or [''])[0] or '').lower() if 't5' in _arch or 'encoder' in _arch: from transformers import T5EncoderModel print(' loading as T5EncoderModel (encoder-only)', flush=True) model = T5EncoderModel.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype, device_map='cpu', trust_remote_code=True, low_cpu_mem_usage=True) else: try: model = AutoModelForCausalLM.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype, device_map='cpu', trust_remote_code=True, low_cpu_mem_usage=True) except ValueError: print(' fallback to generic AutoModel', flush=True) model = AutoModel.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype, device_map='cpu', trust_remote_code=True, low_cpu_mem_usage=True) try: tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token except Exception as e: print(f' tokenizer load failed (ok for encoder-only): {e}', flush=True) tokenizer = None print(f' loaded in {time.time()-t_load:.0f}s', flush=True) # Collect matrices to quantize (skip embeddings, norms, lm_head) matrices = [] for pn, p in model.named_parameters(): if p.dim() != 2 or 'norm' in pn or 'embed' in pn or 'lm_head' in pn: continue matrices.append((pn, p)) print(f' {len(matrices)} matrices to quantize', flush=True) # Apply --matrix-range slice (for parallel sharded processing) range_start, range_end = 0, len(matrices) if args.matrix_range: s, e = args.matrix_range.split(':') range_start = int(s) if s else 0 range_end = int(e) if e else len(matrices) range_end = min(range_end, len(matrices)) print(f' matrix-range: [{range_start}:{range_end})', flush=True) # Codec fingerprint for this run — used to validate resumed checkpoints. expected_fp = codec_fingerprint(args.model, args.revision, DEPTH_POWERS, GS, CODEC_VERSION) # Determine which need work (resume from checkpoints) todo = [] done_count = 0 discarded_count = 0 for idx, (pn, p) in enumerate(matrices): if idx < range_start or idx >= range_end: continue cp = matrix_ckpt_path(ckpt_dir, idx, pn) if os.path.exists(cp): try: z = np.load(cp, allow_pickle=True) meta = json.loads(str(z['_meta'][()])) # Validate: configs cover requested set, fingerprint matches, shape matches. have_configs = set(meta.get('configs', [])) ckpt_fp = meta.get('fingerprint') ckpt_shape = tuple(meta.get('shape', ())) cur_shape = tuple(p.shape) if all(cn in have_configs for cn in config_names) \ and ckpt_fp == expected_fp \ and ckpt_shape == cur_shape: done_count += 1 continue if ckpt_fp != expected_fp: print(f' fingerprint mismatch on {cp}: stale={ckpt_fp!r} expected={expected_fp!r} — discarding', flush=True) elif ckpt_shape != cur_shape: print(f' shape mismatch on {cp}: stale={ckpt_shape} current={cur_shape} — discarding', flush=True) else: print(f' missing configs in {cp}: have={have_configs}, need={config_names} — redoing', flush=True) discarded_count += 1 os.remove(cp) except Exception as e: print(f' bad checkpoint {cp}: {e}, will redo', flush=True) os.remove(cp) todo.append((idx, pn, p)) if discarded_count: print(f' discarded {discarded_count} stale checkpoint(s)', flush=True) print(f' {done_count} matrices already checkpointed, {len(todo)} to do', flush=True) t0 = time.time() state = { 'model': args.model, 'configs': config_names, 'total_matrices': len(matrices), 'done_matrices': done_count, 'started_at': t0, 'updated_at': t0, } write_progress(args.output, state) def process_one(idx, pn, p): w = p.data.float().numpy() result = quantize_matrix_multi( (w.ravel(), w.shape[0], w.shape[1], config_names)) # Pack into npz: one key per config + meta (with codec fingerprint # so a future resume can detect a stale checkpoint and discard it). save_data = {'_meta': np.array(json.dumps({ 'name': pn, 'idx': idx, 'shape': list(w.shape), 'configs': config_names, 'fingerprint': expected_fp, }))} for cn, info in result.items(): save_data[f'{cn}__w'] = info['recon_w'] save_data[f'{cn}__stats'] = np.array(json.dumps({ 'depth_counts': info['depth_counts'], 'weight_bits': info['weight_bits'], 'scale_bits': info['scale_bits'], 'n_groups': info['n_groups'], })) atomic_save_npz(matrix_ckpt_path(ckpt_dir, idx, pn), save_data) return idx if args.workers > 1 and len(todo) > 1: # Streaming generator: yield (matrix, config_names) one at a time. # CRITICAL: do NOT pre-build all matrices in a list — for large models # (Llama 70B = 140GB) that OOMs the box at multiple hundred GB. The generator # is consumed lazily by Pool.imap. idx_name = [(idx, pn, list(p.shape)) for idx, pn, p in todo] def gen(): for idx, pn, p in todo: w = p.data.float().numpy() yield (w.ravel(), w.shape[0], w.shape[1], config_names) # Free the source tensor after we've handed off the numpy view. # The Pool worker has its own copy via pickle. p.data = __import__('torch').zeros(1, dtype=p.dtype) with Pool(args.workers) as pool: for i, result in enumerate(pool.imap(quantize_matrix_multi, gen(), chunksize=1)): idx, pn, shape = idx_name[i] save_data = {'_meta': np.array(json.dumps({ 'name': pn, 'idx': idx, 'shape': shape, 'configs': config_names, 'fingerprint': expected_fp, }))} for cn, info in result.items(): save_data[f'{cn}__w'] = info['recon_w'] save_data[f'{cn}__stats'] = np.array(json.dumps({ 'depth_counts': info['depth_counts'], 'weight_bits': info['weight_bits'], 'scale_bits': info['scale_bits'], 'n_groups': info['n_groups'], })) atomic_save_npz(matrix_ckpt_path(ckpt_dir, idx, pn), save_data) done_count += 1 state['done_matrices'] = done_count state['updated_at'] = time.time() state['elapsed_s'] = time.time() - t0 if (i+1) % 5 == 0 or (i+1) == len(todo): write_progress(args.output, state) eta = (len(todo) - (i+1)) * (time.time() - t0) / max(i+1, 1) print(f' {done_count}/{len(matrices)} ({time.time()-t0:.0f}s, ETA {eta:.0f}s)', flush=True) else: for i, (idx, pn, p) in enumerate(todo): process_one(idx, pn, p) done_count += 1 state['done_matrices'] = done_count state['updated_at'] = time.time() state['elapsed_s'] = time.time() - t0 if (i+1) % 5 == 0 or (i+1) == len(todo): write_progress(args.output, state) eta = (len(todo) - (i+1)) * (time.time() - t0) / max(i+1, 1) print(f' {done_count}/{len(matrices)} ({time.time()-t0:.0f}s, ETA {eta:.0f}s)', flush=True) print(f' Quantization complete in {time.time()-t0:.0f}s', flush=True) # If we processed only a slice, don't assemble — leave that for the merge step. if args.matrix_range: # Verify which checkpoints exist for this slice; print summary slice_done = sum(1 for idx, (pn, p) in enumerate(matrices) if range_start <= idx < range_end and os.path.exists(matrix_ckpt_path(ckpt_dir, idx, pn))) print(f' slice [{range_start}:{range_end}): {slice_done} checkpointed', flush=True) return if args.skip_assembly: print(' --skip-assembly: not building HF model dirs', flush=True) return # ============================================================ # ASSEMBLY: load each config from checkpoints, write HF model # ============================================================ print(' Assembling HF models per config...', flush=True) for cn in config_names: cfg_dir = os.path.join(args.output, cn) os.makedirs(cfg_dir, exist_ok=True) model_dir = os.path.join(cfg_dir, 'model') # Aggregate stats total_groups = 0 total_depth = {1:0, 2:0, 3:0, 4:0} total_wb = 0.0; total_sb = 0.0 # Replace tensors in-place with this config's reconstruction name_to_param = {pn: p for pn, p in matrices} for idx, (pn, p) in enumerate(matrices): cp = matrix_ckpt_path(ckpt_dir, idx, pn) z = np.load(cp, allow_pickle=True) recon_w = z[f'{cn}__w'] stats = json.loads(str(z[f'{cn}__stats'][()])) p.data = __import__('torch').from_numpy(recon_w).to(p.dtype) total_groups += stats['n_groups'] for d in [1,2,3,4]: total_depth[d] += stats['depth_counts'].get(str(d), stats['depth_counts'].get(d, 0)) total_wb += stats['weight_bits'] total_sb += stats['scale_bits'] tg = max(total_groups, 1) trit_bpw = total_wb / (tg * GS) scale_bpw = total_sb / (tg * GS) total_bpw = trit_bpw + scale_bpw print(f' [{cn}] BPW={total_bpw:.3f} (trit={trit_bpw:.3f}+scale={scale_bpw:.3f})', flush=True) print(f' [{cn}] Saving to {model_dir}...', flush=True) model.save_pretrained(model_dir, safe_serialization=True) if tokenizer is not None: tokenizer.save_pretrained(model_dir) config = { 'model': os.path.basename(args.model.rstrip('/')), 'model_revision': args.revision, 'codec_version': CODEC_VERSION, 'codec_fingerprint': expected_fp, 'codec': cn, 'bpw': total_bpw, 'trit_bpw': trit_bpw, 'scale_bpw': scale_bpw, 'depth_pcts': {str(d): total_depth[d]/tg for d in [1,2,3,4]}, 'n_matrices': len(matrices), 'group_size': GS, 'fp16_layers': ['lm_head', 'embed_tokens', '*_norm'], 'codec_params': CODECS[cn], } with open(os.path.join(cfg_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=2) print(f' [{cn}] DONE: {cfg_dir}', flush=True) print(f' ALL CONFIGS COMPLETE in {time.time()-t0:.0f}s total', flush=True) if __name__ == '__main__': main()