SparseVLM / test_e2e.py
Aryan3108's picture
Upload folder using huggingface_hub
176b11a verified
Raw
History Blame Contribute Delete
4.87 kB
"""
test_e2e.py
-----------
End-to-end SparseVLM test on Qwen2.5-VL-7B-Instruct.
Downloads the model on first run (~15 GB). Runs three configurations:
- Baseline (no pruning)
- SparseVLM n_vis=128
- SparseVLM n_vis=64
Reports latency, speedup, and checks the output is coherent.
Usage:
python test_e2e.py
"""
import sys, time, io, requests
import torch
from PIL import Image
from sparsevlm import apply_sparsevlm, reset_n_vis
from sparsevlm.patch import remove_hooks
# Qwen2.5-VL uses Qwen2_5_VLForConditionalGeneration in transformers >= 4.49;
# fall back to the older name for 4.48.x and below.
try:
from transformers import Qwen2_5_VLForConditionalGeneration as QwenVLModel, AutoProcessor
except ImportError:
from transformers import Qwen2VLForConditionalGeneration as QwenVLModel, AutoProcessor
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 has float32 dynamic range β€” avoids NaN overflow with eager attention
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
def load_test_image() -> Image.Image:
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/280px-PNG_transparency_demonstration_1.png"
try:
resp = requests.get(url, timeout=10)
img = Image.open(io.BytesIO(resp.content)).convert("RGB")
print(f"Loaded test image from web: {img.size}")
return img
except Exception:
img = Image.new("RGB", (448, 448), color=(100, 149, 237))
print("Using synthetic test image (448x448 cornflower blue).")
return img
def timeit_generate(model, inputs, n_warmup=2, n_runs=5, state=None, n_vis=None):
for _ in range(n_warmup):
if state is not None:
reset_n_vis(state, n_vis)
with torch.no_grad():
model.generate(**inputs, max_new_tokens=32)
if DEVICE == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n_runs):
if state is not None:
reset_n_vis(state, n_vis)
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=32)
if DEVICE == "cuda":
torch.cuda.synchronize()
elapsed = (time.perf_counter() - t0) / n_runs * 1000
return elapsed, out
def main():
print(f"\n=== SparseVLM End-to-End Test ===")
print(f"Device: {DEVICE}")
if DEVICE == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
print(f"\nLoading {MODEL_ID} ...")
# eager attention required so attention weights are returned (flash attn returns None)
model = QwenVLModel.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
device_map="auto",
attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
model.eval()
print("Model loaded.")
image = load_test_image()
prompt = "Describe this image in one sentence."
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[image], return_tensors="pt").to(DEVICE)
print("\n── Baseline (no SparseVLM) ──────────────────────────────────────")
baseline_ms, out = timeit_generate(model, inputs)
response = processor.decode(out[0], skip_special_tokens=True)
print(f"Latency: {baseline_ms:.1f} ms")
print(f"Output: {response[:120]}")
configs = [
("SparseVLM n_vis=128 (50%)", 128),
("SparseVLM n_vis=96 (37%)", 96),
("SparseVLM n_vis=64 (25%)", 64),
]
print(f"\n{'Config':<30} {'Latency':>10} {'Speedup':>10} Output")
print("─" * 100)
for label, n_vis in configs:
state = apply_sparsevlm(model, n_vis=n_vis)
reset_n_vis(state, n_vis=n_vis)
ms, out = timeit_generate(model, inputs, state=state, n_vis=n_vis)
response = processor.decode(out[0], skip_special_tokens=True)
speedup = baseline_ms / ms
print(f"{label:<30} {ms:>9.1f}ms {speedup:>9.1f}x {response[:60]}")
remove_hooks(state)
print("\n── Layer 1 kernel benchmark ─────────────────────────────────────")
import subprocess, sys, os
bench = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark", "bench_layer1.py")
subprocess.run([sys.executable, bench], check=True)
print("\n=== Test complete. Update README.md benchmark table with the numbers above. ===")
if __name__ == "__main__":
main()