#!/usr/bin/env python3 """Perplexity benchmark for FireEcho quantization formats. Evaluates WikiText-2 perplexity across quantization configs: 1. FP4 baseline (Goliath FP4, all experts) 2. FE-XC 10% cold (codebook 2-bit, plain k-means) 3. FE-XVQ 10% cold (codebook 2-bit, Hessian-weighted k-means) 4. INT2 10% cold (scalar 2-bit) Each config runs in a SEPARATE SUBPROCESS to guarantee clean CUDA context (PyTorch's memory allocator doesn't fully release between del+gc.collect). Usage: python benchmark_perplexity.py [--max_tokens 50000] [--stride 256] Output: PPL comparison table suitable for paper. Copyright (c) 2025-2026 Echo (FireEcho Project). All rights reserved. """ import sys import os import time import math import json import argparse import subprocess import tempfile sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) MODEL_DIR = '/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct' FEXVQ_CODEBOOKS = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fexvq_codebooks.pt') SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) # ===== Worker code (runs in subprocess) ===== def run_single_config(config, max_tokens, stride, max_len, cold_pct, result_file): """Run a single config evaluation. Called in subprocess.""" import torch import torch.nn.functional as F sys.path.insert(0, SCRIPT_DIR) print(f"\n{'=' * 70}") print(f" Config: {config.upper()}") print(f"{'=' * 70}") # Load model from fireecho_kernel import FireEchoEngine from transformers import AutoTokenizer print("[1] Loading model...") engine = FireEchoEngine.from_pretrained(MODEL_DIR) engine.pack_all_experts() engine.eval() tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) # Load WikiText-2 from datasets import load_dataset print(" Loading WikiText-2 test set...") ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") text = "\n\n".join([t for t in ds["text"] if t.strip()]) print(f" Text length: {len(text):,} chars") tokens = tokenizer.encode(text, add_special_tokens=False) if max_tokens > 0 and len(tokens) > max_tokens: tokens = tokens[:max_tokens] print(f" Tokenized: {len(tokens):,} tokens") token_ids = torch.tensor(tokens, dtype=torch.long) # Warmup usage counters warmup_prompts = [ "Explain how neural networks learn from data.", "Write a Python function that sorts a list.", "What are the main causes of climate change?", "Describe the architecture of a transformer.", "How does public key cryptography work?", "What is the halting problem?", "Explain quantum computing simply.", "Write a recursive Fibonacci function.", "What are the fundamental forces in physics?", "How does the human immune system work?", "Describe the process of photosynthesis.", "What is the P vs NP problem?", "How does GPS determine your location?", "Explain machine learning overfitting.", "What are design patterns in software?", "How do search engines rank pages?", "Describe the lifecycle of a star.", "What is Shannon's information theory?", "How do operating systems manage memory?", "Explain the CAP theorem.", ] print(f" Warming up expert usage (20 prompts)...") for prompt in warmup_prompts: ids = tokenizer.encode(prompt, return_tensors='pt').cuda() engine.reset_cache() engine._current_seq_id = 0 engine.generate(ids, max_new_tokens=32, temperature=0.0) ffn = engine.layers[0].ffn if hasattr(ffn, 'expert_usage'): usage = ffn.expert_usage top5 = usage.topk(5) bot5 = usage.topk(5, largest=False) print(f" Layer 0 usage: top5={top5.values.tolist()}, bot5={bot5.values.tolist()}") # Apply quantization config if config == 'fp4': print(" [FP4 baseline — no demotion]") elif config == 'fexc': engine.enable_auto_fexc_demotion(cold_threshold_pct=cold_pct) total = 0 for layer in engine.layers: layer.ffn._maybe_demote_to_fexc() if hasattr(layer.ffn, '_expert_is_fexc'): total += layer.ffn._expert_is_fexc.sum().item() print(f" FE-XC demoted: {total} experts ({total // len(engine.layers)}/layer)") elif config == 'fexvq': if os.path.exists(FEXVQ_CODEBOOKS): print(f" Loading pre-calibrated FE-XVQ codebooks...") ckpt = torch.load(FEXVQ_CODEBOOKS, weights_only=True) codebooks = ckpt['codebooks'] engine.enable_auto_fexc_demotion(cold_threshold_pct=cold_pct) # Force init + inject Hessian-weighted codebooks BEFORE demotion for li, layer in enumerate(engine.layers): ffn_l = layer.ffn if not getattr(ffn_l, '_fexc_enabled', False): ffn_l._init_fexc_buffers() if li in codebooks: ffn_l.gu_codebooks = codebooks[li]['gate_up'].cuda().half() ffn_l.dn_codebooks = codebooks[li]['down'].cuda().half() total = 0 for layer in engine.layers: layer.ffn._maybe_demote_to_fexc() if hasattr(layer.ffn, '_expert_is_fexc'): total += layer.ffn._expert_is_fexc.sum().item() print(f" FE-XVQ demoted: {total} experts ({total // len(engine.layers)}/layer)") else: print(f" ERROR: No pre-calibrated codebooks at {FEXVQ_CODEBOOKS}") json.dump({'error': 'no codebooks'}, open(result_file, 'w')) return elif config == 'int2': engine.enable_auto_int2_demotion(cold_threshold_pct=cold_pct) total = 0 for layer in engine.layers: layer.ffn._maybe_demote_to_int2() if hasattr(layer.ffn, '_expert_is_int2'): total += layer.ffn._expert_is_int2.sum().item() print(f" INT2 demoted: {total} experts ({total // len(engine.layers)}/layer)") vram_gb = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram_gb:.1f} GB") # Evaluate perplexity print(f"\n Evaluating perplexity...") t0 = time.time() total_nll = 0.0 total_tokens = 0 num_windows = 0 seq_len = token_ids.shape[0] num_windows_total = max(1, (seq_len - max_len) // stride + 1) for begin in range(0, seq_len - 1, stride): end = min(begin + max_len, seq_len) input_ids = token_ids[begin:end].unsqueeze(0).cuda() engine.reset_cache() engine._current_seq_id = 0 if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False with torch.no_grad(): logits = engine.forward(input_ids, use_cache=False) shift_logits = logits[:, :-1, :].contiguous() shift_labels = input_ids[:, 1:].contiguous() if begin > 0: overlap = max_len - stride shift_logits = shift_logits[:, overlap:, :] shift_labels = shift_labels[:, overlap:] if shift_labels.numel() == 0: continue loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction='sum' ) total_nll += loss.item() total_tokens += shift_labels.numel() num_windows += 1 if num_windows % 20 == 0 or num_windows == 1: elapsed = time.time() - t0 current_ppl = math.exp(total_nll / total_tokens) tok_per_s = total_tokens / elapsed print(f" Window {num_windows}/{num_windows_total}: " f"PPL={current_ppl:.2f}, {total_tokens} tok, " f"{tok_per_s:.0f} tok/s eval") elapsed = time.time() - t0 ppl = math.exp(total_nll / total_tokens) if total_tokens > 0 else float('inf') print(f" Final: PPL={ppl:.2f}, {total_tokens} tok, " f"{num_windows} windows, {elapsed:.1f}s") # Write result result = { 'config': config, 'ppl': ppl, 'tokens': total_tokens, 'vram_gb': vram_gb, 'time_s': elapsed, } with open(result_file, 'w') as f: json.dump(result, f) # ===== Main orchestrator ===== def main(): parser = argparse.ArgumentParser(description='FireEcho Perplexity Benchmark') parser.add_argument('--max_tokens', type=int, default=50000, help='Max tokens from WikiText-2 (default: 50000)') parser.add_argument('--stride', type=int, default=256, help='Sliding window stride (default: 256)') parser.add_argument('--max_len', type=int, default=512, help='Max context per window (default: 512)') parser.add_argument('--configs', type=str, default='fp4,fexc,fexvq,int2', help='Comma-separated configs to test (default: fp4,fexc,fexvq,int2)') parser.add_argument('--cold_pct', type=float, default=0.10, help='Fraction of experts to demote (default: 0.10)') parser.add_argument('--_worker', type=str, default=None, help=argparse.SUPPRESS) # Internal: run single config parser.add_argument('--_result_file', type=str, default=None, help=argparse.SUPPRESS) args = parser.parse_args() # Worker mode: run single config in subprocess if args._worker: run_single_config(args._worker, args.max_tokens, args.stride, args.max_len, args.cold_pct, args._result_file) return # Orchestrator mode: spawn subprocesses configs = [c.strip() for c in args.configs.split(',')] print("=" * 70) print(" FireEcho Perplexity Benchmark") print(" WikiText-2 | Qwen3-Omni 30B MoE | RTX 5090") print("=" * 70) print(f" Max tokens: {args.max_tokens:,}") print(f" Window: {args.max_len}, stride: {args.stride}") print(f" Cold threshold: {args.cold_pct*100:.0f}%") print(f" Configs: {configs}") print(f" Subprocess isolation: enabled (clean CUDA context per config)") results = {} script_path = os.path.abspath(__file__) python = sys.executable for config in configs: # Create temp file for result fd, result_file = tempfile.mkstemp(suffix='.json', prefix=f'ppl_{config}_') os.close(fd) try: cmd = [ python, '-u', script_path, '--_worker', config, '--_result_file', result_file, '--max_tokens', str(args.max_tokens), '--stride', str(args.stride), '--max_len', str(args.max_len), '--cold_pct', str(args.cold_pct), ] ret = subprocess.run(cmd, cwd=SCRIPT_DIR) if ret.returncode != 0: print(f"\n SUBPROCESS FAILED for {config.upper()} (exit code {ret.returncode})") results[config] = {'error': f'exit code {ret.returncode}'} continue # Read result with open(result_file) as f: r = json.load(f) if 'error' in r: results[config] = r else: results[config] = r print(f" >> {config.upper()}: PPL={r['ppl']:.2f}, " f"VRAM={r['vram_gb']:.1f}G, {r['time_s']:.0f}s") except Exception as e: print(f"\n ERROR launching {config.upper()}: {e}") results[config] = {'error': str(e)} finally: if os.path.exists(result_file): os.unlink(result_file) # === Results Table === print(f"\n{'=' * 70}") print(f" RESULTS — WikiText-2 Perplexity") print(f"{'=' * 70}") print(f"\n{'Config':<12} {'PPL':>8} {'Δ PPL':>8} {'VRAM':>8} {'Tokens':>10} {'bits/w':>7} {'Time':>7}") print(f"{'─' * 66}") baseline_ppl = results.get('fp4', {}).get('ppl', None) for config in configs: if config not in results: continue r = results[config] if r.get('error'): print(f"{config.upper():<12} {'ERROR':>8} {'—':>8} {'—':>8} {'—':>10} {'—':>7} {'—':>7}") continue delta = f"+{r['ppl'] - baseline_ppl:.2f}" if baseline_ppl and config != 'fp4' else "—" bits = {'fp4': '4.0', 'fexc': '~2.2', 'fexvq': '~2.2', 'int2': '2.0'}.get(config, '?') time_s = f"{r.get('time_s', 0):.0f}s" print(f"{config.upper():<12} {r['ppl']:>8.2f} {delta:>8} {r['vram_gb']:>7.1f}G " f"{r['tokens']:>10,} {bits:>7} {time_s:>7}") # Ablation analysis: FE-XC vs FE-XVQ if (baseline_ppl and 'fexc' in results and 'fexvq' in results and not results['fexc'].get('error') and not results['fexvq'].get('error')): fexc_delta = results['fexc']['ppl'] - baseline_ppl fexvq_delta = results['fexvq']['ppl'] - baseline_ppl print(f"\n Ablation: Hessian-weighted codebooks (FE-XVQ vs FE-XC)") print(f" FE-XC (plain k-means): +{fexc_delta:.2f} PPL") print(f" FE-XVQ (Hessian-weighted): +{fexvq_delta:.2f} PPL") if fexc_delta > 0: hessian_gain = (1 - fexvq_delta / fexc_delta) * 100 print(f" Hessian reduces {hessian_gain:.0f}% of codebook PPL degradation") # FE-XVQ vs INT2 if (baseline_ppl and 'fexvq' in results and 'int2' in results and not results['fexvq'].get('error') and not results['int2'].get('error')): fexvq_delta = results['fexvq']['ppl'] - baseline_ppl int2_delta = results['int2']['ppl'] - baseline_ppl if int2_delta > 0: improvement = (1 - fexvq_delta / int2_delta) * 100 print(f"\n FE-XVQ recovers {improvement:.0f}% of INT2's PPL degradation") print(f" (same 2-bit storage, codebook quality advantage)") # Note about BF16 print(f"\n Note: BF16 baseline omitted — Qwen3-Omni 30B BF16 = ~61GB,") print(f" exceeds RTX 5090 32GB. FP4 (Goliath) is practical baseline.") print(f"\n{'=' * 70}") if __name__ == '__main__': main()