#!/usr/bin/env python3 """ Benchmark: baseline vs torch.inference_mode vs torch.compile. Tests identical outputs and measures speedup. Usage: CUDA_VISIBLE_DEVICES=1 python test_kv_cache.py """ import time import torch import copy from types import SimpleNamespace from camel.camel_model import CAMEL from camel.process_ecg import get_waveform def run_inference(model, ecg_path, prompt, max_new_tokens=64): """Run inference through the CAMEL model and return output + timing.""" args = SimpleNamespace( mode='base', text=prompt, ecgs=[ecg_path], device=model.device, ecg_configs=None, json=None, temperature=0.0, top_k=64, top_p=0.95, min_p=0.0, max_new_tokens=max_new_tokens ) t0 = time.time() output, _ = model.run(args) elapsed = time.time() - t0 return output, elapsed def main(): device = 'cuda:0' ecg_path = 'demo/08704_hr' prompt = 'Describe the ECG.' max_tokens = 64 # --- Load model --- print("=== Loading model ===") t0 = time.time() model = CAMEL(mode='base', device=device) print(f"Model loaded in {time.time() - t0:.1f}s\n") # --- Warmup (1 run to warm CUDA) --- print("=== Warmup run ===") out_warmup, t_warmup = run_inference(model, ecg_path, prompt, max_tokens) print(f"Warmup: {t_warmup:.2f}s\n") # ============================================================ # Test 1: Baseline (no special context) # ============================================================ print("=== Test 1: Baseline (torch.no_grad, default) ===") with torch.no_grad(): out_baseline, t_baseline = run_inference(model, ecg_path, prompt, max_tokens) print(f"Time: {t_baseline:.2f}s") print(f"Output: {out_baseline[:150]}...\n") # ============================================================ # Test 2: torch.inference_mode # ============================================================ print("=== Test 2: torch.inference_mode ===") with torch.inference_mode(): out_infer, t_infer = run_inference(model, ecg_path, prompt, max_tokens) print(f"Time: {t_infer:.2f}s") print(f"Output: {out_infer[:150]}...\n") # ============================================================ # Test 3: torch.compile on the language model # ============================================================ print("=== Test 3: torch.compile (language model) ===") lm = model.session.wrapper.language_model original_lm = lm # keep reference # Compile the language model print("Compiling model...") t_compile_start = time.time() try: compiled_lm = torch.compile(lm, mode="reduce-overhead") model.session.wrapper.language_model = compiled_lm print(f"Compile call: {time.time() - t_compile_start:.2f}s") # First run triggers actual compilation print("First run (triggers compilation)...") with torch.inference_mode(): out_compile_first, t_compile_first = run_inference(model, ecg_path, prompt, max_tokens) print(f"First compiled run: {t_compile_first:.2f}s") # Second run uses compiled graph print("Second run (uses compiled graph)...") with torch.inference_mode(): out_compile, t_compile = run_inference(model, ecg_path, prompt, max_tokens) print(f"Second compiled run: {t_compile:.2f}s") print(f"Output: {out_compile[:150]}...\n") compile_success = True except Exception as e: print(f"torch.compile FAILED: {e}\n") compile_success = False t_compile = None out_compile = None model.session.wrapper.language_model = original_lm # ============================================================ # Summary # ============================================================ print("=" * 60) print("SUMMARY") print("=" * 60) print(f"Baseline (no_grad): {t_baseline:.2f}s") print(f"inference_mode: {t_infer:.2f}s") if compile_success: print(f"torch.compile (warmup): {t_compile_first:.2f}s") print(f"torch.compile (steady): {t_compile:.2f}s") print() # Compare outputs match_infer = out_baseline == out_infer print(f"Baseline vs inference_mode outputs match: {match_infer}") if not match_infer: print(f" Baseline: {repr(out_baseline[:200])}") print(f" Infer: {repr(out_infer[:200])}") if compile_success: match_compile = out_baseline == out_compile print(f"Baseline vs torch.compile outputs match: {match_compile}") if not match_compile: print(f" Baseline: {repr(out_baseline[:200])}") print(f" Compile: {repr(out_compile[:200])}") print() if t_infer > 0: print(f"inference_mode speedup: {t_baseline / t_infer:.2f}x") if compile_success and t_compile > 0: print(f"torch.compile speedup: {t_baseline / t_compile:.2f}x") print("=" * 60) if __name__ == "__main__": main()