File size: 4,869 Bytes
176b11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
test_e2e.py
-----------
End-to-end SparseVLM test on Qwen2.5-VL-7B-Instruct.

Downloads the model on first run (~15 GB). Runs three configurations:
  - Baseline (no pruning)
  - SparseVLM n_vis=128
  - SparseVLM n_vis=64

Reports latency, speedup, and checks the output is coherent.

Usage:
    python test_e2e.py
"""

import sys, time, io, requests
import torch
from PIL import Image
from sparsevlm import apply_sparsevlm, reset_n_vis
from sparsevlm.patch import remove_hooks

# Qwen2.5-VL uses Qwen2_5_VLForConditionalGeneration in transformers >= 4.49;
# fall back to the older name for 4.48.x and below.
try:
    from transformers import Qwen2_5_VLForConditionalGeneration as QwenVLModel, AutoProcessor
except ImportError:
    from transformers import Qwen2VLForConditionalGeneration as QwenVLModel, AutoProcessor

MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 has float32 dynamic range β€” avoids NaN overflow with eager attention
DTYPE    = torch.bfloat16 if torch.cuda.is_available() else torch.float32


def load_test_image() -> Image.Image:
    url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/280px-PNG_transparency_demonstration_1.png"
    try:
        resp = requests.get(url, timeout=10)
        img = Image.open(io.BytesIO(resp.content)).convert("RGB")
        print(f"Loaded test image from web: {img.size}")
        return img
    except Exception:
        img = Image.new("RGB", (448, 448), color=(100, 149, 237))
        print("Using synthetic test image (448x448 cornflower blue).")
        return img


def timeit_generate(model, inputs, n_warmup=2, n_runs=5, state=None, n_vis=None):
    for _ in range(n_warmup):
        if state is not None:
            reset_n_vis(state, n_vis)
        with torch.no_grad():
            model.generate(**inputs, max_new_tokens=32)
    if DEVICE == "cuda":
        torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(n_runs):
        if state is not None:
            reset_n_vis(state, n_vis)
        with torch.no_grad():
            out = model.generate(**inputs, max_new_tokens=32)
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    elapsed = (time.perf_counter() - t0) / n_runs * 1000
    return elapsed, out


def main():
    print(f"\n=== SparseVLM End-to-End Test ===")
    print(f"Device: {DEVICE}")
    if DEVICE == "cuda":
        print(f"GPU:    {torch.cuda.get_device_name(0)}")
        print(f"VRAM:   {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

    print(f"\nLoading {MODEL_ID} ...")
    # eager attention required so attention weights are returned (flash attn returns None)
    model = QwenVLModel.from_pretrained(
        MODEL_ID,
        torch_dtype=DTYPE,
        device_map="auto",
        attn_implementation="eager",
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    model.eval()
    print("Model loaded.")

    image = load_test_image()
    prompt = "Describe this image in one sentence."

    messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], return_tensors="pt").to(DEVICE)

    print("\n── Baseline (no SparseVLM) ──────────────────────────────────────")
    baseline_ms, out = timeit_generate(model, inputs)
    response = processor.decode(out[0], skip_special_tokens=True)
    print(f"Latency: {baseline_ms:.1f} ms")
    print(f"Output:  {response[:120]}")

    configs = [
        ("SparseVLM n_vis=128 (50%)", 128),
        ("SparseVLM n_vis=96  (37%)", 96),
        ("SparseVLM n_vis=64  (25%)", 64),
    ]

    print(f"\n{'Config':<30} {'Latency':>10} {'Speedup':>10}  Output")
    print("─" * 100)

    for label, n_vis in configs:
        state = apply_sparsevlm(model, n_vis=n_vis)
        reset_n_vis(state, n_vis=n_vis)

        ms, out = timeit_generate(model, inputs, state=state, n_vis=n_vis)
        response = processor.decode(out[0], skip_special_tokens=True)
        speedup  = baseline_ms / ms

        print(f"{label:<30} {ms:>9.1f}ms {speedup:>9.1f}x  {response[:60]}")
        remove_hooks(state)

    print("\n── Layer 1 kernel benchmark ─────────────────────────────────────")
    import subprocess, sys, os
    bench = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark", "bench_layer1.py")
    subprocess.run([sys.executable, bench], check=True)

    print("\n=== Test complete. Update README.md benchmark table with the numbers above. ===")


if __name__ == "__main__":
    main()