File size: 1,789 Bytes
751ad26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from __future__ import annotations

import numpy as np

from ..attention_reference import mix_page_ref, score_page_ref
from ..tracing import ExecutionTrace
from ..types import EncodedPage
from .torch_mps import PreparedPageTorch


def _source_page(page: EncodedPage | PreparedPageTorch) -> EncodedPage:
    if isinstance(page, PreparedPageTorch):
        return page.source_page
    return page


def _record_trace(page: EncodedPage, trace: ExecutionTrace | None) -> None:
    if trace is None:
        return
    trace.record_page_read(page.payload_nbytes, page.metadata_nbytes)
    if page.header.mode_default in ("M0", "M1", "T3"):
        trace.record_temporary(page.header.token_count * page.header.group_size * np.dtype(np.float32).itemsize)
    elif page.header.mode_default in ("M2", "M4") and page.m2_sketch is not None:
        trace.record_temporary(int(page.m2_sketch.shape[0] * page.m2_sketch.shape[-1] * np.dtype(np.float32).itemsize))


def score_page_cpu_ref(
    query_slice: np.ndarray,
    page: EncodedPage | PreparedPageTorch,
    *,
    trace: ExecutionTrace | None = None,
) -> np.ndarray:
    source_page = _source_page(page)
    _record_trace(source_page, trace)
    return score_page_ref(query_slice, source_page)


def mix_page_cpu_ref(
    attn_weights: np.ndarray,
    page: EncodedPage | PreparedPageTorch,
    *,
    out_acc: np.ndarray | None = None,
    trace: ExecutionTrace | None = None,
) -> np.ndarray:
    source_page = _source_page(page)
    _record_trace(source_page, trace)
    mixed = mix_page_ref(attn_weights, source_page)
    if out_acc is None:
        return mixed
    output = np.asarray(out_acc, dtype=np.float32)
    if output.shape != mixed.shape:
        raise ValueError("out_acc must have shape [head_dim]")
    return output + mixed