""" test_e2e.py ----------- End-to-end SparseVLM test on Qwen2.5-VL-7B-Instruct. Downloads the model on first run (~15 GB). Runs three configurations: - Baseline (no pruning) - SparseVLM n_vis=128 - SparseVLM n_vis=64 Reports latency, speedup, and checks the output is coherent. Usage: python test_e2e.py """ import sys, time, io, requests import torch from PIL import Image from sparsevlm import apply_sparsevlm, reset_n_vis from sparsevlm.patch import remove_hooks # Qwen2.5-VL uses Qwen2_5_VLForConditionalGeneration in transformers >= 4.49; # fall back to the older name for 4.48.x and below. try: from transformers import Qwen2_5_VLForConditionalGeneration as QwenVLModel, AutoProcessor except ImportError: from transformers import Qwen2VLForConditionalGeneration as QwenVLModel, AutoProcessor MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # bfloat16 has float32 dynamic range — avoids NaN overflow with eager attention DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 def load_test_image() -> Image.Image: url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/280px-PNG_transparency_demonstration_1.png" try: resp = requests.get(url, timeout=10) img = Image.open(io.BytesIO(resp.content)).convert("RGB") print(f"Loaded test image from web: {img.size}") return img except Exception: img = Image.new("RGB", (448, 448), color=(100, 149, 237)) print("Using synthetic test image (448x448 cornflower blue).") return img def timeit_generate(model, inputs, n_warmup=2, n_runs=5, state=None, n_vis=None): for _ in range(n_warmup): if state is not None: reset_n_vis(state, n_vis) with torch.no_grad(): model.generate(**inputs, max_new_tokens=32) if DEVICE == "cuda": torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(n_runs): if state is not None: reset_n_vis(state, n_vis) with torch.no_grad(): out = model.generate(**inputs, max_new_tokens=32) if DEVICE == "cuda": torch.cuda.synchronize() elapsed = (time.perf_counter() - t0) / n_runs * 1000 return elapsed, out def main(): print(f"\n=== SparseVLM End-to-End Test ===") print(f"Device: {DEVICE}") if DEVICE == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") print(f"\nLoading {MODEL_ID} ...") # eager attention required so attention weights are returned (flash attn returns None) model = QwenVLModel.from_pretrained( MODEL_ID, torch_dtype=DTYPE, device_map="auto", attn_implementation="eager", ) processor = AutoProcessor.from_pretrained(MODEL_ID) model.eval() print("Model loaded.") image = load_test_image() prompt = "Describe this image in one sentence." messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[text], images=[image], return_tensors="pt").to(DEVICE) print("\n── Baseline (no SparseVLM) ──────────────────────────────────────") baseline_ms, out = timeit_generate(model, inputs) response = processor.decode(out[0], skip_special_tokens=True) print(f"Latency: {baseline_ms:.1f} ms") print(f"Output: {response[:120]}") configs = [ ("SparseVLM n_vis=128 (50%)", 128), ("SparseVLM n_vis=96 (37%)", 96), ("SparseVLM n_vis=64 (25%)", 64), ] print(f"\n{'Config':<30} {'Latency':>10} {'Speedup':>10} Output") print("─" * 100) for label, n_vis in configs: state = apply_sparsevlm(model, n_vis=n_vis) reset_n_vis(state, n_vis=n_vis) ms, out = timeit_generate(model, inputs, state=state, n_vis=n_vis) response = processor.decode(out[0], skip_special_tokens=True) speedup = baseline_ms / ms print(f"{label:<30} {ms:>9.1f}ms {speedup:>9.1f}x {response[:60]}") remove_hooks(state) print("\n── Layer 1 kernel benchmark ─────────────────────────────────────") import subprocess, sys, os bench = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark", "bench_layer1.py") subprocess.run([sys.executable, bench], check=True) print("\n=== Test complete. Update README.md benchmark table with the numbers above. ===") if __name__ == "__main__": main()