Aryan3108 commited on
Commit
f6cafdf
Β·
verified Β·
1 Parent(s): a31ee88

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +17 -11
  2. pyproject.toml +1 -1
  3. sparsevlm/__init__.py +4 -2
  4. sparsevlm/generate.py +104 -0
README.md CHANGED
@@ -37,7 +37,7 @@ pip install sparsevlm
37
  ```python
38
  import torch
39
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
40
- from sparsevlm import apply_sparsevlm, reset_n_vis, remove_hooks
41
 
42
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
  "Qwen/Qwen2.5-VL-7B-Instruct",
@@ -47,16 +47,22 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
47
  )
48
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
49
 
50
- # Enable SparseVLM β€” no retraining needed
51
- state = apply_sparsevlm(model, n_vis=256)
52
-
53
- # Reset before each new image forward pass
54
- reset_n_vis(state, n_vis=256)
55
- inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
56
- output = model.generate(**inputs, max_new_tokens=256)
57
-
58
- # Remove hooks when done
59
- remove_hooks(state)
 
 
 
 
 
 
60
  ```
61
 
62
  ---
 
37
  ```python
38
  import torch
39
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
40
+ from sparsevlm import sparsevlm_generate
41
 
42
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
  "Qwen/Qwen2.5-VL-7B-Instruct",
 
47
  )
48
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
49
 
50
+ # Prepare inputs normally
51
+ messages = [{"role": "user", "content": [
52
+ {"type": "image", "image": image},
53
+ {"type": "text", "text": "Describe this image."}
54
+ ]}]
55
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
56
+ inputs = processor(text=[text], images=[image], return_tensors="pt").to("cuda")
57
+
58
+ # Run SparseVLM β€” keeps top-64 visual tokens out of 256 (25%)
59
+ output = sparsevlm_generate(
60
+ model, processor, inputs,
61
+ n_vis=256, # visual tokens in your sequence
62
+ keep_n_vis=64, # keep 25% β€” tune this
63
+ max_new_tokens=256,
64
+ )
65
+ print(processor.decode(output[0][1:], skip_special_tokens=True))
66
  ```
67
 
68
  ---
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "sparsevlm"
7
- version = "0.1.1"
8
  description = "Training-free visual token sparsification for vision-language models (ICML 2025)"
9
  readme = "README.md"
10
  license = { text = "Apache-2.0" }
 
4
 
5
  [project]
6
  name = "sparsevlm"
7
+ version = "0.1.2"
8
  description = "Training-free visual token sparsification for vision-language models (ICML 2025)"
9
  readme = "README.md"
10
  license = { text = "Apache-2.0" }
sparsevlm/__init__.py CHANGED
@@ -9,6 +9,7 @@ Quick start:
9
  """
10
 
11
  from .patch import patch_qwen2vl, reset_n_vis, unpatch_qwen2vl, remove_hooks
 
12
 
13
 
14
  def apply_sparsevlm(
@@ -43,5 +44,6 @@ def apply_sparsevlm(
43
  )
44
 
45
 
46
- __all__ = ["apply_sparsevlm", "reset_n_vis", "unpatch_qwen2vl", "remove_hooks"]
47
- __version__ = "0.1.1"
 
 
9
  """
10
 
11
  from .patch import patch_qwen2vl, reset_n_vis, unpatch_qwen2vl, remove_hooks
12
+ from .generate import sparsevlm_generate
13
 
14
 
15
  def apply_sparsevlm(
 
44
  )
45
 
46
 
