| """ |
| test_e2e.py |
| ----------- |
| End-to-end SparseVLM test on Qwen2.5-VL-7B-Instruct. |
| |
| Downloads the model on first run (~15 GB). Runs three configurations: |
| - Baseline (no pruning) |
| - SparseVLM n_vis=128 |
| - SparseVLM n_vis=64 |
| |
| Reports latency, speedup, and checks the output is coherent. |
| |
| Usage: |
| python test_e2e.py |
| """ |
|
|
| import sys, time, io, requests |
| import torch |
| from PIL import Image |
| from sparsevlm import apply_sparsevlm, reset_n_vis |
| from sparsevlm.patch import remove_hooks |
|
|
| |
| |
| try: |
| from transformers import Qwen2_5_VLForConditionalGeneration as QwenVLModel, AutoProcessor |
| except ImportError: |
| from transformers import Qwen2VLForConditionalGeneration as QwenVLModel, AutoProcessor |
|
|
| MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
| def load_test_image() -> Image.Image: |
| url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/280px-PNG_transparency_demonstration_1.png" |
| try: |
| resp = requests.get(url, timeout=10) |
| img = Image.open(io.BytesIO(resp.content)).convert("RGB") |
| print(f"Loaded test image from web: {img.size}") |
| return img |
| except Exception: |
| img = Image.new("RGB", (448, 448), color=(100, 149, 237)) |
| print("Using synthetic test image (448x448 cornflower blue).") |
| return img |
|
|
|
|
| def timeit_generate(model, inputs, n_warmup=2, n_runs=5, state=None, n_vis=None): |
| for _ in range(n_warmup): |
| if state is not None: |
| reset_n_vis(state, n_vis) |
| with torch.no_grad(): |
| model.generate(**inputs, max_new_tokens=32) |
| if DEVICE == "cuda": |
| torch.cuda.synchronize() |
|
|
| t0 = time.perf_counter() |
| for _ in range(n_runs): |
| if state is not None: |
| reset_n_vis(state, n_vis) |
| with torch.no_grad(): |
| out = model.generate(**inputs, max_new_tokens=32) |
| if DEVICE == "cuda": |
| torch.cuda.synchronize() |
| elapsed = (time.perf_counter() - t0) / n_runs * 1000 |
| return elapsed, out |
|
|
|
|
| def main(): |
| print(f"\n=== SparseVLM End-to-End Test ===") |
| print(f"Device: {DEVICE}") |
| if DEVICE == "cuda": |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") |
|
|
| print(f"\nLoading {MODEL_ID} ...") |
| |
| model = QwenVLModel.from_pretrained( |
| MODEL_ID, |
| torch_dtype=DTYPE, |
| device_map="auto", |
| attn_implementation="eager", |
| ) |
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
| model.eval() |
| print("Model loaded.") |
|
|
| image = load_test_image() |
| prompt = "Describe this image in one sentence." |
|
|
| messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = processor(text=[text], images=[image], return_tensors="pt").to(DEVICE) |
|
|
| print("\nββ Baseline (no SparseVLM) ββββββββββββββββββββββββββββββββββββββ") |
| baseline_ms, out = timeit_generate(model, inputs) |
| response = processor.decode(out[0], skip_special_tokens=True) |
| print(f"Latency: {baseline_ms:.1f} ms") |
| print(f"Output: {response[:120]}") |
|
|
| configs = [ |
| ("SparseVLM n_vis=128 (50%)", 128), |
| ("SparseVLM n_vis=96 (37%)", 96), |
| ("SparseVLM n_vis=64 (25%)", 64), |
| ] |
|
|
| print(f"\n{'Config':<30} {'Latency':>10} {'Speedup':>10} Output") |
| print("β" * 100) |
|
|
| for label, n_vis in configs: |
| state = apply_sparsevlm(model, n_vis=n_vis) |
| reset_n_vis(state, n_vis=n_vis) |
|
|
| ms, out = timeit_generate(model, inputs, state=state, n_vis=n_vis) |
| response = processor.decode(out[0], skip_special_tokens=True) |
| speedup = baseline_ms / ms |
|
|
| print(f"{label:<30} {ms:>9.1f}ms {speedup:>9.1f}x {response[:60]}") |
| remove_hooks(state) |
|
|
| print("\nββ Layer 1 kernel benchmark βββββββββββββββββββββββββββββββββββββ") |
| import subprocess, sys, os |
| bench = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark", "bench_layer1.py") |
| subprocess.run([sys.executable, bench], check=True) |
|
|
| print("\n=== Test complete. Update README.md benchmark table with the numbers above. ===") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|