FireEcho / FireEcho Engine /benchmark_perplexity.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
#!/usr/bin/env python3
"""Perplexity benchmark for FireEcho quantization formats.
Evaluates WikiText-2 perplexity across quantization configs:
1. FP4 baseline (Goliath FP4, all experts)
2. FE-XC 10% cold (codebook 2-bit, plain k-means)
3. FE-XVQ 10% cold (codebook 2-bit, Hessian-weighted k-means)
4. INT2 10% cold (scalar 2-bit)
Each config runs in a SEPARATE SUBPROCESS to guarantee clean CUDA context
(PyTorch's memory allocator doesn't fully release between del+gc.collect).
Usage:
python benchmark_perplexity.py [--max_tokens 50000] [--stride 256]
Output: PPL comparison table suitable for paper.
Copyright (c) 2025-2026 Echo (FireEcho Project). All rights reserved.
"""
import sys
import os
import time
import math
import json
import argparse
import subprocess
import tempfile
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
MODEL_DIR = '/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct'
FEXVQ_CODEBOOKS = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fexvq_codebooks.pt')
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# ===== Worker code (runs in subprocess) =====
def run_single_config(config, max_tokens, stride, max_len, cold_pct, result_file):
"""Run a single config evaluation. Called in subprocess."""
import torch
import torch.nn.functional as F
sys.path.insert(0, SCRIPT_DIR)
print(f"\n{'=' * 70}")
print(f" Config: {config.upper()}")
print(f"{'=' * 70}")
# Load model
from fireecho_kernel import FireEchoEngine
from transformers import AutoTokenizer
print("[1] Loading model...")
engine = FireEchoEngine.from_pretrained(MODEL_DIR)
engine.pack_all_experts()
engine.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
# Load WikiText-2
from datasets import load_dataset
print(" Loading WikiText-2 test set...")
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
text = "\n\n".join([t for t in ds["text"] if t.strip()])
print(f" Text length: {len(text):,} chars")
tokens = tokenizer.encode(text, add_special_tokens=False)
if max_tokens > 0 and len(tokens) > max_tokens:
tokens = tokens[:max_tokens]
print(f" Tokenized: {len(tokens):,} tokens")
token_ids = torch.tensor(tokens, dtype=torch.long)
# Warmup usage counters
warmup_prompts = [
"Explain how neural networks learn from data.",
"Write a Python function that sorts a list.",
"What are the main causes of climate change?",
"Describe the architecture of a transformer.",
"How does public key cryptography work?",
"What is the halting problem?",
"Explain quantum computing simply.",
"Write a recursive Fibonacci function.",
"What are the fundamental forces in physics?",
"How does the human immune system work?",
"Describe the process of photosynthesis.",
"What is the P vs NP problem?",
"How does GPS determine your location?",
"Explain machine learning overfitting.",
"What are design patterns in software?",
"How do search engines rank pages?",
"Describe the lifecycle of a star.",
"What is Shannon's information theory?",
"How do operating systems manage memory?",
"Explain the CAP theorem.",
]
print(f" Warming up expert usage (20 prompts)...")
for prompt in warmup_prompts:
ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
engine.generate(ids, max_new_tokens=32, temperature=0.0)
ffn = engine.layers[0].ffn
if hasattr(ffn, 'expert_usage'):
usage = ffn.expert_usage
top5 = usage.topk(5)
bot5 = usage.topk(5, largest=False)
print(f" Layer 0 usage: top5={top5.values.tolist()}, bot5={bot5.values.tolist()}")
# Apply quantization config
if config == 'fp4':
print(" [FP4 baseline β€” no demotion]")
elif config == 'fexc':
engine.enable_auto_fexc_demotion(cold_threshold_pct=cold_pct)
total = 0
for layer in engine.layers:
layer.ffn._maybe_demote_to_fexc()
if hasattr(layer.ffn, '_expert_is_fexc'):
total += layer.ffn._expert_is_fexc.sum().item()
print(f" FE-XC demoted: {total} experts ({total // len(engine.layers)}/layer)")
elif config == 'fexvq':
if os.path.exists(FEXVQ_CODEBOOKS):
print(f" Loading pre-calibrated FE-XVQ codebooks...")
ckpt = torch.load(FEXVQ_CODEBOOKS, weights_only=True)
codebooks = ckpt['codebooks']
engine.enable_auto_fexc_demotion(cold_threshold_pct=cold_pct)
# Force init + inject Hessian-weighted codebooks BEFORE demotion
for li, layer in enumerate(engine.layers):
ffn_l = layer.ffn
if not getattr(ffn_l, '_fexc_enabled', False):
ffn_l._init_fexc_buffers()
if li in codebooks:
ffn_l.gu_codebooks = codebooks[li]['gate_up'].cuda().half()
ffn_l.dn_codebooks = codebooks[li]['down'].cuda().half()
total = 0
for layer in engine.layers:
layer.ffn._maybe_demote_to_fexc()
if hasattr(layer.ffn, '_expert_is_fexc'):
total += layer.ffn._expert_is_fexc.sum().item()
print(f" FE-XVQ demoted: {total} experts ({total // len(engine.layers)}/layer)")
else:
print(f" ERROR: No pre-calibrated codebooks at {FEXVQ_CODEBOOKS}")
json.dump({'error': 'no codebooks'}, open(result_file, 'w'))
return
elif config == 'int2':
engine.enable_auto_int2_demotion(cold_threshold_pct=cold_pct)
total = 0
for layer in engine.layers:
layer.ffn._maybe_demote_to_int2()
if hasattr(layer.ffn, '_expert_is_int2'):
total += layer.ffn._expert_is_int2.sum().item()
print(f" INT2 demoted: {total} experts ({total // len(engine.layers)}/layer)")
vram_gb = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram_gb:.1f} GB")
# Evaluate perplexity
print(f"\n Evaluating perplexity...")
t0 = time.time()
total_nll = 0.0
total_tokens = 0
num_windows = 0
seq_len = token_ids.shape[0]
num_windows_total = max(1, (seq_len - max_len) // stride + 1)
for begin in range(0, seq_len - 1, stride):
end = min(begin + max_len, seq_len)
input_ids = token_ids[begin:end].unsqueeze(0).cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
with torch.no_grad():
logits = engine.forward(input_ids, use_cache=False)
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
if begin > 0:
overlap = max_len - stride
shift_logits = shift_logits[:, overlap:, :]
shift_labels = shift_labels[:, overlap:]
if shift_labels.numel() == 0:
continue
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction='sum'
)
total_nll += loss.item()
total_tokens += shift_labels.numel()
num_windows += 1
if num_windows % 20 == 0 or num_windows == 1:
elapsed = time.time() - t0
current_ppl = math.exp(total_nll / total_tokens)
tok_per_s = total_tokens / elapsed
print(f" Window {num_windows}/{num_windows_total}: "
f"PPL={current_ppl:.2f}, {total_tokens} tok, "
f"{tok_per_s:.0f} tok/s eval")
elapsed = time.time() - t0
ppl = math.exp(total_nll / total_tokens) if total_tokens > 0 else float('inf')
print(f" Final: PPL={ppl:.2f}, {total_tokens} tok, "
f"{num_windows} windows, {elapsed:.1f}s")
# Write result
result = {
'config': config,
'ppl': ppl,
'tokens': total_tokens,
'vram_gb': vram_gb,
'time_s': elapsed,
}
with open(result_file, 'w') as f:
json.dump(result, f)
# ===== Main orchestrator =====
def main():
parser = argparse.ArgumentParser(description='FireEcho Perplexity Benchmark')
parser.add_argument('--max_tokens', type=int, default=50000,
help='Max tokens from WikiText-2 (default: 50000)')
parser.add_argument('--stride', type=int, default=256,
help='Sliding window stride (default: 256)')
parser.add_argument('--max_len', type=int, default=512,
help='Max context per window (default: 512)')
parser.add_argument('--configs', type=str, default='fp4,fexc,fexvq,int2',
help='Comma-separated configs to test (default: fp4,fexc,fexvq,int2)')
parser.add_argument('--cold_pct', type=float, default=0.10,
help='Fraction of experts to demote (default: 0.10)')
parser.add_argument('--_worker', type=str, default=None,
help=argparse.SUPPRESS) # Internal: run single config
parser.add_argument('--_result_file', type=str, default=None,
help=argparse.SUPPRESS)
args = parser.parse_args()
# Worker mode: run single config in subprocess
if args._worker:
run_single_config(args._worker, args.max_tokens, args.stride,
args.max_len, args.cold_pct, args._result_file)
return
# Orchestrator mode: spawn subprocesses
configs = [c.strip() for c in args.configs.split(',')]
print("=" * 70)
print(" FireEcho Perplexity Benchmark")
print(" WikiText-2 | Qwen3-Omni 30B MoE | RTX 5090")
print("=" * 70)
print(f" Max tokens: {args.max_tokens:,}")
print(f" Window: {args.max_len}, stride: {args.stride}")
print(f" Cold threshold: {args.cold_pct*100:.0f}%")
print(f" Configs: {configs}")
print(f" Subprocess isolation: enabled (clean CUDA context per config)")
results = {}
script_path = os.path.abspath(__file__)
python = sys.executable
for config in configs:
# Create temp file for result
fd, result_file = tempfile.mkstemp(suffix='.json', prefix=f'ppl_{config}_')
os.close(fd)
try:
cmd = [
python, '-u', script_path,
'--_worker', config,
'--_result_file', result_file,
'--max_tokens', str(args.max_tokens),
'--stride', str(args.stride),
'--max_len', str(args.max_len),
'--cold_pct', str(args.cold_pct),
]
ret = subprocess.run(cmd, cwd=SCRIPT_DIR)
if ret.returncode != 0:
print(f"\n SUBPROCESS FAILED for {config.upper()} (exit code {ret.returncode})")
results[config] = {'error': f'exit code {ret.returncode}'}
continue
# Read result
with open(result_file) as f:
r = json.load(f)
if 'error' in r:
results[config] = r
else:
results[config] = r
print(f" >> {config.upper()}: PPL={r['ppl']:.2f}, "
f"VRAM={r['vram_gb']:.1f}G, {r['time_s']:.0f}s")
except Exception as e:
print(f"\n ERROR launching {config.upper()}: {e}")
results[config] = {'error': str(e)}
finally:
if os.path.exists(result_file):
os.unlink(result_file)
# === Results Table ===
print(f"\n{'=' * 70}")
print(f" RESULTS β€” WikiText-2 Perplexity")
print(f"{'=' * 70}")
print(f"\n{'Config':<12} {'PPL':>8} {'Ξ” PPL':>8} {'VRAM':>8} {'Tokens':>10} {'bits/w':>7} {'Time':>7}")
print(f"{'─' * 66}")
baseline_ppl = results.get('fp4', {}).get('ppl', None)
for config in configs:
if config not in results:
continue
r = results[config]
if r.get('error'):
print(f"{config.upper():<12} {'ERROR':>8} {'β€”':>8} {'β€”':>8} {'β€”':>10} {'β€”':>7} {'β€”':>7}")
continue
delta = f"+{r['ppl'] - baseline_ppl:.2f}" if baseline_ppl and config != 'fp4' else "β€”"
bits = {'fp4': '4.0', 'fexc': '~2.2', 'fexvq': '~2.2', 'int2': '2.0'}.get(config, '?')
time_s = f"{r.get('time_s', 0):.0f}s"
print(f"{config.upper():<12} {r['ppl']:>8.2f} {delta:>8} {r['vram_gb']:>7.1f}G "
f"{r['tokens']:>10,} {bits:>7} {time_s:>7}")
# Ablation analysis: FE-XC vs FE-XVQ
if (baseline_ppl and 'fexc' in results and 'fexvq' in results
and not results['fexc'].get('error') and not results['fexvq'].get('error')):
fexc_delta = results['fexc']['ppl'] - baseline_ppl
fexvq_delta = results['fexvq']['ppl'] - baseline_ppl
print(f"\n Ablation: Hessian-weighted codebooks (FE-XVQ vs FE-XC)")
print(f" FE-XC (plain k-means): +{fexc_delta:.2f} PPL")
print(f" FE-XVQ (Hessian-weighted): +{fexvq_delta:.2f} PPL")
if fexc_delta > 0:
hessian_gain = (1 - fexvq_delta / fexc_delta) * 100
print(f" Hessian reduces {hessian_gain:.0f}% of codebook PPL degradation")
# FE-XVQ vs INT2
if (baseline_ppl and 'fexvq' in results and 'int2' in results
and not results['fexvq'].get('error') and not results['int2'].get('error')):
fexvq_delta = results['fexvq']['ppl'] - baseline_ppl
int2_delta = results['int2']['ppl'] - baseline_ppl
if int2_delta > 0:
improvement = (1 - fexvq_delta / int2_delta) * 100
print(f"\n FE-XVQ recovers {improvement:.0f}% of INT2's PPL degradation")
print(f" (same 2-bit storage, codebook quality advantage)")
# Note about BF16
print(f"\n Note: BF16 baseline omitted β€” Qwen3-Omni 30B BF16 = ~61GB,")
print(f" exceeds RTX 5090 32GB. FP4 (Goliath) is practical baseline.")
print(f"\n{'=' * 70}")
if __name__ == '__main__':
main()