47
+ __all__ = ["apply_sparsevlm", "reset_n_vis", "unpatch_qwen2vl",
48
+ "remove_hooks", "sparsevlm_generate"]
49
+ __version__ = "0.1.2"
sparsevlm/generate.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ generate.py β€” KV cache pruning for SparseVLM.
3
+
4
+ Usage:
5
+ from sparsevlm import sparsevlm_generate
6
+
7
+ output = sparsevlm_generate(
8
+ model, processor, inputs,
9
+ n_vis=256, # total visual tokens in the sequence
10
+ keep_n_vis=64, # how many to keep (25%)
11
+ max_new_tokens=256,
12
+ )
13
+ """
14
+
15
+ import torch
16
+
17
+
18
+ def _prune_kv_cache(cache, kept_indices):
19
+ """
20
+ Remove KV entries for pruned visual tokens.
21
+ Works with transformers 5.x DynamicCache (cache.layers[i].keys / .values).
22
+ .contiguous() ensures no stride gaps after indexing.
23
+ """
24
+ for layer in cache.layers:
25
+ k = kept_indices.to(layer.keys.device)
26
+ layer.keys = layer.keys[:, :, k, :].contiguous()
27
+ layer.values = layer.values[:, :, k, :].contiguous()
28
+ return cache
29
+
30
+
31
+ def sparsevlm_generate(
32
+ model,
33
+ processor,
34
+ inputs,
35
+ n_vis: int,
36
+ keep_n_vis: int,
37
+ max_new_tokens: int = 256,
38
+ target_layer: int = 2,
39
+ device: str = "cuda",
40
+ ):
41
+ """
42
+ SparseVLM generation via KV cache pruning.
43
+
44
+ Runs prefill once with output_attentions=True, scores all n_vis visual
45
+ tokens by their total text attention, keeps the top keep_n_vis, and
46
+ decodes with the pruned KV cache.
47
+
48
+ Args:
49
+ model: Qwen2_5_VLForConditionalGeneration (loaded with
50
+ attn_implementation="eager")
51
+ processor: AutoProcessor
52
+ inputs: dict from processor(..., return_tensors="pt")
53
+ n_vis: number of visual tokens in the sequence
54
+ (inputs["input_ids"].shape[1] - n_text)
55
+ keep_n_vis: how many visual tokens to keep
56
+ max_new_tokens: generation length
57
+ target_layer: which layer's attention to use for scoring (default 2)
58
+ device: primary device (default "cuda")
59
+
60
+ Returns:
61
+ generated token ids [B, max_new_tokens]
62
+ """
63
+ N_TOTAL = inputs["input_ids"].shape[1]
64
+
65
+ # ── 1. Prefill β€” get KV cache + attention weights ─────────────────────────
66
+ with torch.no_grad():
67
+ prefill = model(**inputs, use_cache=True, output_attentions=True)
68
+
69
+ # ── 2. Score all n_vis visual tokens ──────────────────────────────────────
70
+ # text→visual attention submatrix: [B, H, N_text, N_vis] averaged over heads
71
+ attn = prefill.attentions[target_layer]
72
+ A_tv = attn[:, :, n_vis:, :n_vis].mean(dim=1) # [B, N_text, N_vis]
73
+ scores = A_tv.sum(dim=1)[0] # [N_vis]
74
+
75
+ # ── 3. Keep top-keep_n_vis visual tokens by attention score ───────────────
76
+ kept_vis = scores.topk(keep_n_vis).indices
77
+ text_idx = torch.arange(n_vis, N_TOTAL, device=kept_vis.device)
78
+ kept_all = torch.cat([kept_vis, text_idx])
79
+
80
+ cache = _prune_kv_cache(prefill.past_key_values, kept_all)
81
+ n_kept = cache.get_seq_length()
82
+
83
+ # ── 4. Fix rope_deltas so decode positions are correct ────────────────────
84
+ # generate() computes: next_pos = cache.get_seq_length() + rope_deltas
85
+ # After pruning get_seq_length() = n_kept < N_TOTAL, so we compensate:
86
+ n_pruned = N_TOTAL - n_kept
87
+ orig_deltas = model.model.rope_deltas.clone()
88
+ model.model.rope_deltas = orig_deltas + n_pruned
89
+
90
+ # ── 5. Decode with pruned cache ────────────────────────────────────────────
91
+ attn_mask = torch.ones(1, n_kept + 1, device=device, dtype=torch.long)
92
+ with torch.no_grad():
93
+ output = model.generate(
94
+ input_ids=inputs["input_ids"][:, -1:],
95
+ attention_mask=attn_mask,
96
+ past_key_values=cache,
97
+ max_new_tokens=max_new_tokens,
98
+ use_cache=True,
99
+ )
100
+
101
+ # ── 6. Restore rope_deltas ─────────────────────────────────────────────────
102
+ model.model.rope_deltas = orig_deltas
103
+
104
+ return output