File size: 5,018 Bytes
a90aca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3
"""
Benchmark: baseline vs torch.inference_mode vs torch.compile.
Tests identical outputs and measures speedup.

Usage: CUDA_VISIBLE_DEVICES=1 python test_kv_cache.py
"""
import time
import torch
import copy
from types import SimpleNamespace
from camel.camel_model import CAMEL
from camel.process_ecg import get_waveform


def run_inference(model, ecg_path, prompt, max_new_tokens=64):
    """Run inference through the CAMEL model and return output + timing."""
    args = SimpleNamespace(
        mode='base', text=prompt, ecgs=[ecg_path], device=model.device,
        ecg_configs=None, json=None, temperature=0.0, top_k=64,
        top_p=0.95, min_p=0.0, max_new_tokens=max_new_tokens
    )
    t0 = time.time()
    output, _ = model.run(args)
    elapsed = time.time() - t0
    return output, elapsed


def main():
    device = 'cuda:0'
    ecg_path = 'demo/08704_hr'
    prompt = 'Describe the ECG.'
    max_tokens = 64

    # --- Load model ---
    print("=== Loading model ===")
    t0 = time.time()
    model = CAMEL(mode='base', device=device)
    print(f"Model loaded in {time.time() - t0:.1f}s\n")

    # --- Warmup (1 run to warm CUDA) ---
    print("=== Warmup run ===")
    out_warmup, t_warmup = run_inference(model, ecg_path, prompt, max_tokens)
    print(f"Warmup: {t_warmup:.2f}s\n")

    # ============================================================
    # Test 1: Baseline (no special context)
    # ============================================================
    print("=== Test 1: Baseline (torch.no_grad, default) ===")
    with torch.no_grad():
        out_baseline, t_baseline = run_inference(model, ecg_path, prompt, max_tokens)
    print(f"Time: {t_baseline:.2f}s")
    print(f"Output: {out_baseline[:150]}...\n")

    # ============================================================
    # Test 2: torch.inference_mode
    # ============================================================
    print("=== Test 2: torch.inference_mode ===")
    with torch.inference_mode():
        out_infer, t_infer = run_inference(model, ecg_path, prompt, max_tokens)
    print(f"Time: {t_infer:.2f}s")
    print(f"Output: {out_infer[:150]}...\n")

    # ============================================================
    # Test 3: torch.compile on the language model
    # ============================================================
    print("=== Test 3: torch.compile (language model) ===")
    lm = model.session.wrapper.language_model
    original_lm = lm  # keep reference

    # Compile the language model
    print("Compiling model...")
    t_compile_start = time.time()
    try:
        compiled_lm = torch.compile(lm, mode="reduce-overhead")
        model.session.wrapper.language_model = compiled_lm
        print(f"Compile call: {time.time() - t_compile_start:.2f}s")

        # First run triggers actual compilation
        print("First run (triggers compilation)...")
        with torch.inference_mode():
            out_compile_first, t_compile_first = run_inference(model, ecg_path, prompt, max_tokens)
        print(f"First compiled run: {t_compile_first:.2f}s")

        # Second run uses compiled graph
        print("Second run (uses compiled graph)...")
        with torch.inference_mode():
            out_compile, t_compile = run_inference(model, ecg_path, prompt, max_tokens)
        print(f"Second compiled run: {t_compile:.2f}s")
        print(f"Output: {out_compile[:150]}...\n")

        compile_success = True
    except Exception as e:
        print(f"torch.compile FAILED: {e}\n")
        compile_success = False
        t_compile = None
        out_compile = None
        model.session.wrapper.language_model = original_lm

    # ============================================================
    # Summary
    # ============================================================
    print("=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"Baseline (no_grad):      {t_baseline:.2f}s")
    print(f"inference_mode:          {t_infer:.2f}s")
    if compile_success:
        print(f"torch.compile (warmup):  {t_compile_first:.2f}s")
        print(f"torch.compile (steady):  {t_compile:.2f}s")
    print()

    # Compare outputs
    match_infer = out_baseline == out_infer
    print(f"Baseline vs inference_mode outputs match: {match_infer}")
    if not match_infer:
        print(f"  Baseline: {repr(out_baseline[:200])}")
        print(f"  Infer:    {repr(out_infer[:200])}")

    if compile_success:
        match_compile = out_baseline == out_compile
        print(f"Baseline vs torch.compile outputs match: {match_compile}")
        if not match_compile:
            print(f"  Baseline: {repr(out_baseline[:200])}")
            print(f"  Compile:  {repr(out_compile[:200])}")

    print()
    if t_infer > 0:
        print(f"inference_mode speedup: {t_baseline / t_infer:.2f}x")
    if compile_success and t_compile > 0:
        print(f"torch.compile speedup:  {t_baseline / t_compile:.2f}x")
    print("=" * 60)


if __name__ == "__main__":
    main()