from __future__ import annotations from typing import Literal, Sequence import numpy as np from .attention_reference import softmax from .backends import ( PreparedPageTorch, cuda_available, decode_multi_query_step_cuda, decode_step_cuda, decode_multi_query_step_mps, decode_step_mps, mix_page_cpu_ref, mix_page_cuda, mix_page_mps, mps_available, page_supported_cuda, page_supported_mps, prepare_page_cuda, prepare_page_mps, prepare_pages_cuda, prepare_pages_mps, score_pages_cuda, score_pages_mps, score_page_cpu_ref, score_page_cuda, score_page_mps, ) from .page_cache import PreparedPageCache from .tracing import ExecutionTrace from .types import EncodedPage BackendName = Literal["cpu_ref", "torch_mps", "torch_cuda", "auto"] PageLike = EncodedPage | PreparedPageTorch def _resolve_backend(backend: BackendName, page: PageLike) -> Literal["cpu_ref", "torch_mps", "torch_cuda"]: if backend == "cpu_ref": return "cpu_ref" if backend == "torch_mps": if not mps_available(): raise RuntimeError("torch_mps is unavailable on this machine") if not page_supported_mps(page): raise ValueError("page is unsupported by torch_mps in this phase") return "torch_mps" if backend == "torch_cuda": if not cuda_available(): raise RuntimeError("torch_cuda is unavailable on this machine") if not page_supported_cuda(page): raise ValueError("page is unsupported by torch_cuda in this phase") return "torch_cuda" if isinstance(page, PreparedPageTorch): return "torch_cuda" if page.device_type == "cuda" else "torch_mps" if cuda_available() and page_supported_cuda(page): return "torch_cuda" if mps_available() and page_supported_mps(page): return "torch_mps" return "cpu_ref" def _prepared_pages_backend(pages: Sequence[PageLike]) -> Literal["torch_mps", "torch_cuda"] | None: if not pages or not all(isinstance(page, PreparedPageTorch) for page in pages): return None device_type = pages[0].device_type if any(page.device_type != device_type for page in pages): raise ValueError("prepared torch pages must all target the same device") return "torch_cuda" if device_type == "cuda" else "torch_mps" def prepare_page( page: PageLike, *, backend: BackendName = "auto", cache: PreparedPageCache | None = None, trace: ExecutionTrace | None = None, ) -> PageLike: resolved_backend = _resolve_backend(backend, page) if resolved_backend == "torch_mps": if cache is not None: return cache.prepare_page(page, backend="torch_mps", trace=trace) return prepare_page_mps(page, trace=trace) if resolved_backend == "torch_cuda": if cache is not None: return cache.prepare_page(page, backend="torch_cuda", trace=trace) return prepare_page_cuda(page, trace=trace) return page.source_page if isinstance(page, PreparedPageTorch) else page def prepare_pages( pages: Sequence[PageLike], *, backend: BackendName = "auto", cache: PreparedPageCache | None = None, trace: ExecutionTrace | None = None, ) -> list[PageLike]: if pages: resolved_backend = _resolve_backend(backend, pages[0]) if resolved_backend == "torch_mps": if cache is not None: return cache.prepare_pages(list(pages), backend="torch_mps", trace=trace) return prepare_pages_mps(pages, trace=trace) if resolved_backend == "torch_cuda": if cache is not None: return cache.prepare_pages(list(pages), backend="torch_cuda", trace=trace) return prepare_pages_cuda(pages, trace=trace) return [prepare_page(page, backend=backend, cache=cache, trace=trace) for page in pages] def score_page( query_slice: np.ndarray, page: PageLike, *, backend: BackendName = "auto", trace: ExecutionTrace | None = None, ) -> np.ndarray: resolved_backend = _resolve_backend(backend, page) if resolved_backend == "torch_mps": return score_page_mps(query_slice, page, trace=trace) if resolved_backend == "torch_cuda": return score_page_cuda(query_slice, page, trace=trace) return score_page_cpu_ref(query_slice, page, trace=trace) def score_pages( query_slice: np.ndarray, pages: Sequence[PageLike], *, backend: BackendName = "auto", cache: PreparedPageCache | None = None, trace: ExecutionTrace | None = None, ) -> list[np.ndarray]: if not pages: return [] prepared_pages = prepare_pages(pages, backend=backend, cache=cache, trace=trace) prepared_backend = _prepared_pages_backend(prepared_pages) if prepared_backend == "torch_mps": return score_pages_mps(query_slice, prepared_pages, trace=trace) if prepared_backend == "torch_cuda": return score_pages_cuda(query_slice, prepared_pages, trace=trace) return [score_page(query_slice, page, backend=backend, trace=trace) for page in prepared_pages] def mix_page( attn_weights: np.ndarray, page: PageLike, *, out_acc: np.ndarray | None = None, backend: BackendName = "auto", trace: ExecutionTrace | None = None, ) -> np.ndarray: resolved_backend = _resolve_backend(backend, page) if resolved_backend == "torch_mps": return mix_page_mps(attn_weights, page, out_acc=out_acc, trace=trace) if resolved_backend == "torch_cuda": return mix_page_cuda(attn_weights, page, out_acc=out_acc, trace=trace) return mix_page_cpu_ref(attn_weights, page, out_acc=out_acc, trace=trace) def attention_step( query_slice: np.ndarray, key_page: PageLike, value_page: PageLike, *, backend: BackendName = "cpu_ref", cache: PreparedPageCache | None = None, trace: ExecutionTrace | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: prepared_key_page = prepare_page(key_page, backend=backend, cache=cache, trace=trace) prepared_value_page = prepare_page(value_page, backend=backend, cache=cache, trace=trace) logits = score_page(query_slice, prepared_key_page, backend=backend, trace=trace) weights = softmax(logits) output = mix_page(weights, prepared_value_page, backend=backend, trace=trace) return logits, weights, output def decode_step( query_slice: np.ndarray, key_pages: Sequence[PageLike], value_pages: Sequence[PageLike], *, backend: BackendName = "auto", cache: PreparedPageCache | None = None, trace: ExecutionTrace | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: if len(key_pages) != len(value_pages): raise ValueError("key_pages and value_pages must contain the same number of pages") if not key_pages: raise ValueError("decode_step requires at least one page") return decode_step_with_page_logits( query_slice, key_pages, value_pages, backend=backend, cache=cache, trace=trace, ) def decode_multi_query_step( query_slices: np.ndarray, key_pages: Sequence[PageLike], value_pages: Sequence[PageLike], *, backend: BackendName = "auto", cache: PreparedPageCache | None = None, trace: ExecutionTrace | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: queries = np.asarray(query_slices, dtype=np.float32) if queries.ndim != 2: raise ValueError("query_slices must have shape [query_count, head_dim]") if len(key_pages) != len(value_pages): raise ValueError("key_pages and value_pages must contain the same number of pages") if not key_pages: raise ValueError("decode_multi_query_step requires at least one page") prepared_key_pages = prepare_pages(key_pages, backend=backend, cache=cache, trace=trace) prepared_value_pages = prepare_pages(value_pages, backend=backend, cache=cache, trace=trace) prepared_backend = _prepared_pages_backend(prepared_key_pages) if prepared_backend is not None and prepared_backend == _prepared_pages_backend(prepared_value_pages): if prepared_backend == "torch_cuda": return decode_multi_query_step_cuda( queries, prepared_key_pages, prepared_value_pages, trace=trace, ) return decode_multi_query_step_mps( queries, prepared_key_pages, prepared_value_pages, trace=trace, ) logits_list = [] weights_list = [] output_list = [] for query_slice in queries: logits, weights, output = decode_step( query_slice, prepared_key_pages, prepared_value_pages, backend=backend, trace=trace, ) logits_list.append(logits) weights_list.append(weights) output_list.append(output) return ( np.stack(logits_list, axis=0).astype(np.float32, copy=False), np.stack(weights_list, axis=0).astype(np.float32, copy=False), np.stack(output_list, axis=0).astype(np.float32, copy=False), ) def decode_step_with_page_logits( query_slice: np.ndarray, key_pages: Sequence[PageLike], value_pages: Sequence[PageLike], *, page_logits: Sequence[np.ndarray | None] | None = None, backend: BackendName = "auto", cache: PreparedPageCache | None = None, trace: ExecutionTrace | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: if len(key_pages) != len(value_pages): raise ValueError("key_pages and value_pages must contain the same number of pages") if not key_pages: raise ValueError("decode_step requires at least one page") if page_logits is not None and len(page_logits) != len(key_pages): raise ValueError("page_logits must align with key_pages") prepared_key_pages = prepare_pages(key_pages, backend=backend, cache=cache, trace=trace) prepared_value_pages = prepare_pages(value_pages, backend=backend, cache=cache, trace=trace) prepared_backend = _prepared_pages_backend(prepared_key_pages) if prepared_backend is not None and prepared_backend == _prepared_pages_backend(prepared_value_pages): if prepared_backend == "torch_cuda": return decode_step_cuda( query_slice, prepared_key_pages, prepared_value_pages, precomputed_page_logits=page_logits, trace=trace, ) return decode_step_mps( query_slice, prepared_key_pages, prepared_value_pages, precomputed_page_logits=page_logits, trace=trace, ) resolved_page_logits = [] for index, page in enumerate(prepared_key_pages): cached_logits = None if page_logits is None else page_logits[index] if cached_logits is None: cached_logits = score_page(query_slice, page, backend=backend, trace=trace) resolved_page_logits.append(np.asarray(cached_logits, dtype=np.float32)) logits = np.concatenate(resolved_page_logits).astype(np.float32, copy=False) weights = softmax(logits) output = np.zeros(prepared_key_pages[0].header.head_dim, dtype=np.float32) offset = 0 for value_page in prepared_value_pages: token_count = value_page.header.token_count page_weights = weights[offset : offset + token_count] output = mix_page(page_weights, value_page, out_acc=output, backend=backend, trace=trace) offset += token_count return logits, weights, output