""" scheduler.py ------------ CUDA graph bucketing for zero kernel-launch overhead (Layer 3). Snaps dynamic token counts to 10 pre-defined buckets. Captures one CUDA graph per bucket. Routes requests to nearest bucket. """ import torch class SparsityScheduler: def __init__(self, n_vis_max: int, n_buckets: int = 10, min_tokens: int = 32): self.n_vis_max = n_vis_max self.n_buckets = n_buckets self.min_tokens = min_tokens self.buckets = self._compute_buckets() self._graphs = {} self._static_inputs = {} self._static_outputs = {} self._warmed_up = False def _compute_buckets(self) -> list: step = (self.n_vis_max - self.min_tokens) / self.n_buckets buckets = [int(self.min_tokens + i * step) for i in range(self.n_buckets)] buckets[-1] = self.n_vis_max return sorted(set(buckets)) def snap_to_bucket(self, n_vis: int) -> int: """Snap to nearest bucket >= n_vis.""" for b in self.buckets: if b >= n_vis: return b return self.n_vis_max def get_bucket_idx(self, n_vis: int) -> int: return self.buckets.index(self.snap_to_bucket(n_vis)) def warmup(self, model_forward_fn, sample_inputs_fn, n_warmup: int = 3): """Capture CUDA graphs for all buckets.""" if not torch.cuda.is_available(): print("[SparsityScheduler] CUDA not available — skipping.") return for idx, n_vis in enumerate(self.buckets): static_inputs = sample_inputs_fn(n_vis) for _ in range(n_warmup): model_forward_fn(static_inputs) torch.cuda.synchronize() g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): static_output = model_forward_fn(static_inputs) self._graphs[idx] = g self._static_inputs[idx] = static_inputs self._static_outputs[idx] = static_output self._warmed_up = True print(f"[SparsityScheduler] Captured graphs for {len(self.buckets)} buckets.") def replay(self, bucket_idx: int, new_inputs: dict) -> torch.Tensor: """Copy new inputs into static tensors and replay graph.""" if not self._warmed_up: raise RuntimeError("Call warmup() first.") for key, tensor in new_inputs.items(): if key in self._static_inputs[bucket_idx]: self._static_inputs[bucket_idx][key].copy_(tensor) self._graphs[bucket_idx].replay() return self._static_outputs[bucket_idx] def summary(self) -> str: return ( f"SparsityScheduler: {len(self.buckets)} buckets\n" f" Token counts: {self.buckets}\n" f" Warmed up: {self._warmed_up}" ) def make_scheduler(n_vis_max: int, n_buckets: int = 10, min_tokens: int = 32): return SparsityScheduler(n_vis_max, n_buckets, min_tokens)