FireEcho / benchmark_real_model.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
#!/usr/bin/env python3
"""
Real Model Benchmark: FireEcho vs HuggingFace
==============================================
Loads Qwen2-0.5B into both HuggingFace and FireEcho, validates correctness,
then benchmarks generation tok/s, TTFT, and VRAM across prompt lengths.
Usage:
python benchmark_real_model.py
python benchmark_real_model.py --model Qwen/Qwen2-0.5B --prompt-lengths 128 512 2048
"""
import argparse
import sys
import time
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _cuda_sync():
if torch.cuda.is_available():
torch.cuda.synchronize()
def _peak_vram_mb() -> float:
if torch.cuda.is_available():
return torch.cuda.max_memory_allocated() / (1024 ** 2)
return 0.0
def _reset_peak_vram():
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def _timed_cuda(fn, warmup: int = 2, repeats: int = 5) -> float:
"""Run *fn* with CUDA events, return median wall-time in seconds."""
for _ in range(warmup):
fn()
_cuda_sync()
times = []
for _ in range(repeats):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
fn()
end.record()
_cuda_sync()
times.append(start.elapsed_time(end) / 1000.0) # ms → s
times.sort()
return times[len(times) // 2] # median
# ---------------------------------------------------------------------------
# 1. Load models
# ---------------------------------------------------------------------------
def load_hf_model(model_name: str, dtype=torch.bfloat16, device='cuda'):
from transformers import AutoModelForCausalLM, AutoTokenizer
print(f"\n[HF] Loading {model_name} ...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map=device,
trust_remote_code=True, attn_implementation="sdpa",
)
model.eval()
params_m = sum(p.numel() for p in model.parameters()) / 1e6
print(f"[HF] {params_m:.1f}M params, dtype={dtype}, device={device}")
return model, tokenizer
def load_fireecho(model_name: str, dtype=torch.bfloat16, device='cuda',
use_goliath: bool = False, goliath_bits: int = 4):
sys.path.insert(0, '.')
from fireecho_kernel import FireEchoEngine, FireEchoConfig
if use_goliath:
# Load base first, then build quantised config
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config = FireEchoConfig(
dim=hf_cfg.hidden_size,
num_heads=hf_cfg.num_attention_heads,
num_kv_heads=getattr(hf_cfg, 'num_key_value_heads',
hf_cfg.num_attention_heads),
num_layers=hf_cfg.num_hidden_layers,
vocab_size=hf_cfg.vocab_size,
intermediate_size=hf_cfg.intermediate_size,
max_seq_len=min(getattr(hf_cfg, 'max_position_embeddings', 4096),
32768),
rope_theta=getattr(hf_cfg, 'rope_theta', 10000.0),
attn_bias=getattr(hf_cfg, 'attention_bias', False),
tie_word_embeddings=getattr(hf_cfg, 'tie_word_embeddings', False),
use_nvfp4=True,
use_goliath=True,
goliath_bits=goliath_bits,
use_hebbian=False,
)
tag = f"FP{goliath_bits}"
print(f"\n[FE-{tag}] Loading {model_name} (Goliath {tag}) ...")
engine = FireEchoEngine.from_pretrained(model_name, config=config,
dtype=dtype, device=device)
else:
tag = "BF16"
print(f"\n[FE-{tag}] Loading {model_name} ...")
engine = FireEchoEngine.from_pretrained(model_name, dtype=dtype,
device=device)
engine.eval()
params_m = sum(p.numel() for p in engine.parameters()) / 1e6
print(f"[FE-{tag}] {params_m:.1f}M params")
return engine
# ---------------------------------------------------------------------------
# 2. Correctness validation
# ---------------------------------------------------------------------------
def validate_correctness(hf_model, fe_engine, tokenizer, device='cuda',
seq_len: int = 128) -> Dict:
"""Compare HF vs FireEcho logits on the same input."""
prompt = "The quick brown fox jumps over the lazy dog. " * 20
input_ids = tokenizer(prompt, return_tensors='pt',
max_length=seq_len, truncation=True).input_ids.to(device)
actual_len = input_ids.shape[1]
with torch.no_grad():
hf_logits = hf_model(input_ids).logits # [1, S, V]
fe_logits = fe_engine(input_ids) # [1, S, V]
# Top-1 match rate
hf_top1 = hf_logits.argmax(dim=-1) # [1, S]
fe_top1 = fe_logits.argmax(dim=-1)
match_rate = (hf_top1 == fe_top1).float().mean().item()
# Numerical distance
max_abs_diff = (hf_logits - fe_logits).abs().max().item()
cos_sim = F.cosine_similarity(
hf_logits.view(-1, hf_logits.shape[-1]).float(),
fe_logits.view(-1, fe_logits.shape[-1]).float(),
dim=-1,
).mean().item()
return {
'seq_len': actual_len,
'top1_match': match_rate,
'max_abs_diff': max_abs_diff,
'cosine_sim': cos_sim,
}
# ---------------------------------------------------------------------------
# 3. Benchmark helpers
# ---------------------------------------------------------------------------
@torch.no_grad()
def bench_prefill(model, input_ids, is_hf: bool) -> Tuple[float, float]:
"""Measure TTFT (time-to-first-token) and peak VRAM for prefill."""
_reset_peak_vram()
_cuda_sync()
def _run():
if is_hf:
model(input_ids)
else:
model(input_ids)
ttft = _timed_cuda(_run, warmup=2, repeats=5)
vram = _peak_vram_mb()
return ttft, vram
@torch.no_grad()
def bench_decode(model, input_ids, max_new_tokens: int, is_hf: bool,
tokenizer=None) -> Tuple[float, float]:
"""Measure decode tok/s and peak VRAM."""
_reset_peak_vram()
_cuda_sync()
def _run():
if is_hf:
model.generate(input_ids, max_new_tokens=max_new_tokens,
do_sample=False, use_cache=True)
else:
model.generate(input_ids, max_new_tokens=max_new_tokens,
temperature=0.0, top_k=1, use_cache=True)
elapsed = _timed_cuda(_run, warmup=1, repeats=3)
tok_per_s = max_new_tokens / elapsed
vram = _peak_vram_mb()
return tok_per_s, vram
def make_input(tokenizer, seq_len: int, device='cuda') -> torch.Tensor:
"""Create an input_ids tensor of the desired length."""
# Repeat a seed sentence until we reach the desired length
seed = ("The quick brown fox jumps over the lazy dog. "
"In a distant land, ancient scholars studied the stars. ")
text = seed * ((seq_len // 20) + 1)
ids = tokenizer(text, return_tensors='pt',
max_length=seq_len, truncation=True).input_ids.to(device)
return ids
# ---------------------------------------------------------------------------
# 4. Run full benchmark
# ---------------------------------------------------------------------------
def _free_model(model):
"""Move model to CPU and free GPU memory."""
if model is not None:
model.cpu()
del model
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def run_benchmark(model_name: str, prompt_lengths: List[int],
max_new_tokens: int, device: str, dtype):
results = []
# Define configs to benchmark (loaded/freed one at a time)
config_specs = [
('HF-BF16', True, {}),
('FE-BF16', False, {}),
('FE-FP4', False, {'use_goliath': True, 'goliath_bits': 4}),
('FE-FP8', False, {'use_goliath': True, 'goliath_bits': 8}),
]
# --- Load HF for tokenizer + reference logits ---
hf_model, tokenizer = load_hf_model(model_name, dtype=dtype, device=device)
# Generate reference logits for correctness validation (then free HF)
print("\n" + "=" * 70)
print("CORRECTNESS VALIDATION (vs HuggingFace)")
print("=" * 70)
ref_prompt = "The quick brown fox jumps over the lazy dog. " * 20
ref_ids = tokenizer(ref_prompt, return_tensors='pt',
max_length=128, truncation=True).input_ids.to(device)
with torch.no_grad():
ref_logits = hf_model(ref_ids).logits.cpu() # save to CPU
ref_top1 = ref_logits.argmax(dim=-1)
_free_model(hf_model)
hf_model = None
# Validate each FE config against saved reference
for name, is_hf, fe_kwargs in config_specs:
if is_hf:
continue
try:
fe_model = load_fireecho(model_name, dtype=dtype, device=device, **fe_kwargs)
with torch.no_grad():
fe_logits = fe_model(ref_ids).cpu()
fe_top1 = fe_logits.argmax(dim=-1)
match_rate = (ref_top1 == fe_top1).float().mean().item()
cos_sim = F.cosine_similarity(
ref_logits.view(-1, ref_logits.shape[-1]).float(),
fe_logits.view(-1, fe_logits.shape[-1]).float(),
dim=-1).mean().item()
max_diff = (ref_logits - fe_logits).abs().max().item()
status = "PASS" if match_rate > 0.90 else "FAIL"
print(f" {name}: top1={match_rate:.3f} "
f"cos_sim={cos_sim:.5f} "
f"max_diff={max_diff:.4f} [{status}]")
_free_model(fe_model)
except Exception as e:
print(f" {name}: ERROR - {e}")
del ref_logits, ref_top1
import gc; gc.collect()
# --- Benchmark (one config at a time to avoid OOM on large models) ---
print("\n" + "=" * 70)
print(f"INFERENCE BENCHMARK (decode {max_new_tokens} tokens)")
print("=" * 70)
printed_headers = set()
for name, is_hf, fe_kwargs in config_specs:
try:
if is_hf:
model, _ = load_hf_model(model_name, dtype=dtype, device=device)
else:
model = load_fireecho(model_name, dtype=dtype, device=device, **fe_kwargs)
except Exception as e:
print(f"\n[WARN] {name} load failed: {e}")
continue
for seq_len in prompt_lengths:
input_ids = make_input(tokenizer, seq_len, device)
actual_len = input_ids.shape[1]
if actual_len not in printed_headers:
print(f"\n--- Prompt length: {actual_len} tokens ---")
print(f"{'Config':<12} {'TTFT(ms)':>10} {'Tok/s':>10} "
f"{'Prefill MB':>12} {'Decode MB':>12}")
print("-" * 60)
printed_headers.add(actual_len)
try:
if not is_hf and hasattr(model, 'reset_cache'):
model.reset_cache()
ttft, pre_vram = bench_prefill(model, input_ids, is_hf)
if not is_hf and hasattr(model, 'reset_cache'):
model.reset_cache()
tok_s, dec_vram = bench_decode(
model, input_ids, max_new_tokens, is_hf, tokenizer)
print(f"{name:<12} {ttft*1000:>10.1f} {tok_s:>10.1f} "
f"{pre_vram:>12.1f} {dec_vram:>12.1f}")
results.append({
'config': name, 'prompt_len': actual_len,
'ttft_ms': ttft * 1000, 'tok_s': tok_s,
'prefill_vram_mb': pre_vram, 'decode_vram_mb': dec_vram,
})
except Exception as e:
print(f"{name:<12} {'ERROR':>10} - {e}")
_free_model(model)
# --- Summary table ---
print("\n" + "=" * 70)
print("SUMMARY TABLE")
print("=" * 70)
print(f"{'Config':<12} {'Prompt':>7} {'TTFT(ms)':>10} {'Tok/s':>10} "
f"{'Peak VRAM':>12}")
print("-" * 55)
for r in results:
print(f"{r['config']:<12} {r['prompt_len']:>7} "
f"{r['ttft_ms']:>10.1f} {r['tok_s']:>10.1f} "
f"{r['decode_vram_mb']:>12.1f}")
return results
# ---------------------------------------------------------------------------
# 5. Generation demo
# ---------------------------------------------------------------------------
def generation_demo(model_name: str, device: str, dtype):
"""Show side-by-side generation from both engines."""
hf_model, tokenizer = load_hf_model(model_name, dtype=dtype, device=device)
sys.path.insert(0, '.')
from fireecho_kernel import FireEchoEngine
fe_engine = load_fireecho(model_name, dtype=dtype, device=device)
fe_engine.eval()
prompt = "Once upon a time in a land far away,"
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
print("\n" + "=" * 70)
print(f"GENERATION DEMO (prompt: {prompt!r})")
print("=" * 70)
# HuggingFace
with torch.no_grad():
hf_out = hf_model.generate(input_ids, max_new_tokens=60,
do_sample=False, use_cache=True)
hf_text = tokenizer.decode(hf_out[0], skip_special_tokens=True)
print(f"\n[HF] {hf_text}")
# FireEcho
fe_engine.reset_cache()
with torch.no_grad():
fe_out = fe_engine.generate(input_ids, max_new_tokens=60,
temperature=0.0, top_k=1, use_cache=True)
fe_text = tokenizer.decode(fe_out[0], skip_special_tokens=True)
print(f"[FE] {fe_text}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Benchmark FireEcho vs HuggingFace on a real model")
parser.add_argument('--model', default='Qwen/Qwen2-0.5B',
help='HuggingFace model name')
parser.add_argument('--prompt-lengths', nargs='+', type=int,
default=[128, 512, 2048],
help='Prompt lengths to benchmark')
parser.add_argument('--max-new-tokens', type=int, default=100,
help='Tokens to generate per benchmark')
parser.add_argument('--device', default='cuda')
parser.add_argument('--dtype', default='bfloat16',
choices=['bfloat16', 'float16', 'float32'])
parser.add_argument('--demo', action='store_true',
help='Run generation demo only')
args = parser.parse_args()
dtype_map = {
'bfloat16': torch.bfloat16,
'float16': torch.float16,
'float32': torch.float32,
}
dtype = dtype_map[args.dtype]
if not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
args.device = 'cpu'
print(f"Model: {args.model}")
print(f"Device: {args.device}")
print(f"Dtype: {dtype}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
if args.demo:
generation_demo(args.model, args.device, dtype)
else:
run_benchmark(args.model, args.prompt_lengths,
args.max_new_tokens, args.device, dtype)
if __name__ == '__main__':
main()