SparseVLM / sparsevlm /scheduler.py
Aryan3108's picture
Upload folder using huggingface_hub
176b11a verified
Raw
History Blame Contribute Delete
2.98 kB
"""
scheduler.py
------------
CUDA graph bucketing for zero kernel-launch overhead (Layer 3).
Snaps dynamic token counts to 10 pre-defined buckets.
Captures one CUDA graph per bucket. Routes requests to nearest bucket.
"""
import torch
class SparsityScheduler:
def __init__(self, n_vis_max: int, n_buckets: int = 10, min_tokens: int = 32):
self.n_vis_max = n_vis_max
self.n_buckets = n_buckets
self.min_tokens = min_tokens
self.buckets = self._compute_buckets()
self._graphs = {}
self._static_inputs = {}
self._static_outputs = {}
self._warmed_up = False
def _compute_buckets(self) -> list:
step = (self.n_vis_max - self.min_tokens) / self.n_buckets
buckets = [int(self.min_tokens + i * step) for i in range(self.n_buckets)]
buckets[-1] = self.n_vis_max
return sorted(set(buckets))
def snap_to_bucket(self, n_vis: int) -> int:
"""Snap to nearest bucket >= n_vis."""
for b in self.buckets:
if b >= n_vis:
return b
return self.n_vis_max
def get_bucket_idx(self, n_vis: int) -> int:
return self.buckets.index(self.snap_to_bucket(n_vis))
def warmup(self, model_forward_fn, sample_inputs_fn, n_warmup: int = 3):
"""Capture CUDA graphs for all buckets."""
if not torch.cuda.is_available():
print("[SparsityScheduler] CUDA not available — skipping.")
return
for idx, n_vis in enumerate(self.buckets):
static_inputs = sample_inputs_fn(n_vis)
for _ in range(n_warmup):
model_forward_fn(static_inputs)
torch.cuda.synchronize()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_output = model_forward_fn(static_inputs)
self._graphs[idx] = g
self._static_inputs[idx] = static_inputs
self._static_outputs[idx] = static_output
self._warmed_up = True
print(f"[SparsityScheduler] Captured graphs for {len(self.buckets)} buckets.")
def replay(self, bucket_idx: int, new_inputs: dict) -> torch.Tensor:
"""Copy new inputs into static tensors and replay graph."""
if not self._warmed_up:
raise RuntimeError("Call warmup() first.")
for key, tensor in new_inputs.items():
if key in self._static_inputs[bucket_idx]:
self._static_inputs[bucket_idx][key].copy_(tensor)
self._graphs[bucket_idx].replay()
return self._static_outputs[bucket_idx]
def summary(self) -> str:
return (
f"SparsityScheduler: {len(self.buckets)} buckets\n"
f" Token counts: {self.buckets}\n"
f" Warmed up: {self._warmed_up}"
)
def make_scheduler(n_vis_max: int, n_buckets: int = 10, min_tokens: int = 32):
return SparsityScheduler(n_vis_max, n_buckets, min_tokens)