Spaces:
Paused
Paused
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from collections import OrderedDict | |
| from typing import Literal | |
| from .backends import ( | |
| PreparedPageTorch, | |
| cuda_available, | |
| mps_available, | |
| prepare_page_cuda, | |
| prepare_page_mps, | |
| prepare_pages_cuda, | |
| prepare_pages_mps, | |
| ) | |
| from .tracing import ExecutionTrace | |
| from .types import EncodedPage | |
| CachePolicy = Literal["fifo", "lru", "pinned_recent_fifo"] | |
| CacheBackend = Literal["auto", "torch_mps", "torch_cuda", "cpu_ref"] | |
| class PreparedPageCache: | |
| max_resident_bytes: int | None = None | |
| policy: CachePolicy = "fifo" | |
| pinned_recent_pages: int = 0 | |
| _prepared_pages: dict[tuple[str, int], PreparedPageTorch] = field(default_factory=dict) | |
| _prepared_page_ids: set[int] = field(default_factory=set) | |
| _resident_bytes: int = 0 | |
| _order: OrderedDict[tuple[str, int], None] = field(default_factory=OrderedDict) | |
| def resident_bytes(self) -> int: | |
| return self._resident_bytes | |
| def size(self) -> int: | |
| return len(self._prepared_pages) | |
| def clear(self) -> None: | |
| self._prepared_pages.clear() | |
| self._prepared_page_ids.clear() | |
| self._resident_bytes = 0 | |
| self._order.clear() | |
| def owns_prepared_page(self, page: PreparedPageTorch) -> bool: | |
| return id(page) in self._prepared_page_ids | |
| def _page_nbytes(self, page: PreparedPageTorch) -> int: | |
| resident_nbytes = int(page.resident_nbytes) | |
| if resident_nbytes > 0: | |
| return resident_nbytes | |
| return int(page.host_to_device_nbytes) | |
| def _pinned_keys(self) -> set[int]: | |
| if self.policy != "pinned_recent_fifo" or self.pinned_recent_pages <= 0: | |
| return set() | |
| keys = list(self._order.keys()) | |
| if not keys: | |
| return set() | |
| return set(keys[-self.pinned_recent_pages :]) | |
| def _touch_cached_page(self, cache_key: tuple[str, int]) -> None: | |
| if self.policy == "lru": | |
| self._order.move_to_end(cache_key) | |
| def _evict_one(self, *, trace: ExecutionTrace | None = None) -> bool: | |
| pinned_keys = self._pinned_keys() | |
| fallback_key: tuple[str, int] | None = None | |
| while self._order: | |
| cache_key = next(iter(self._order)) | |
| self._order.pop(cache_key, None) | |
| if cache_key in pinned_keys: | |
| if fallback_key is None: | |
| fallback_key = cache_key | |
| self._order[cache_key] = None | |
| if len(pinned_keys) >= len(self._order): | |
| break | |
| continue | |
| cached_page = self._prepared_pages.pop(cache_key, None) | |
| if cached_page is None: | |
| continue | |
| self._prepared_page_ids.discard(id(cached_page)) | |
| evicted_bytes = self._page_nbytes(cached_page) | |
| self._resident_bytes = max(0, self._resident_bytes - evicted_bytes) | |
| if trace is not None: | |
| trace.record_cache_eviction(evicted_bytes) | |
| trace.observe_cache_resident_bytes(self._resident_bytes) | |
| return True | |
| if fallback_key is not None: | |
| self._order.pop(fallback_key, None) | |
| cached_page = self._prepared_pages.pop(fallback_key, None) | |
| if cached_page is None: | |
| return False | |
| self._prepared_page_ids.discard(id(cached_page)) | |
| evicted_bytes = self._page_nbytes(cached_page) | |
| self._resident_bytes = max(0, self._resident_bytes - evicted_bytes) | |
| if trace is not None: | |
| trace.record_cache_eviction(evicted_bytes) | |
| trace.observe_cache_resident_bytes(self._resident_bytes) | |
| return True | |
| return False | |
| def _ensure_capacity(self, incoming_nbytes: int, *, trace: ExecutionTrace | None = None) -> None: | |
| if self.max_resident_bytes is None: | |
| return | |
| while self._resident_bytes + incoming_nbytes > self.max_resident_bytes and self._prepared_pages: | |
| if not self._evict_one(trace=trace): | |
| break | |
| def _resolve_backend(self, backend: CacheBackend) -> CacheBackend: | |
| if backend != "auto": | |
| return backend | |
| if cuda_available(): | |
| return "torch_cuda" | |
| if mps_available(): | |
| return "torch_mps" | |
| return "cpu_ref" | |
| def append_page( | |
| self, | |
| page: EncodedPage | PreparedPageTorch, | |
| *, | |
| backend: CacheBackend = "auto", | |
| trace: ExecutionTrace | None = None, | |
| ) -> EncodedPage | PreparedPageTorch: | |
| return self.prepare_page(page, backend=backend, trace=trace) | |
| def append_pages( | |
| self, | |
| pages: list[EncodedPage | PreparedPageTorch], | |
| *, | |
| backend: CacheBackend = "auto", | |
| trace: ExecutionTrace | None = None, | |
| ) -> list[EncodedPage | PreparedPageTorch]: | |
| return self.prepare_pages(pages, backend=backend, trace=trace) | |
| def prepare_page( | |
| self, | |
| page: EncodedPage | PreparedPageTorch, | |
| *, | |
| backend: CacheBackend = "auto", | |
| trace: ExecutionTrace | None = None, | |
| ) -> EncodedPage | PreparedPageTorch: | |
| resolved_backend = self._resolve_backend(backend) | |
| if resolved_backend == "cpu_ref": | |
| return page.source_page if isinstance(page, PreparedPageTorch) else page | |
| if isinstance(page, PreparedPageTorch): | |
| if trace is not None: | |
| trace.record_cache_hit() | |
| trace.observe_cache_resident_bytes(self._resident_bytes) | |
| return page | |
| cache_key = ("cuda" if resolved_backend == "torch_cuda" else "mps", id(page)) | |
| cached_page = self._prepared_pages.get(cache_key) | |
| if cached_page is not None: | |
| self._touch_cached_page(cache_key) | |
| if trace is not None: | |
| trace.record_cache_hit() | |
| trace.observe_cache_resident_bytes(self._resident_bytes) | |
| return cached_page | |
| prepared_page = prepare_page_cuda(page, trace=trace) if resolved_backend == "torch_cuda" else prepare_page_mps(page, trace=trace) | |
| self._ensure_capacity(self._page_nbytes(prepared_page), trace=trace) | |
| self._prepared_pages[cache_key] = prepared_page | |
| self._prepared_page_ids.add(id(prepared_page)) | |
| self._order[cache_key] = None | |
| self._resident_bytes += self._page_nbytes(prepared_page) | |
| if trace is not None: | |
| trace.record_cache_miss() | |
| trace.observe_cache_resident_bytes(self._resident_bytes) | |
| return prepared_page | |
| def prepare_pages( | |
| self, | |
| pages: list[EncodedPage | PreparedPageTorch], | |
| *, | |
| backend: CacheBackend = "auto", | |
| trace: ExecutionTrace | None = None, | |
| ) -> list[EncodedPage | PreparedPageTorch]: | |
| resolved_backend = self._resolve_backend(backend) | |
| if resolved_backend == "cpu_ref": | |
| return [page.source_page if isinstance(page, PreparedPageTorch) else page for page in pages] | |
| prepared_pages: list[EncodedPage | PreparedPageTorch | None] = [None] * len(pages) | |
| miss_indices: list[int] = [] | |
| miss_pages: list[EncodedPage] = [] | |
| for index, page in enumerate(pages): | |
| if isinstance(page, PreparedPageTorch): | |
| prepared_pages[index] = page | |
| if trace is not None: | |
| trace.record_cache_hit() | |
| trace.observe_cache_resident_bytes(self._resident_bytes) | |
| continue | |
| cache_key = ("cuda" if resolved_backend == "torch_cuda" else "mps", id(page)) | |
| cached_page = self._prepared_pages.get(cache_key) | |
| if cached_page is not None: | |
| self._touch_cached_page(cache_key) | |
| prepared_pages[index] = cached_page | |
| if trace is not None: | |
| trace.record_cache_hit() | |
| trace.observe_cache_resident_bytes(self._resident_bytes) | |
| continue | |
| miss_indices.append(index) | |
| miss_pages.append(page) | |
| if miss_pages: | |
| new_prepared_pages = prepare_pages_cuda(miss_pages, trace=trace) if resolved_backend == "torch_cuda" else prepare_pages_mps(miss_pages, trace=trace) | |
| for index, source_page, prepared_page in zip(miss_indices, miss_pages, new_prepared_pages, strict=True): | |
| self._ensure_capacity(self._page_nbytes(prepared_page), trace=trace) | |
| cache_key = (prepared_page.device_type, id(source_page)) | |
| self._prepared_pages[cache_key] = prepared_page | |
| self._prepared_page_ids.add(id(prepared_page)) | |
| self._order[cache_key] = None | |
| self._resident_bytes += self._page_nbytes(prepared_page) | |
| prepared_pages[index] = prepared_page | |
| if trace is not None: | |
| trace.record_cache_miss() | |
| trace.observe_cache_resident_bytes(self._resident_bytes) | |
| if any(page is None for page in prepared_pages): | |
| raise RuntimeError("prepared page cache failed to populate all requested pages") | |
| return [page for page in prepared_pages if page is not None] | |