File size: 5,018 Bytes
a90aca1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | #!/usr/bin/env python3
"""
Benchmark: baseline vs torch.inference_mode vs torch.compile.
Tests identical outputs and measures speedup.
Usage: CUDA_VISIBLE_DEVICES=1 python test_kv_cache.py
"""
import time
import torch
import copy
from types import SimpleNamespace
from camel.camel_model import CAMEL
from camel.process_ecg import get_waveform
def run_inference(model, ecg_path, prompt, max_new_tokens=64):
"""Run inference through the CAMEL model and return output + timing."""
args = SimpleNamespace(
mode='base', text=prompt, ecgs=[ecg_path], device=model.device,
ecg_configs=None, json=None, temperature=0.0, top_k=64,
top_p=0.95, min_p=0.0, max_new_tokens=max_new_tokens
)
t0 = time.time()
output, _ = model.run(args)
elapsed = time.time() - t0
return output, elapsed
def main():
device = 'cuda:0'
ecg_path = 'demo/08704_hr'
prompt = 'Describe the ECG.'
max_tokens = 64
# --- Load model ---
print("=== Loading model ===")
t0 = time.time()
model = CAMEL(mode='base', device=device)
print(f"Model loaded in {time.time() - t0:.1f}s\n")
# --- Warmup (1 run to warm CUDA) ---
print("=== Warmup run ===")
out_warmup, t_warmup = run_inference(model, ecg_path, prompt, max_tokens)
print(f"Warmup: {t_warmup:.2f}s\n")
# ============================================================
# Test 1: Baseline (no special context)
# ============================================================
print("=== Test 1: Baseline (torch.no_grad, default) ===")
with torch.no_grad():
out_baseline, t_baseline = run_inference(model, ecg_path, prompt, max_tokens)
print(f"Time: {t_baseline:.2f}s")
print(f"Output: {out_baseline[:150]}...\n")
# ============================================================
# Test 2: torch.inference_mode
# ============================================================
print("=== Test 2: torch.inference_mode ===")
with torch.inference_mode():
out_infer, t_infer = run_inference(model, ecg_path, prompt, max_tokens)
print(f"Time: {t_infer:.2f}s")
print(f"Output: {out_infer[:150]}...\n")
# ============================================================
# Test 3: torch.compile on the language model
# ============================================================
print("=== Test 3: torch.compile (language model) ===")
lm = model.session.wrapper.language_model
original_lm = lm # keep reference
# Compile the language model
print("Compiling model...")
t_compile_start = time.time()
try:
compiled_lm = torch.compile(lm, mode="reduce-overhead")
model.session.wrapper.language_model = compiled_lm
print(f"Compile call: {time.time() - t_compile_start:.2f}s")
# First run triggers actual compilation
print("First run (triggers compilation)...")
with torch.inference_mode():
out_compile_first, t_compile_first = run_inference(model, ecg_path, prompt, max_tokens)
print(f"First compiled run: {t_compile_first:.2f}s")
# Second run uses compiled graph
print("Second run (uses compiled graph)...")
with torch.inference_mode():
out_compile, t_compile = run_inference(model, ecg_path, prompt, max_tokens)
print(f"Second compiled run: {t_compile:.2f}s")
print(f"Output: {out_compile[:150]}...\n")
compile_success = True
except Exception as e:
print(f"torch.compile FAILED: {e}\n")
compile_success = False
t_compile = None
out_compile = None
model.session.wrapper.language_model = original_lm
# ============================================================
# Summary
# ============================================================
print("=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Baseline (no_grad): {t_baseline:.2f}s")
print(f"inference_mode: {t_infer:.2f}s")
if compile_success:
print(f"torch.compile (warmup): {t_compile_first:.2f}s")
print(f"torch.compile (steady): {t_compile:.2f}s")
print()
# Compare outputs
match_infer = out_baseline == out_infer
print(f"Baseline vs inference_mode outputs match: {match_infer}")
if not match_infer:
print(f" Baseline: {repr(out_baseline[:200])}")
print(f" Infer: {repr(out_infer[:200])}")
if compile_success:
match_compile = out_baseline == out_compile
print(f"Baseline vs torch.compile outputs match: {match_compile}")
if not match_compile:
print(f" Baseline: {repr(out_baseline[:200])}")
print(f" Compile: {repr(out_compile[:200])}")
print()
if t_infer > 0:
print(f"inference_mode speedup: {t_baseline / t_infer:.2f}x")
if compile_success and t_compile > 0:
print(f"torch.compile speedup: {t_baseline / t_compile:.2f}x")
print("=" * 60)
if __name__ == "__main__":
main()
|