srt-adapter-v8a / scripts /benchmark_latency.py
RiverRider's picture
Add scripts/reproduce.py + benchmark_latency.py; pin reqs to working transformers/torch range; refresh README with adapter overhead numbers
8e4e948 verified
#!/usr/bin/env python3
"""Latency and peak-VRAM benchmark for the SRT-Adapter v8a.
Measures forward-pass latency and peak GPU memory for:
1. Backbone-only (Qwen2.5-7B, no adapter)
2. Backbone + SRT-Adapter (full forward including all 4 readouts)
Reports tokens/sec and peak GiB at sequence lengths 64, 256, 512, with batch=1.
Usage:
python scripts/benchmark_latency.py
python scripts/benchmark_latency.py --warmup 5 --iters 20
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HERE = Path(__file__).resolve().parent
ROOT = HERE.parent
sys.path.insert(0, str((ROOT / "src").resolve()))
from srt.adapter import SRTAdapter # noqa: E402
from srt.config import ( # noqa: E402
SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig,
)
def build_config(p: Path) -> SRTConfig:
raw = json.loads(p.read_text())
return SRTConfig(
backbone_id=raw["backbone_id"],
backbone_dtype=raw["backbone_dtype"],
mah_layer_indices=list(raw["mah_layer_indices"]),
rrm_inject_indices=list(raw["rrm_inject_indices"]),
community_layer_idx=raw["community_layer_idx"],
num_mah_layers=raw["num_mah_layers"],
mah=MAHConfig(**raw["mah"]),
rrm=RRMConfig(**raw["rrm"]),
ben=BENConfig(**raw["ben"]),
community=CommunityConfig(**raw["community"]),
loss=LossConfig(**{k: v for k, v in raw["loss"].items() if k in LossConfig.__dataclass_fields__}),
)
def time_forward(fn, warmup: int, iters: int) -> tuple[float, float]:
"""Return (mean_sec, peak_bytes) over `iters` after `warmup`."""
for _ in range(warmup):
fn()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
fn()
torch.cuda.synchronize()
elapsed = (time.perf_counter() - t0) / iters
peak = torch.cuda.max_memory_allocated()
return elapsed, peak
def fmt_gib(b: int) -> str:
return f"{b / (1024 ** 3):.2f} GiB"
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--config", default=str(ROOT / "config.json"))
ap.add_argument("--adapter", default=str(ROOT / "adapter.safetensors"))
ap.add_argument("--seq-lens", type=int, nargs="+", default=[64, 256, 512])
ap.add_argument("--warmup", type=int, default=3)
ap.add_argument("--iters", type=int, default=10)
args = ap.parse_args()
if not torch.cuda.is_available():
print("CUDA required.", file=sys.stderr)
sys.exit(1)
device = "cuda"
gpu_name = torch.cuda.get_device_name(0)
print(f"GPU: {gpu_name}")
print(f"torch={torch.__version__}, dtype=bfloat16\n")
config = build_config(Path(args.config))
tok = AutoTokenizer.from_pretrained(config.backbone_id)
# Build a synthetic input pool we can slice
long_text = "The quick brown fox jumps over the lazy dog. " * 200
enc_full = tok(long_text, return_tensors="pt").to(device)
rows = []
# 1. Backbone-only baseline
print("=== Backbone-only (Qwen2.5-7B, bfloat16) ===")
backbone = AutoModelForCausalLM.from_pretrained(
config.backbone_id, torch_dtype=torch.bfloat16
).to(device)
backbone.eval()
for T in args.seq_lens:
ids = enc_full.input_ids[:, :T]
mask = enc_full.attention_mask[:, :T]
torch.cuda.empty_cache()
def fn():
with torch.no_grad():
backbone(input_ids=ids, attention_mask=mask)
sec, peak = time_forward(fn, args.warmup, args.iters)
tps = T / sec
print(f" T={T:4d} {sec*1000:7.2f} ms/fwd {tps:8.1f} tok/s peak={fmt_gib(peak)}")
rows.append({"variant": "backbone_only", "seq_len": T,
"ms_per_forward": sec * 1000, "tokens_per_sec": tps,
"peak_vram_gib": peak / (1024 ** 3)})
del backbone
torch.cuda.empty_cache()
# 2. Backbone + adapter
print("\n=== Backbone + SRT-Adapter v8a ===")
model = SRTAdapter(config).to(device)
if args.adapter.endswith(".safetensors"):
from safetensors.torch import load_file
state = load_file(args.adapter, device=device)
else:
state = torch.load(args.adapter, map_location=device)
model.load_state_dict(state, strict=False)
model.eval()
for T in args.seq_lens:
ids = enc_full.input_ids[:, :T]
mask = enc_full.attention_mask[:, :T]
torch.cuda.empty_cache()
def fn():
with torch.no_grad():
model(input_ids=ids, attention_mask=mask)
sec, peak = time_forward(fn, args.warmup, args.iters)
tps = T / sec
print(f" T={T:4d} {sec*1000:7.2f} ms/fwd {tps:8.1f} tok/s peak={fmt_gib(peak)}")
rows.append({"variant": "backbone_plus_adapter", "seq_len": T,
"ms_per_forward": sec * 1000, "tokens_per_sec": tps,
"peak_vram_gib": peak / (1024 ** 3)})
# Adapter overhead summary
print("\n=== Adapter overhead ===")
by = {(r["variant"], r["seq_len"]): r for r in rows}
for T in args.seq_lens:
b = by[("backbone_only", T)]
a = by[("backbone_plus_adapter", T)]
latency_overhead = (a["ms_per_forward"] / b["ms_per_forward"]) - 1.0
vram_overhead = a["peak_vram_gib"] - b["peak_vram_gib"]
print(f" T={T:4d} latency +{latency_overhead*100:5.1f}% "
f"vram +{vram_overhead:.2f} GiB")
out = ROOT / "benchmarks" / "latency_vram.json"
out.write_text(json.dumps({
"gpu": gpu_name,
"torch_version": torch.__version__,
"warmup": args.warmup,
"iters": args.iters,
"rows": rows,
}, indent=2))
print(f"\nwrote {out}")
if __name__ == "__main__":
main()