DotCache-Arena / dotcache /page_cache.py
DeanoCalver's picture
Initial DotCache Arena Space upload
751ad26 verified
Raw
History Blame Contribute Delete
9.27 kB
from __future__ import annotations
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import Literal
from .backends import (
PreparedPageTorch,
cuda_available,
mps_available,
prepare_page_cuda,
prepare_page_mps,
prepare_pages_cuda,
prepare_pages_mps,
)
from .tracing import ExecutionTrace
from .types import EncodedPage
CachePolicy = Literal["fifo", "lru", "pinned_recent_fifo"]
CacheBackend = Literal["auto", "torch_mps", "torch_cuda", "cpu_ref"]
@dataclass(slots=True)
class PreparedPageCache:
max_resident_bytes: int | None = None
policy: CachePolicy = "fifo"
pinned_recent_pages: int = 0
_prepared_pages: dict[tuple[str, int], PreparedPageTorch] = field(default_factory=dict)
_prepared_page_ids: set[int] = field(default_factory=set)
_resident_bytes: int = 0
_order: OrderedDict[tuple[str, int], None] = field(default_factory=OrderedDict)
@property
def resident_bytes(self) -> int:
return self._resident_bytes
@property
def size(self) -> int:
return len(self._prepared_pages)
def clear(self) -> None:
self._prepared_pages.clear()
self._prepared_page_ids.clear()
self._resident_bytes = 0
self._order.clear()
def owns_prepared_page(self, page: PreparedPageTorch) -> bool:
return id(page) in self._prepared_page_ids
def _page_nbytes(self, page: PreparedPageTorch) -> int:
resident_nbytes = int(page.resident_nbytes)
if resident_nbytes > 0:
return resident_nbytes
return int(page.host_to_device_nbytes)
def _pinned_keys(self) -> set[int]:
if self.policy != "pinned_recent_fifo" or self.pinned_recent_pages <= 0:
return set()
keys = list(self._order.keys())
if not keys:
return set()
return set(keys[-self.pinned_recent_pages :])
def _touch_cached_page(self, cache_key: tuple[str, int]) -> None:
if self.policy == "lru":
self._order.move_to_end(cache_key)
def _evict_one(self, *, trace: ExecutionTrace | None = None) -> bool:
pinned_keys = self._pinned_keys()
fallback_key: tuple[str, int] | None = None
while self._order:
cache_key = next(iter(self._order))
self._order.pop(cache_key, None)
if cache_key in pinned_keys:
if fallback_key is None:
fallback_key = cache_key
self._order[cache_key] = None
if len(pinned_keys) >= len(self._order):
break
continue
cached_page = self._prepared_pages.pop(cache_key, None)
if cached_page is None:
continue
self._prepared_page_ids.discard(id(cached_page))
evicted_bytes = self._page_nbytes(cached_page)
self._resident_bytes = max(0, self._resident_bytes - evicted_bytes)
if trace is not None:
trace.record_cache_eviction(evicted_bytes)
trace.observe_cache_resident_bytes(self._resident_bytes)
return True
if fallback_key is not None:
self._order.pop(fallback_key, None)
cached_page = self._prepared_pages.pop(fallback_key, None)
if cached_page is None:
return False
self._prepared_page_ids.discard(id(cached_page))
evicted_bytes = self._page_nbytes(cached_page)
self._resident_bytes = max(0, self._resident_bytes - evicted_bytes)
if trace is not None:
trace.record_cache_eviction(evicted_bytes)
trace.observe_cache_resident_bytes(self._resident_bytes)
return True
return False
def _ensure_capacity(self, incoming_nbytes: int, *, trace: ExecutionTrace | None = None) -> None:
if self.max_resident_bytes is None:
return
while self._resident_bytes + incoming_nbytes > self.max_resident_bytes and self._prepared_pages:
if not self._evict_one(trace=trace):
break
def _resolve_backend(self, backend: CacheBackend) -> CacheBackend:
if backend != "auto":
return backend
if cuda_available():
return "torch_cuda"
if mps_available():
return "torch_mps"
return "cpu_ref"
def append_page(
self,
page: EncodedPage | PreparedPageTorch,
*,
backend: CacheBackend = "auto",
trace: ExecutionTrace | None = None,
) -> EncodedPage | PreparedPageTorch:
return self.prepare_page(page, backend=backend, trace=trace)
def append_pages(
self,
pages: list[EncodedPage | PreparedPageTorch],
*,
backend: CacheBackend = "auto",
trace: ExecutionTrace | None = None,
) -> list[EncodedPage | PreparedPageTorch]:
return self.prepare_pages(pages, backend=backend, trace=trace)
def prepare_page(
self,
page: EncodedPage | PreparedPageTorch,
*,
backend: CacheBackend = "auto",
trace: ExecutionTrace | None = None,
) -> EncodedPage | PreparedPageTorch:
resolved_backend = self._resolve_backend(backend)
if resolved_backend == "cpu_ref":
return page.source_page if isinstance(page, PreparedPageTorch) else page
if isinstance(page, PreparedPageTorch):
if trace is not None:
trace.record_cache_hit()
trace.observe_cache_resident_bytes(self._resident_bytes)
return page
cache_key = ("cuda" if resolved_backend == "torch_cuda" else "mps", id(page))
cached_page = self._prepared_pages.get(cache_key)
if cached_page is not None:
self._touch_cached_page(cache_key)
if trace is not None:
trace.record_cache_hit()
trace.observe_cache_resident_bytes(self._resident_bytes)
return cached_page
prepared_page = prepare_page_cuda(page, trace=trace) if resolved_backend == "torch_cuda" else prepare_page_mps(page, trace=trace)
self._ensure_capacity(self._page_nbytes(prepared_page), trace=trace)
self._prepared_pages[cache_key] = prepared_page
self._prepared_page_ids.add(id(prepared_page))
self._order[cache_key] = None
self._resident_bytes += self._page_nbytes(prepared_page)
if trace is not None:
trace.record_cache_miss()
trace.observe_cache_resident_bytes(self._resident_bytes)
return prepared_page
def prepare_pages(
self,
pages: list[EncodedPage | PreparedPageTorch],
*,
backend: CacheBackend = "auto",
trace: ExecutionTrace | None = None,
) -> list[EncodedPage | PreparedPageTorch]:
resolved_backend = self._resolve_backend(backend)
if resolved_backend == "cpu_ref":
return [page.source_page if isinstance(page, PreparedPageTorch) else page for page in pages]
prepared_pages: list[EncodedPage | PreparedPageTorch | None] = [None] * len(pages)
miss_indices: list[int] = []
miss_pages: list[EncodedPage] = []
for index, page in enumerate(pages):
if isinstance(page, PreparedPageTorch):
prepared_pages[index] = page
if trace is not None:
trace.record_cache_hit()
trace.observe_cache_resident_bytes(self._resident_bytes)
continue
cache_key = ("cuda" if resolved_backend == "torch_cuda" else "mps", id(page))
cached_page = self._prepared_pages.get(cache_key)
if cached_page is not None:
self._touch_cached_page(cache_key)
prepared_pages[index] = cached_page
if trace is not None:
trace.record_cache_hit()
trace.observe_cache_resident_bytes(self._resident_bytes)
continue
miss_indices.append(index)
miss_pages.append(page)
if miss_pages:
new_prepared_pages = prepare_pages_cuda(miss_pages, trace=trace) if resolved_backend == "torch_cuda" else prepare_pages_mps(miss_pages, trace=trace)
for index, source_page, prepared_page in zip(miss_indices, miss_pages, new_prepared_pages, strict=True):
self._ensure_capacity(self._page_nbytes(prepared_page), trace=trace)
cache_key = (prepared_page.device_type, id(source_page))
self._prepared_pages[cache_key] = prepared_page
self._prepared_page_ids.add(id(prepared_page))
self._order[cache_key] = None
self._resident_bytes += self._page_nbytes(prepared_page)
prepared_pages[index] = prepared_page
if trace is not None:
trace.record_cache_miss()
trace.observe_cache_resident_bytes(self._resident_bytes)
if any(page is None for page in prepared_pages):
raise RuntimeError("prepared page cache failed to populate all requested pages")
return [page for page in prepared_pages if page is not None]