File size: 1,458 Bytes
176b11a
 
 
 
 
 
 
 
 
 
 
f6cafdf
176b11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6cafdf
 
45c83c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
sparsevlm — Training-free visual token sparsification for VLMs.

Quick start:
    from sparsevlm import apply_sparsevlm, reset_n_vis
    state = apply_sparsevlm(model, n_vis=256)
    reset_n_vis(state, n_vis=256)   # call before every new image
    output = model.generate(...)
"""

from .patch import patch_qwen2vl, reset_n_vis, unpatch_qwen2vl, remove_hooks
from .generate import sparsevlm_generate


def apply_sparsevlm(
    model,
    n_vis: int = 256,
    target_layers=None,
    min_keep: int = 32,
    tau: float = 0.5,
    theta: float = 0.5,
) -> dict:
    """
    Apply SparseVLM to a Qwen2.5-VL model. One call, no training needed.

    Args:
        model:         Qwen2VLForConditionalGeneration
        n_vis:         visual tokens per image (Qwen2.5-VL-7B: ~256 for 448px)
        target_layers: layers to prune at (default: every 4th from layer 2)
        min_keep:      never prune below this many visual tokens
        tau:           recycling fraction (paper default: 0.5)
        theta:         cluster ratio (paper default: 0.5)

    Returns:
        state dict — pass to reset_n_vis() before each new image
    """
    return patch_qwen2vl(
        model=model,
        n_vis=n_vis,
        target_layers=target_layers,
        min_keep=min_keep,
        tau=tau,
        theta=theta,
    )


__all__ = ["apply_sparsevlm", "reset_n_vis", "unpatch_qwen2vl",
           "remove_hooks", "sparsevlm_generate"]
__version__ = "0.1.3"