FireEcho / benchmark_generation.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
"""
Benchmark Generation — Prefill + Decode across Goliath configs
==============================================================
Measures:
- Prefill throughput (tok/s)
- Decode throughput (tok/s) and per-token latency (ms)
- VRAM usage (GB)
Configs tested:
- Goliath FP4 (goliath_bits=4)
- Goliath FP8 (goliath_bits=8)
- Goliath Auto (goliath_bits='auto')
- Legacy path (use_goliath=False)
Context lengths: 512, 2048, 8192
Usage:
python3 benchmark_generation.py
"""
import gc
import sys
import time
import torch
# Ensure the kernel directory is importable
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from fireecho_kernel import FireEchoConfig, FireEchoEngine, _GOLIATH_AVAILABLE
if _GOLIATH_AVAILABLE:
from goliath_kernel import _can_use_goliath_dot_scaled
else:
_can_use_goliath_dot_scaled = None
# ============================================================================
# Engine Factory
# ============================================================================
def create_bench_engine(goliath_bits=4, use_goliath=True, num_layers=8):
"""Create a 7B-scale engine (reduced layers) with the given Goliath config."""
config = FireEchoConfig(
dim=4096,
num_heads=32,
num_kv_heads=8,
num_layers=num_layers,
vocab_size=32000,
intermediate_size=11008,
max_seq_len=16384,
max_kv_blocks=1024,
use_nvfp4=True,
quantize_weights=True,
goliath_bits=goliath_bits,
use_goliath=use_goliath,
use_hebbian=False,
use_vision=False,
use_audio=False,
)
engine = FireEchoEngine(config).cuda()
engine.eval()
return engine
# ============================================================================
# Benchmark Helpers
# ============================================================================
def bench_prefill(engine, seq_len, warmup=3, iters=5):
"""Benchmark prefill (forward pass on full prompt).
Returns dict with ms, tok_s, vram_gb.
"""
input_ids = torch.randint(0, 32000, (1, seq_len), device='cuda')
# Warmup
for _ in range(warmup):
engine.reset_cache()
with torch.no_grad():
_ = engine(input_ids, use_cache=False)
torch.cuda.synchronize()
# Benchmark
start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)
start_evt.record()
for _ in range(iters):
engine.reset_cache()
with torch.no_grad():
_ = engine(input_ids, use_cache=False)
end_evt.record()
torch.cuda.synchronize()
ms = start_evt.elapsed_time(end_evt) / iters
tok_s = seq_len / (ms / 1000.0)
vram_gb = torch.cuda.memory_allocated() / 1e9
return {'ms': ms, 'tok_s': tok_s, 'vram_gb': vram_gb}
def bench_decode(engine, prompt_len, num_decode_tokens=50, warmup=2):
"""Benchmark decode (token-by-token generation after prefill).
Returns dict with total_ms, per_token_ms, tok_s, vram_gb.
"""
prompt = torch.randint(0, 32000, (1, prompt_len), device='cuda')
# Warmup
for _ in range(warmup):
engine.reset_cache()
with torch.no_grad():
_ = engine.generate(prompt, max_new_tokens=5, use_cache=False)
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
# Benchmark
start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)
engine.reset_cache()
start_evt.record()
with torch.no_grad():
output = engine.generate(prompt, max_new_tokens=num_decode_tokens,
use_cache=False)
end_evt.record()
torch.cuda.synchronize()
gen_tokens = output.shape[1] - prompt_len
total_ms = start_evt.elapsed_time(end_evt)
per_token_ms = total_ms / max(gen_tokens, 1)
tok_s = gen_tokens / (total_ms / 1000.0) if total_ms > 0 else 0.0
vram_gb = torch.cuda.memory_allocated() / 1e9
return {
'total_ms': total_ms,
'per_token_ms': per_token_ms,
'tok_s': tok_s,
'vram_gb': vram_gb,
'gen_tokens': gen_tokens,
}
# ============================================================================
# Main Benchmark
# ============================================================================
def main():
if not torch.cuda.is_available():
print("CUDA not available.")
return
props = torch.cuda.get_device_properties(0)
print("=" * 85)
print("GENERATION BENCHMARK — Goliath FP4/FP8 Configs")
print("=" * 85)
print(f"GPU: {props.name}")
print(f"VRAM: {props.total_memory / 1e9:.1f} GB")
print(f"Goliath available: {_GOLIATH_AVAILABLE}")
if _can_use_goliath_dot_scaled is not None:
print(f"Goliath dot_scaled (native FP4 TCs): {_can_use_goliath_dot_scaled()}")
print()
configs = [
('Goliath FP4', dict(goliath_bits=4, use_goliath=True)),
('Goliath FP8', dict(goliath_bits=8, use_goliath=True)),
('Goliath Auto', dict(goliath_bits='auto', use_goliath=True)),
('Legacy path', dict(goliath_bits=4, use_goliath=False)),
]
context_lengths = [512, 2048, 8192]
# --- Prefill benchmark ---
print("-" * 85)
print("PREFILL BENCHMARK")
print("-" * 85)
header = f"{'Config':<16} | {'Ctx':>5} | {'Prefill ms':>11} | {'Prefill tok/s':>14} | {'VRAM GB':>8}"
print(header)
print("-" * len(header))
for cfg_name, cfg_kwargs in configs:
try:
engine = create_bench_engine(**cfg_kwargs)
except Exception as e:
print(f"{cfg_name:<16} | {'ERROR':>5} | {str(e)[:40]}")
continue
for ctx in context_lengths:
try:
r = bench_prefill(engine, ctx)
print(f"{cfg_name:<16} | {ctx:>5} | {r['ms']:>9.2f}ms | {r['tok_s']:>12,.0f} | {r['vram_gb']:>7.2f}")
except Exception as e:
print(f"{cfg_name:<16} | {ctx:>5} | ERROR: {str(e)[:30]}")
del engine
gc.collect()
torch.cuda.empty_cache()
# --- Decode benchmark ---
print()
print("-" * 85)
print("DECODE BENCHMARK (50 tokens)")
print("-" * 85)
header = f"{'Config':<16} | {'Ctx':>5} | {'Decode tok/s':>13} | {'ms/token':>9} | {'VRAM GB':>8}"
print(header)
print("-" * len(header))
for cfg_name, cfg_kwargs in configs:
try:
engine = create_bench_engine(**cfg_kwargs)
except Exception as e:
print(f"{cfg_name:<16} | {'ERROR':>5} | {str(e)[:40]}")
continue
for ctx in context_lengths:
try:
r = bench_decode(engine, ctx)
print(f"{cfg_name:<16} | {ctx:>5} | {r['tok_s']:>11,.1f} | {r['per_token_ms']:>7.2f}ms | {r['vram_gb']:>7.2f}")
except Exception as e:
print(f"{cfg_name:<16} | {ctx:>5} | ERROR: {str(e)[:30]}")
del engine
gc.collect()
torch.cuda.empty_cache()
print()
print("=" * 85)
print("BENCHMARK COMPLETE")
print("=" * 85)
if __name__ == "__main__":
main()