| | |
| | """Perplexity benchmark for FireEcho quantization formats. |
| | |
| | Evaluates WikiText-2 perplexity across quantization configs: |
| | 1. FP4 baseline (Goliath FP4, all experts) |
| | 2. FE-XC 10% cold (codebook 2-bit, plain k-means) |
| | 3. FE-XVQ 10% cold (codebook 2-bit, Hessian-weighted k-means) |
| | 4. INT2 10% cold (scalar 2-bit) |
| | |
| | Each config runs in a SEPARATE SUBPROCESS to guarantee clean CUDA context |
| | (PyTorch's memory allocator doesn't fully release between del+gc.collect). |
| | |
| | Usage: |
| | python benchmark_perplexity.py [--max_tokens 50000] [--stride 256] |
| | |
| | Output: PPL comparison table suitable for paper. |
| | |
| | Copyright (c) 2025-2026 Echo (FireEcho Project). All rights reserved. |
| | """ |
| |
|
| | import sys |
| | import os |
| | import time |
| | import math |
| | import json |
| | import argparse |
| | import subprocess |
| | import tempfile |
| |
|
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| |
|
| | MODEL_DIR = '/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct' |
| | FEXVQ_CODEBOOKS = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
| | 'fexvq_codebooks.pt') |
| | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| |
|
| | |
| |
|
| | def run_single_config(config, max_tokens, stride, max_len, cold_pct, result_file): |
| | """Run a single config evaluation. Called in subprocess.""" |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | sys.path.insert(0, SCRIPT_DIR) |
| |
|
| | print(f"\n{'=' * 70}") |
| | print(f" Config: {config.upper()}") |
| | print(f"{'=' * 70}") |
| |
|
| | |
| | from fireecho_kernel import FireEchoEngine |
| | from transformers import AutoTokenizer |
| |
|
| | print("[1] Loading model...") |
| | engine = FireEchoEngine.from_pretrained(MODEL_DIR) |
| | engine.pack_all_experts() |
| | engine.eval() |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) |
| |
|
| | |
| | from datasets import load_dataset |
| | print(" Loading WikiText-2 test set...") |
| | ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| | text = "\n\n".join([t for t in ds["text"] if t.strip()]) |
| | print(f" Text length: {len(text):,} chars") |
| | tokens = tokenizer.encode(text, add_special_tokens=False) |
| | if max_tokens > 0 and len(tokens) > max_tokens: |
| | tokens = tokens[:max_tokens] |
| | print(f" Tokenized: {len(tokens):,} tokens") |
| | token_ids = torch.tensor(tokens, dtype=torch.long) |
| |
|
| | |
| | warmup_prompts = [ |
| | "Explain how neural networks learn from data.", |
| | "Write a Python function that sorts a list.", |
| | "What are the main causes of climate change?", |
| | "Describe the architecture of a transformer.", |
| | "How does public key cryptography work?", |
| | "What is the halting problem?", |
| | "Explain quantum computing simply.", |
| | "Write a recursive Fibonacci function.", |
| | "What are the fundamental forces in physics?", |
| | "How does the human immune system work?", |
| | "Describe the process of photosynthesis.", |
| | "What is the P vs NP problem?", |
| | "How does GPS determine your location?", |
| | "Explain machine learning overfitting.", |
| | "What are design patterns in software?", |
| | "How do search engines rank pages?", |
| | "Describe the lifecycle of a star.", |
| | "What is Shannon's information theory?", |
| | "How do operating systems manage memory?", |
| | "Explain the CAP theorem.", |
| | ] |
| | print(f" Warming up expert usage (20 prompts)...") |
| | for prompt in warmup_prompts: |
| | ids = tokenizer.encode(prompt, return_tensors='pt').cuda() |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | engine.generate(ids, max_new_tokens=32, temperature=0.0) |
| |
|
| | ffn = engine.layers[0].ffn |
| | if hasattr(ffn, 'expert_usage'): |
| | usage = ffn.expert_usage |
| | top5 = usage.topk(5) |
| | bot5 = usage.topk(5, largest=False) |
| | print(f" Layer 0 usage: top5={top5.values.tolist()}, bot5={bot5.values.tolist()}") |
| |
|
| | |
| | if config == 'fp4': |
| | print(" [FP4 baseline β no demotion]") |
| | elif config == 'fexc': |
| | engine.enable_auto_fexc_demotion(cold_threshold_pct=cold_pct) |
| | total = 0 |
| | for layer in engine.layers: |
| | layer.ffn._maybe_demote_to_fexc() |
| | if hasattr(layer.ffn, '_expert_is_fexc'): |
| | total += layer.ffn._expert_is_fexc.sum().item() |
| | print(f" FE-XC demoted: {total} experts ({total // len(engine.layers)}/layer)") |
| | elif config == 'fexvq': |
| | if os.path.exists(FEXVQ_CODEBOOKS): |
| | print(f" Loading pre-calibrated FE-XVQ codebooks...") |
| | ckpt = torch.load(FEXVQ_CODEBOOKS, weights_only=True) |
| | codebooks = ckpt['codebooks'] |
| | engine.enable_auto_fexc_demotion(cold_threshold_pct=cold_pct) |
| | |
| | for li, layer in enumerate(engine.layers): |
| | ffn_l = layer.ffn |
| | if not getattr(ffn_l, '_fexc_enabled', False): |
| | ffn_l._init_fexc_buffers() |
| | if li in codebooks: |
| | ffn_l.gu_codebooks = codebooks[li]['gate_up'].cuda().half() |
| | ffn_l.dn_codebooks = codebooks[li]['down'].cuda().half() |
| | total = 0 |
| | for layer in engine.layers: |
| | layer.ffn._maybe_demote_to_fexc() |
| | if hasattr(layer.ffn, '_expert_is_fexc'): |
| | total += layer.ffn._expert_is_fexc.sum().item() |
| | print(f" FE-XVQ demoted: {total} experts ({total // len(engine.layers)}/layer)") |
| | else: |
| | print(f" ERROR: No pre-calibrated codebooks at {FEXVQ_CODEBOOKS}") |
| | json.dump({'error': 'no codebooks'}, open(result_file, 'w')) |
| | return |
| | elif config == 'int2': |
| | engine.enable_auto_int2_demotion(cold_threshold_pct=cold_pct) |
| | total = 0 |
| | for layer in engine.layers: |
| | layer.ffn._maybe_demote_to_int2() |
| | if hasattr(layer.ffn, '_expert_is_int2'): |
| | total += layer.ffn._expert_is_int2.sum().item() |
| | print(f" INT2 demoted: {total} experts ({total // len(engine.layers)}/layer)") |
| |
|
| | vram_gb = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram_gb:.1f} GB") |
| |
|
| | |
| | print(f"\n Evaluating perplexity...") |
| | t0 = time.time() |
| |
|
| | total_nll = 0.0 |
| | total_tokens = 0 |
| | num_windows = 0 |
| | seq_len = token_ids.shape[0] |
| | num_windows_total = max(1, (seq_len - max_len) // stride + 1) |
| |
|
| | for begin in range(0, seq_len - 1, stride): |
| | end = min(begin + max_len, seq_len) |
| | input_ids = token_ids[begin:end].unsqueeze(0).cuda() |
| |
|
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| |
|
| | with torch.no_grad(): |
| | logits = engine.forward(input_ids, use_cache=False) |
| |
|
| | shift_logits = logits[:, :-1, :].contiguous() |
| | shift_labels = input_ids[:, 1:].contiguous() |
| |
|
| | if begin > 0: |
| | overlap = max_len - stride |
| | shift_logits = shift_logits[:, overlap:, :] |
| | shift_labels = shift_labels[:, overlap:] |
| |
|
| | if shift_labels.numel() == 0: |
| | continue |
| |
|
| | loss = F.cross_entropy( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1), |
| | reduction='sum' |
| | ) |
| |
|
| | total_nll += loss.item() |
| | total_tokens += shift_labels.numel() |
| | num_windows += 1 |
| |
|
| | if num_windows % 20 == 0 or num_windows == 1: |
| | elapsed = time.time() - t0 |
| | current_ppl = math.exp(total_nll / total_tokens) |
| | tok_per_s = total_tokens / elapsed |
| | print(f" Window {num_windows}/{num_windows_total}: " |
| | f"PPL={current_ppl:.2f}, {total_tokens} tok, " |
| | f"{tok_per_s:.0f} tok/s eval") |
| |
|
| | elapsed = time.time() - t0 |
| | ppl = math.exp(total_nll / total_tokens) if total_tokens > 0 else float('inf') |
| | print(f" Final: PPL={ppl:.2f}, {total_tokens} tok, " |
| | f"{num_windows} windows, {elapsed:.1f}s") |
| |
|
| | |
| | result = { |
| | 'config': config, |
| | 'ppl': ppl, |
| | 'tokens': total_tokens, |
| | 'vram_gb': vram_gb, |
| | 'time_s': elapsed, |
| | } |
| | with open(result_file, 'w') as f: |
| | json.dump(result, f) |
| |
|
| |
|
| | |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='FireEcho Perplexity Benchmark') |
| | parser.add_argument('--max_tokens', type=int, default=50000, |
| | help='Max tokens from WikiText-2 (default: 50000)') |
| | parser.add_argument('--stride', type=int, default=256, |
| | help='Sliding window stride (default: 256)') |
| | parser.add_argument('--max_len', type=int, default=512, |
| | help='Max context per window (default: 512)') |
| | parser.add_argument('--configs', type=str, default='fp4,fexc,fexvq,int2', |
| | help='Comma-separated configs to test (default: fp4,fexc,fexvq,int2)') |
| | parser.add_argument('--cold_pct', type=float, default=0.10, |
| | help='Fraction of experts to demote (default: 0.10)') |
| | parser.add_argument('--_worker', type=str, default=None, |
| | help=argparse.SUPPRESS) |
| | parser.add_argument('--_result_file', type=str, default=None, |
| | help=argparse.SUPPRESS) |
| | args = parser.parse_args() |
| |
|
| | |
| | if args._worker: |
| | run_single_config(args._worker, args.max_tokens, args.stride, |
| | args.max_len, args.cold_pct, args._result_file) |
| | return |
| |
|
| | |
| | configs = [c.strip() for c in args.configs.split(',')] |
| |
|
| | print("=" * 70) |
| | print(" FireEcho Perplexity Benchmark") |
| | print(" WikiText-2 | Qwen3-Omni 30B MoE | RTX 5090") |
| | print("=" * 70) |
| | print(f" Max tokens: {args.max_tokens:,}") |
| | print(f" Window: {args.max_len}, stride: {args.stride}") |
| | print(f" Cold threshold: {args.cold_pct*100:.0f}%") |
| | print(f" Configs: {configs}") |
| | print(f" Subprocess isolation: enabled (clean CUDA context per config)") |
| |
|
| | results = {} |
| | script_path = os.path.abspath(__file__) |
| | python = sys.executable |
| |
|
| | for config in configs: |
| | |
| | fd, result_file = tempfile.mkstemp(suffix='.json', prefix=f'ppl_{config}_') |
| | os.close(fd) |
| |
|
| | try: |
| | cmd = [ |
| | python, '-u', script_path, |
| | '--_worker', config, |
| | '--_result_file', result_file, |
| | '--max_tokens', str(args.max_tokens), |
| | '--stride', str(args.stride), |
| | '--max_len', str(args.max_len), |
| | '--cold_pct', str(args.cold_pct), |
| | ] |
| | ret = subprocess.run(cmd, cwd=SCRIPT_DIR) |
| |
|
| | if ret.returncode != 0: |
| | print(f"\n SUBPROCESS FAILED for {config.upper()} (exit code {ret.returncode})") |
| | results[config] = {'error': f'exit code {ret.returncode}'} |
| | continue |
| |
|
| | |
| | with open(result_file) as f: |
| | r = json.load(f) |
| | if 'error' in r: |
| | results[config] = r |
| | else: |
| | results[config] = r |
| | print(f" >> {config.upper()}: PPL={r['ppl']:.2f}, " |
| | f"VRAM={r['vram_gb']:.1f}G, {r['time_s']:.0f}s") |
| |
|
| | except Exception as e: |
| | print(f"\n ERROR launching {config.upper()}: {e}") |
| | results[config] = {'error': str(e)} |
| | finally: |
| | if os.path.exists(result_file): |
| | os.unlink(result_file) |
| |
|
| | |
| | print(f"\n{'=' * 70}") |
| | print(f" RESULTS β WikiText-2 Perplexity") |
| | print(f"{'=' * 70}") |
| | print(f"\n{'Config':<12} {'PPL':>8} {'Ξ PPL':>8} {'VRAM':>8} {'Tokens':>10} {'bits/w':>7} {'Time':>7}") |
| | print(f"{'β' * 66}") |
| |
|
| | baseline_ppl = results.get('fp4', {}).get('ppl', None) |
| | for config in configs: |
| | if config not in results: |
| | continue |
| | r = results[config] |
| | if r.get('error'): |
| | print(f"{config.upper():<12} {'ERROR':>8} {'β':>8} {'β':>8} {'β':>10} {'β':>7} {'β':>7}") |
| | continue |
| | delta = f"+{r['ppl'] - baseline_ppl:.2f}" if baseline_ppl and config != 'fp4' else "β" |
| | bits = {'fp4': '4.0', 'fexc': '~2.2', 'fexvq': '~2.2', 'int2': '2.0'}.get(config, '?') |
| | time_s = f"{r.get('time_s', 0):.0f}s" |
| | print(f"{config.upper():<12} {r['ppl']:>8.2f} {delta:>8} {r['vram_gb']:>7.1f}G " |
| | f"{r['tokens']:>10,} {bits:>7} {time_s:>7}") |
| |
|
| | |
| | if (baseline_ppl and 'fexc' in results and 'fexvq' in results |
| | and not results['fexc'].get('error') and not results['fexvq'].get('error')): |
| | fexc_delta = results['fexc']['ppl'] - baseline_ppl |
| | fexvq_delta = results['fexvq']['ppl'] - baseline_ppl |
| | print(f"\n Ablation: Hessian-weighted codebooks (FE-XVQ vs FE-XC)") |
| | print(f" FE-XC (plain k-means): +{fexc_delta:.2f} PPL") |
| | print(f" FE-XVQ (Hessian-weighted): +{fexvq_delta:.2f} PPL") |
| | if fexc_delta > 0: |
| | hessian_gain = (1 - fexvq_delta / fexc_delta) * 100 |
| | print(f" Hessian reduces {hessian_gain:.0f}% of codebook PPL degradation") |
| |
|
| | |
| | if (baseline_ppl and 'fexvq' in results and 'int2' in results |
| | and not results['fexvq'].get('error') and not results['int2'].get('error')): |
| | fexvq_delta = results['fexvq']['ppl'] - baseline_ppl |
| | int2_delta = results['int2']['ppl'] - baseline_ppl |
| | if int2_delta > 0: |
| | improvement = (1 - fexvq_delta / int2_delta) * 100 |
| | print(f"\n FE-XVQ recovers {improvement:.0f}% of INT2's PPL degradation") |
| | print(f" (same 2-bit storage, codebook quality advantage)") |
| |
|
| | |
| | print(f"\n Note: BF16 baseline omitted β Qwen3-Omni 30B BF16 = ~61GB,") |
| | print(f" exceeds RTX 5090 32GB. FP4 (Goliath) is practical baseline.") |
| |
|
| | print(f"\n{'=' * 70}") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|