Mayank Keoliya
perf: pre-load model on CPU, inference_mode, remove Load buttons, auto-load ECGs
a90aca1 | #!/usr/bin/env python3 | |
| """ | |
| Benchmark: baseline vs torch.inference_mode vs torch.compile. | |
| Tests identical outputs and measures speedup. | |
| Usage: CUDA_VISIBLE_DEVICES=1 python test_kv_cache.py | |
| """ | |
| import time | |
| import torch | |
| import copy | |
| from types import SimpleNamespace | |
| from camel.camel_model import CAMEL | |
| from camel.process_ecg import get_waveform | |
| def run_inference(model, ecg_path, prompt, max_new_tokens=64): | |
| """Run inference through the CAMEL model and return output + timing.""" | |
| args = SimpleNamespace( | |
| mode='base', text=prompt, ecgs=[ecg_path], device=model.device, | |
| ecg_configs=None, json=None, temperature=0.0, top_k=64, | |
| top_p=0.95, min_p=0.0, max_new_tokens=max_new_tokens | |
| ) | |
| t0 = time.time() | |
| output, _ = model.run(args) | |
| elapsed = time.time() - t0 | |
| return output, elapsed | |
| def main(): | |
| device = 'cuda:0' | |
| ecg_path = 'demo/08704_hr' | |
| prompt = 'Describe the ECG.' | |
| max_tokens = 64 | |
| # --- Load model --- | |
| print("=== Loading model ===") | |
| t0 = time.time() | |
| model = CAMEL(mode='base', device=device) | |
| print(f"Model loaded in {time.time() - t0:.1f}s\n") | |
| # --- Warmup (1 run to warm CUDA) --- | |
| print("=== Warmup run ===") | |
| out_warmup, t_warmup = run_inference(model, ecg_path, prompt, max_tokens) | |
| print(f"Warmup: {t_warmup:.2f}s\n") | |
| # ============================================================ | |
| # Test 1: Baseline (no special context) | |
| # ============================================================ | |
| print("=== Test 1: Baseline (torch.no_grad, default) ===") | |
| with torch.no_grad(): | |
| out_baseline, t_baseline = run_inference(model, ecg_path, prompt, max_tokens) | |
| print(f"Time: {t_baseline:.2f}s") | |
| print(f"Output: {out_baseline[:150]}...\n") | |
| # ============================================================ | |
| # Test 2: torch.inference_mode | |
| # ============================================================ | |
| print("=== Test 2: torch.inference_mode ===") | |
| with torch.inference_mode(): | |
| out_infer, t_infer = run_inference(model, ecg_path, prompt, max_tokens) | |
| print(f"Time: {t_infer:.2f}s") | |
| print(f"Output: {out_infer[:150]}...\n") | |
| # ============================================================ | |
| # Test 3: torch.compile on the language model | |
| # ============================================================ | |
| print("=== Test 3: torch.compile (language model) ===") | |
| lm = model.session.wrapper.language_model | |
| original_lm = lm # keep reference | |
| # Compile the language model | |
| print("Compiling model...") | |
| t_compile_start = time.time() | |
| try: | |
| compiled_lm = torch.compile(lm, mode="reduce-overhead") | |
| model.session.wrapper.language_model = compiled_lm | |
| print(f"Compile call: {time.time() - t_compile_start:.2f}s") | |
| # First run triggers actual compilation | |
| print("First run (triggers compilation)...") | |
| with torch.inference_mode(): | |
| out_compile_first, t_compile_first = run_inference(model, ecg_path, prompt, max_tokens) | |
| print(f"First compiled run: {t_compile_first:.2f}s") | |
| # Second run uses compiled graph | |
| print("Second run (uses compiled graph)...") | |
| with torch.inference_mode(): | |
| out_compile, t_compile = run_inference(model, ecg_path, prompt, max_tokens) | |
| print(f"Second compiled run: {t_compile:.2f}s") | |
| print(f"Output: {out_compile[:150]}...\n") | |
| compile_success = True | |
| except Exception as e: | |
| print(f"torch.compile FAILED: {e}\n") | |
| compile_success = False | |
| t_compile = None | |
| out_compile = None | |
| model.session.wrapper.language_model = original_lm | |
| # ============================================================ | |
| # Summary | |
| # ============================================================ | |
| print("=" * 60) | |
| print("SUMMARY") | |
| print("=" * 60) | |
| print(f"Baseline (no_grad): {t_baseline:.2f}s") | |
| print(f"inference_mode: {t_infer:.2f}s") | |
| if compile_success: | |
| print(f"torch.compile (warmup): {t_compile_first:.2f}s") | |
| print(f"torch.compile (steady): {t_compile:.2f}s") | |
| print() | |
| # Compare outputs | |
| match_infer = out_baseline == out_infer | |
| print(f"Baseline vs inference_mode outputs match: {match_infer}") | |
| if not match_infer: | |
| print(f" Baseline: {repr(out_baseline[:200])}") | |
| print(f" Infer: {repr(out_infer[:200])}") | |
| if compile_success: | |
| match_compile = out_baseline == out_compile | |
| print(f"Baseline vs torch.compile outputs match: {match_compile}") | |
| if not match_compile: | |
| print(f" Baseline: {repr(out_baseline[:200])}") | |
| print(f" Compile: {repr(out_compile[:200])}") | |
| print() | |
| if t_infer > 0: | |
| print(f"inference_mode speedup: {t_baseline / t_infer:.2f}x") | |
| if compile_success and t_compile > 0: | |
| print(f"torch.compile speedup: {t_baseline / t_compile:.2f}x") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |