DotCache-Arena / dotcache /attention_runtime.py
DeanoCalver's picture
Initial DotCache Arena Space upload
751ad26 verified
Raw
History Blame Contribute Delete
11.7 kB
from __future__ import annotations
from typing import Literal, Sequence
import numpy as np
from .attention_reference import softmax
from .backends import (
PreparedPageTorch,
cuda_available,
decode_multi_query_step_cuda,
decode_step_cuda,
decode_multi_query_step_mps,
decode_step_mps,
mix_page_cpu_ref,
mix_page_cuda,
mix_page_mps,
mps_available,
page_supported_cuda,
page_supported_mps,
prepare_page_cuda,
prepare_page_mps,
prepare_pages_cuda,
prepare_pages_mps,
score_pages_cuda,
score_pages_mps,
score_page_cpu_ref,
score_page_cuda,
score_page_mps,
)
from .page_cache import PreparedPageCache
from .tracing import ExecutionTrace
from .types import EncodedPage
BackendName = Literal["cpu_ref", "torch_mps", "torch_cuda", "auto"]
PageLike = EncodedPage | PreparedPageTorch
def _resolve_backend(backend: BackendName, page: PageLike) -> Literal["cpu_ref", "torch_mps", "torch_cuda"]:
if backend == "cpu_ref":
return "cpu_ref"
if backend == "torch_mps":
if not mps_available():
raise RuntimeError("torch_mps is unavailable on this machine")
if not page_supported_mps(page):
raise ValueError("page is unsupported by torch_mps in this phase")
return "torch_mps"
if backend == "torch_cuda":
if not cuda_available():
raise RuntimeError("torch_cuda is unavailable on this machine")
if not page_supported_cuda(page):
raise ValueError("page is unsupported by torch_cuda in this phase")
return "torch_cuda"
if isinstance(page, PreparedPageTorch):
return "torch_cuda" if page.device_type == "cuda" else "torch_mps"
if cuda_available() and page_supported_cuda(page):
return "torch_cuda"
if mps_available() and page_supported_mps(page):
return "torch_mps"
return "cpu_ref"
def _prepared_pages_backend(pages: Sequence[PageLike]) -> Literal["torch_mps", "torch_cuda"] | None:
if not pages or not all(isinstance(page, PreparedPageTorch) for page in pages):
return None
device_type = pages[0].device_type
if any(page.device_type != device_type for page in pages):
raise ValueError("prepared torch pages must all target the same device")
return "torch_cuda" if device_type == "cuda" else "torch_mps"
def prepare_page(
page: PageLike,
*,
backend: BackendName = "auto",
cache: PreparedPageCache | None = None,
trace: ExecutionTrace | None = None,
) -> PageLike:
resolved_backend = _resolve_backend(backend, page)
if resolved_backend == "torch_mps":
if cache is not None:
return cache.prepare_page(page, backend="torch_mps", trace=trace)
return prepare_page_mps(page, trace=trace)
if resolved_backend == "torch_cuda":
if cache is not None:
return cache.prepare_page(page, backend="torch_cuda", trace=trace)
return prepare_page_cuda(page, trace=trace)
return page.source_page if isinstance(page, PreparedPageTorch) else page
def prepare_pages(
pages: Sequence[PageLike],
*,
backend: BackendName = "auto",
cache: PreparedPageCache | None = None,
trace: ExecutionTrace | None = None,
) -> list[PageLike]:
if pages:
resolved_backend = _resolve_backend(backend, pages[0])
if resolved_backend == "torch_mps":
if cache is not None:
return cache.prepare_pages(list(pages), backend="torch_mps", trace=trace)
return prepare_pages_mps(pages, trace=trace)
if resolved_backend == "torch_cuda":
if cache is not None:
return cache.prepare_pages(list(pages), backend="torch_cuda", trace=trace)
return prepare_pages_cuda(pages, trace=trace)
return [prepare_page(page, backend=backend, cache=cache, trace=trace) for page in pages]
def score_page(
query_slice: np.ndarray,
page: PageLike,
*,
backend: BackendName = "auto",
trace: ExecutionTrace | None = None,
) -> np.ndarray:
resolved_backend = _resolve_backend(backend, page)
if resolved_backend == "torch_mps":
return score_page_mps(query_slice, page, trace=trace)
if resolved_backend == "torch_cuda":
return score_page_cuda(query_slice, page, trace=trace)
return score_page_cpu_ref(query_slice, page, trace=trace)
def score_pages(
query_slice: np.ndarray,
pages: Sequence[PageLike],
*,
backend: BackendName = "auto",
cache: PreparedPageCache | None = None,
trace: ExecutionTrace | None = None,
) -> list[np.ndarray]:
if not pages:
return []
prepared_pages = prepare_pages(pages, backend=backend, cache=cache, trace=trace)
prepared_backend = _prepared_pages_backend(prepared_pages)
if prepared_backend == "torch_mps":
return score_pages_mps(query_slice, prepared_pages, trace=trace)
if prepared_backend == "torch_cuda":
return score_pages_cuda(query_slice, prepared_pages, trace=trace)
return [score_page(query_slice, page, backend=backend, trace=trace) for page in prepared_pages]
def mix_page(
attn_weights: np.ndarray,
page: PageLike,
*,
out_acc: np.ndarray | None = None,
backend: BackendName = "auto",
trace: ExecutionTrace | None = None,
) -> np.ndarray:
resolved_backend = _resolve_backend(backend, page)
if resolved_backend == "torch_mps":
return mix_page_mps(attn_weights, page, out_acc=out_acc, trace=trace)
if resolved_backend == "torch_cuda":
return mix_page_cuda(attn_weights, page, out_acc=out_acc, trace=trace)
return mix_page_cpu_ref(attn_weights, page, out_acc=out_acc, trace=trace)
def attention_step(
query_slice: np.ndarray,
key_page: PageLike,
value_page: PageLike,
*,
backend: BackendName = "cpu_ref",
cache: PreparedPageCache | None = None,
trace: ExecutionTrace | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
prepared_key_page = prepare_page(key_page, backend=backend, cache=cache, trace=trace)
prepared_value_page = prepare_page(value_page, backend=backend, cache=cache, trace=trace)
logits = score_page(query_slice, prepared_key_page, backend=backend, trace=trace)
weights = softmax(logits)
output = mix_page(weights, prepared_value_page, backend=backend, trace=trace)
return logits, weights, output
def decode_step(
query_slice: np.ndarray,
key_pages: Sequence[PageLike],
value_pages: Sequence[PageLike],
*,
backend: BackendName = "auto",
cache: PreparedPageCache | None = None,
trace: ExecutionTrace | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
if len(key_pages) != len(value_pages):
raise ValueError("key_pages and value_pages must contain the same number of pages")
if not key_pages:
raise ValueError("decode_step requires at least one page")
return decode_step_with_page_logits(
query_slice,
key_pages,
value_pages,
backend=backend,
cache=cache,
trace=trace,
)
def decode_multi_query_step(
query_slices: np.ndarray,
key_pages: Sequence[PageLike],
value_pages: Sequence[PageLike],
*,
backend: BackendName = "auto",
cache: PreparedPageCache | None = None,
trace: ExecutionTrace | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
queries = np.asarray(query_slices, dtype=np.float32)
if queries.ndim != 2:
raise ValueError("query_slices must have shape [query_count, head_dim]")
if len(key_pages) != len(value_pages):
raise ValueError("key_pages and value_pages must contain the same number of pages")
if not key_pages:
raise ValueError("decode_multi_query_step requires at least one page")
prepared_key_pages = prepare_pages(key_pages, backend=backend, cache=cache, trace=trace)
prepared_value_pages = prepare_pages(value_pages, backend=backend, cache=cache, trace=trace)
prepared_backend = _prepared_pages_backend(prepared_key_pages)
if prepared_backend is not None and prepared_backend == _prepared_pages_backend(prepared_value_pages):
if prepared_backend == "torch_cuda":
return decode_multi_query_step_cuda(
queries,
prepared_key_pages,
prepared_value_pages,
trace=trace,
)
return decode_multi_query_step_mps(
queries,
prepared_key_pages,
prepared_value_pages,
trace=trace,
)
logits_list = []
weights_list = []
output_list = []
for query_slice in queries:
logits, weights, output = decode_step(
query_slice,
prepared_key_pages,
prepared_value_pages,
backend=backend,
trace=trace,
)
logits_list.append(logits)
weights_list.append(weights)
output_list.append(output)
return (
np.stack(logits_list, axis=0).astype(np.float32, copy=False),
np.stack(weights_list, axis=0).astype(np.float32, copy=False),
np.stack(output_list, axis=0).astype(np.float32, copy=False),
)
def decode_step_with_page_logits(
query_slice: np.ndarray,
key_pages: Sequence[PageLike],
value_pages: Sequence[PageLike],
*,
page_logits: Sequence[np.ndarray | None] | None = None,
backend: BackendName = "auto",
cache: PreparedPageCache | None = None,
trace: ExecutionTrace | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
if len(key_pages) != len(value_pages):
raise ValueError("key_pages and value_pages must contain the same number of pages")
if not key_pages:
raise ValueError("decode_step requires at least one page")
if page_logits is not None and len(page_logits) != len(key_pages):
raise ValueError("page_logits must align with key_pages")
prepared_key_pages = prepare_pages(key_pages, backend=backend, cache=cache, trace=trace)
prepared_value_pages = prepare_pages(value_pages, backend=backend, cache=cache, trace=trace)
prepared_backend = _prepared_pages_backend(prepared_key_pages)
if prepared_backend is not None and prepared_backend == _prepared_pages_backend(prepared_value_pages):
if prepared_backend == "torch_cuda":
return decode_step_cuda(
query_slice,
prepared_key_pages,
prepared_value_pages,
precomputed_page_logits=page_logits,
trace=trace,
)
return decode_step_mps(
query_slice,
prepared_key_pages,
prepared_value_pages,
precomputed_page_logits=page_logits,
trace=trace,
)
resolved_page_logits = []
for index, page in enumerate(prepared_key_pages):
cached_logits = None if page_logits is None else page_logits[index]
if cached_logits is None:
cached_logits = score_page(query_slice, page, backend=backend, trace=trace)
resolved_page_logits.append(np.asarray(cached_logits, dtype=np.float32))
logits = np.concatenate(resolved_page_logits).astype(np.float32, copy=False)
weights = softmax(logits)
output = np.zeros(prepared_key_pages[0].header.head_dim, dtype=np.float32)
offset = 0
for value_page in prepared_value_pages:
token_count = value_page.header.token_count
page_weights = weights[offset : offset + token_count]
output = mix_page(page_weights, value_page, out_acc=output, backend=backend, trace=trace)
offset += token_count
return logits, weights, output