Add scripts/reproduce.py + benchmark_latency.py; pin reqs to working transformers/torch range; refresh README with adapter overhead numbers
8e4e948 verified | #!/usr/bin/env python3 | |
| """Latency and peak-VRAM benchmark for the SRT-Adapter v8a. | |
| Measures forward-pass latency and peak GPU memory for: | |
| 1. Backbone-only (Qwen2.5-7B, no adapter) | |
| 2. Backbone + SRT-Adapter (full forward including all 4 readouts) | |
| Reports tokens/sec and peak GiB at sequence lengths 64, 256, 512, with batch=1. | |
| Usage: | |
| python scripts/benchmark_latency.py | |
| python scripts/benchmark_latency.py --warmup 5 --iters 20 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| HERE = Path(__file__).resolve().parent | |
| ROOT = HERE.parent | |
| sys.path.insert(0, str((ROOT / "src").resolve())) | |
| from srt.adapter import SRTAdapter # noqa: E402 | |
| from srt.config import ( # noqa: E402 | |
| SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig, | |
| ) | |
| def build_config(p: Path) -> SRTConfig: | |
| raw = json.loads(p.read_text()) | |
| return SRTConfig( | |
| backbone_id=raw["backbone_id"], | |
| backbone_dtype=raw["backbone_dtype"], | |
| mah_layer_indices=list(raw["mah_layer_indices"]), | |
| rrm_inject_indices=list(raw["rrm_inject_indices"]), | |
| community_layer_idx=raw["community_layer_idx"], | |
| num_mah_layers=raw["num_mah_layers"], | |
| mah=MAHConfig(**raw["mah"]), | |
| rrm=RRMConfig(**raw["rrm"]), | |
| ben=BENConfig(**raw["ben"]), | |
| community=CommunityConfig(**raw["community"]), | |
| loss=LossConfig(**{k: v for k, v in raw["loss"].items() if k in LossConfig.__dataclass_fields__}), | |
| ) | |
| def time_forward(fn, warmup: int, iters: int) -> tuple[float, float]: | |
| """Return (mean_sec, peak_bytes) over `iters` after `warmup`.""" | |
| for _ in range(warmup): | |
| fn() | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| for _ in range(iters): | |
| fn() | |
| torch.cuda.synchronize() | |
| elapsed = (time.perf_counter() - t0) / iters | |
| peak = torch.cuda.max_memory_allocated() | |
| return elapsed, peak | |
| def fmt_gib(b: int) -> str: | |
| return f"{b / (1024 ** 3):.2f} GiB" | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--config", default=str(ROOT / "config.json")) | |
| ap.add_argument("--adapter", default=str(ROOT / "adapter.safetensors")) | |
| ap.add_argument("--seq-lens", type=int, nargs="+", default=[64, 256, 512]) | |
| ap.add_argument("--warmup", type=int, default=3) | |
| ap.add_argument("--iters", type=int, default=10) | |
| args = ap.parse_args() | |
| if not torch.cuda.is_available(): | |
| print("CUDA required.", file=sys.stderr) | |
| sys.exit(1) | |
| device = "cuda" | |
| gpu_name = torch.cuda.get_device_name(0) | |
| print(f"GPU: {gpu_name}") | |
| print(f"torch={torch.__version__}, dtype=bfloat16\n") | |
| config = build_config(Path(args.config)) | |
| tok = AutoTokenizer.from_pretrained(config.backbone_id) | |
| # Build a synthetic input pool we can slice | |
| long_text = "The quick brown fox jumps over the lazy dog. " * 200 | |
| enc_full = tok(long_text, return_tensors="pt").to(device) | |
| rows = [] | |
| # 1. Backbone-only baseline | |
| print("=== Backbone-only (Qwen2.5-7B, bfloat16) ===") | |
| backbone = AutoModelForCausalLM.from_pretrained( | |
| config.backbone_id, torch_dtype=torch.bfloat16 | |
| ).to(device) | |
| backbone.eval() | |
| for T in args.seq_lens: | |
| ids = enc_full.input_ids[:, :T] | |
| mask = enc_full.attention_mask[:, :T] | |
| torch.cuda.empty_cache() | |
| def fn(): | |
| with torch.no_grad(): | |
| backbone(input_ids=ids, attention_mask=mask) | |
| sec, peak = time_forward(fn, args.warmup, args.iters) | |
| tps = T / sec | |
| print(f" T={T:4d} {sec*1000:7.2f} ms/fwd {tps:8.1f} tok/s peak={fmt_gib(peak)}") | |
| rows.append({"variant": "backbone_only", "seq_len": T, | |
| "ms_per_forward": sec * 1000, "tokens_per_sec": tps, | |
| "peak_vram_gib": peak / (1024 ** 3)}) | |
| del backbone | |
| torch.cuda.empty_cache() | |
| # 2. Backbone + adapter | |
| print("\n=== Backbone + SRT-Adapter v8a ===") | |
| model = SRTAdapter(config).to(device) | |
| if args.adapter.endswith(".safetensors"): | |
| from safetensors.torch import load_file | |
| state = load_file(args.adapter, device=device) | |
| else: | |
| state = torch.load(args.adapter, map_location=device) | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| for T in args.seq_lens: | |
| ids = enc_full.input_ids[:, :T] | |
| mask = enc_full.attention_mask[:, :T] | |
| torch.cuda.empty_cache() | |
| def fn(): | |
| with torch.no_grad(): | |
| model(input_ids=ids, attention_mask=mask) | |
| sec, peak = time_forward(fn, args.warmup, args.iters) | |
| tps = T / sec | |
| print(f" T={T:4d} {sec*1000:7.2f} ms/fwd {tps:8.1f} tok/s peak={fmt_gib(peak)}") | |
| rows.append({"variant": "backbone_plus_adapter", "seq_len": T, | |
| "ms_per_forward": sec * 1000, "tokens_per_sec": tps, | |
| "peak_vram_gib": peak / (1024 ** 3)}) | |
| # Adapter overhead summary | |
| print("\n=== Adapter overhead ===") | |
| by = {(r["variant"], r["seq_len"]): r for r in rows} | |
| for T in args.seq_lens: | |
| b = by[("backbone_only", T)] | |
| a = by[("backbone_plus_adapter", T)] | |
| latency_overhead = (a["ms_per_forward"] / b["ms_per_forward"]) - 1.0 | |
| vram_overhead = a["peak_vram_gib"] - b["peak_vram_gib"] | |
| print(f" T={T:4d} latency +{latency_overhead*100:5.1f}% " | |
| f"vram +{vram_overhead:.2f} GiB") | |
| out = ROOT / "benchmarks" / "latency_vram.json" | |
| out.write_text(json.dumps({ | |
| "gpu": gpu_name, | |
| "torch_version": torch.__version__, | |
| "warmup": args.warmup, | |
| "iters": args.iters, | |
| "rows": rows, | |
| }, indent=2)) | |
| print(f"\nwrote {out}") | |
| if __name__ == "__main__": | |
| main() | |