CAMEL / test_kv_cache.py
Mayank Keoliya
perf: pre-load model on CPU, inference_mode, remove Load buttons, auto-load ECGs
a90aca1
#!/usr/bin/env python3
"""
Benchmark: baseline vs torch.inference_mode vs torch.compile.
Tests identical outputs and measures speedup.
Usage: CUDA_VISIBLE_DEVICES=1 python test_kv_cache.py
"""
import time
import torch
import copy
from types import SimpleNamespace
from camel.camel_model import CAMEL
from camel.process_ecg import get_waveform
def run_inference(model, ecg_path, prompt, max_new_tokens=64):
"""Run inference through the CAMEL model and return output + timing."""
args = SimpleNamespace(
mode='base', text=prompt, ecgs=[ecg_path], device=model.device,
ecg_configs=None, json=None, temperature=0.0, top_k=64,
top_p=0.95, min_p=0.0, max_new_tokens=max_new_tokens
)
t0 = time.time()
output, _ = model.run(args)
elapsed = time.time() - t0
return output, elapsed
def main():
device = 'cuda:0'
ecg_path = 'demo/08704_hr'
prompt = 'Describe the ECG.'
max_tokens = 64
# --- Load model ---
print("=== Loading model ===")
t0 = time.time()
model = CAMEL(mode='base', device=device)
print(f"Model loaded in {time.time() - t0:.1f}s\n")
# --- Warmup (1 run to warm CUDA) ---
print("=== Warmup run ===")
out_warmup, t_warmup = run_inference(model, ecg_path, prompt, max_tokens)
print(f"Warmup: {t_warmup:.2f}s\n")
# ============================================================
# Test 1: Baseline (no special context)
# ============================================================
print("=== Test 1: Baseline (torch.no_grad, default) ===")
with torch.no_grad():
out_baseline, t_baseline = run_inference(model, ecg_path, prompt, max_tokens)
print(f"Time: {t_baseline:.2f}s")
print(f"Output: {out_baseline[:150]}...\n")
# ============================================================
# Test 2: torch.inference_mode
# ============================================================
print("=== Test 2: torch.inference_mode ===")
with torch.inference_mode():
out_infer, t_infer = run_inference(model, ecg_path, prompt, max_tokens)
print(f"Time: {t_infer:.2f}s")
print(f"Output: {out_infer[:150]}...\n")
# ============================================================
# Test 3: torch.compile on the language model
# ============================================================
print("=== Test 3: torch.compile (language model) ===")
lm = model.session.wrapper.language_model
original_lm = lm # keep reference
# Compile the language model
print("Compiling model...")
t_compile_start = time.time()
try:
compiled_lm = torch.compile(lm, mode="reduce-overhead")
model.session.wrapper.language_model = compiled_lm
print(f"Compile call: {time.time() - t_compile_start:.2f}s")
# First run triggers actual compilation
print("First run (triggers compilation)...")
with torch.inference_mode():
out_compile_first, t_compile_first = run_inference(model, ecg_path, prompt, max_tokens)
print(f"First compiled run: {t_compile_first:.2f}s")
# Second run uses compiled graph
print("Second run (uses compiled graph)...")
with torch.inference_mode():
out_compile, t_compile = run_inference(model, ecg_path, prompt, max_tokens)
print(f"Second compiled run: {t_compile:.2f}s")
print(f"Output: {out_compile[:150]}...\n")
compile_success = True
except Exception as e:
print(f"torch.compile FAILED: {e}\n")
compile_success = False
t_compile = None
out_compile = None
model.session.wrapper.language_model = original_lm
# ============================================================
# Summary
# ============================================================
print("=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Baseline (no_grad): {t_baseline:.2f}s")
print(f"inference_mode: {t_infer:.2f}s")
if compile_success:
print(f"torch.compile (warmup): {t_compile_first:.2f}s")
print(f"torch.compile (steady): {t_compile:.2f}s")
print()
# Compare outputs
match_infer = out_baseline == out_infer
print(f"Baseline vs inference_mode outputs match: {match_infer}")
if not match_infer:
print(f" Baseline: {repr(out_baseline[:200])}")
print(f" Infer: {repr(out_infer[:200])}")
if compile_success:
match_compile = out_baseline == out_compile
print(f"Baseline vs torch.compile outputs match: {match_compile}")
if not match_compile:
print(f" Baseline: {repr(out_baseline[:200])}")
print(f" Compile: {repr(out_compile[:200])}")
print()
if t_infer > 0:
print(f"inference_mode speedup: {t_baseline / t_infer:.2f}x")
if compile_success and t_compile > 0:
print(f"torch.compile speedup: {t_baseline / t_compile:.2f}x")
print("=" * 60)
if __name__ == "__main__":
main()