#!/usr/bin/env python3 """Latency and peak-VRAM benchmark for the SRT-Adapter v8a. Measures forward-pass latency and peak GPU memory for: 1. Backbone-only (Qwen2.5-7B, no adapter) 2. Backbone + SRT-Adapter (full forward including all 4 readouts) Reports tokens/sec and peak GiB at sequence lengths 64, 256, 512, with batch=1. Usage: python scripts/benchmark_latency.py python scripts/benchmark_latency.py --warmup 5 --iters 20 """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path import torch from transformers import AutoModelForCausalLM, AutoTokenizer HERE = Path(__file__).resolve().parent ROOT = HERE.parent sys.path.insert(0, str((ROOT / "src").resolve())) from srt.adapter import SRTAdapter # noqa: E402 from srt.config import ( # noqa: E402 SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig, ) def build_config(p: Path) -> SRTConfig: raw = json.loads(p.read_text()) return SRTConfig( backbone_id=raw["backbone_id"], backbone_dtype=raw["backbone_dtype"], mah_layer_indices=list(raw["mah_layer_indices"]), rrm_inject_indices=list(raw["rrm_inject_indices"]), community_layer_idx=raw["community_layer_idx"], num_mah_layers=raw["num_mah_layers"], mah=MAHConfig(**raw["mah"]), rrm=RRMConfig(**raw["rrm"]), ben=BENConfig(**raw["ben"]), community=CommunityConfig(**raw["community"]), loss=LossConfig(**{k: v for k, v in raw["loss"].items() if k in LossConfig.__dataclass_fields__}), ) def time_forward(fn, warmup: int, iters: int) -> tuple[float, float]: """Return (mean_sec, peak_bytes) over `iters` after `warmup`.""" for _ in range(warmup): fn() torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(iters): fn() torch.cuda.synchronize() elapsed = (time.perf_counter() - t0) / iters peak = torch.cuda.max_memory_allocated() return elapsed, peak def fmt_gib(b: int) -> str: return f"{b / (1024 ** 3):.2f} GiB" def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--config", default=str(ROOT / "config.json")) ap.add_argument("--adapter", default=str(ROOT / "adapter.safetensors")) ap.add_argument("--seq-lens", type=int, nargs="+", default=[64, 256, 512]) ap.add_argument("--warmup", type=int, default=3) ap.add_argument("--iters", type=int, default=10) args = ap.parse_args() if not torch.cuda.is_available(): print("CUDA required.", file=sys.stderr) sys.exit(1) device = "cuda" gpu_name = torch.cuda.get_device_name(0) print(f"GPU: {gpu_name}") print(f"torch={torch.__version__}, dtype=bfloat16\n") config = build_config(Path(args.config)) tok = AutoTokenizer.from_pretrained(config.backbone_id) # Build a synthetic input pool we can slice long_text = "The quick brown fox jumps over the lazy dog. " * 200 enc_full = tok(long_text, return_tensors="pt").to(device) rows = [] # 1. Backbone-only baseline print("=== Backbone-only (Qwen2.5-7B, bfloat16) ===") backbone = AutoModelForCausalLM.from_pretrained( config.backbone_id, torch_dtype=torch.bfloat16 ).to(device) backbone.eval() for T in args.seq_lens: ids = enc_full.input_ids[:, :T] mask = enc_full.attention_mask[:, :T] torch.cuda.empty_cache() def fn(): with torch.no_grad(): backbone(input_ids=ids, attention_mask=mask) sec, peak = time_forward(fn, args.warmup, args.iters) tps = T / sec print(f" T={T:4d} {sec*1000:7.2f} ms/fwd {tps:8.1f} tok/s peak={fmt_gib(peak)}") rows.append({"variant": "backbone_only", "seq_len": T, "ms_per_forward": sec * 1000, "tokens_per_sec": tps, "peak_vram_gib": peak / (1024 ** 3)}) del backbone torch.cuda.empty_cache() # 2. Backbone + adapter print("\n=== Backbone + SRT-Adapter v8a ===") model = SRTAdapter(config).to(device) if args.adapter.endswith(".safetensors"): from safetensors.torch import load_file state = load_file(args.adapter, device=device) else: state = torch.load(args.adapter, map_location=device) model.load_state_dict(state, strict=False) model.eval() for T in args.seq_lens: ids = enc_full.input_ids[:, :T] mask = enc_full.attention_mask[:, :T] torch.cuda.empty_cache() def fn(): with torch.no_grad(): model(input_ids=ids, attention_mask=mask) sec, peak = time_forward(fn, args.warmup, args.iters) tps = T / sec print(f" T={T:4d} {sec*1000:7.2f} ms/fwd {tps:8.1f} tok/s peak={fmt_gib(peak)}") rows.append({"variant": "backbone_plus_adapter", "seq_len": T, "ms_per_forward": sec * 1000, "tokens_per_sec": tps, "peak_vram_gib": peak / (1024 ** 3)}) # Adapter overhead summary print("\n=== Adapter overhead ===") by = {(r["variant"], r["seq_len"]): r for r in rows} for T in args.seq_lens: b = by[("backbone_only", T)] a = by[("backbone_plus_adapter", T)] latency_overhead = (a["ms_per_forward"] / b["ms_per_forward"]) - 1.0 vram_overhead = a["peak_vram_gib"] - b["peak_vram_gib"] print(f" T={T:4d} latency +{latency_overhead*100:5.1f}% " f"vram +{vram_overhead:.2f} GiB") out = ROOT / "benchmarks" / "latency_vram.json" out.write_text(json.dumps({ "gpu": gpu_name, "torch_version": torch.__version__, "warmup": args.warmup, "iters": args.iters, "rows": rows, }, indent=2)) print(f"\nwrote {out}") if __name__ == "__main__": main()