DotCache-Arena / dotcache /attention_reference.py
DeanoCalver's picture
Initial DotCache Arena Space upload
751ad26 verified
Raw
History Blame Contribute Delete
10.5 kB
from __future__ import annotations
import numpy as np
from .decode_reference import decode_page
from .modes.m1_lut import dequantize_group_lut
from .modes.m2_key_sketch import segment_ids_for_token_count
from .modes.m4_key_project import fixed_project_basis
from .modes.m3_escape import decode_escape_payload
from .modes.turbo3 import fwht_last_dim
from .page_format import load_group_words
from .packing import unpack_bits
from .types import EncodedPage
def _pad_query(query_slice: np.ndarray, padded_head_dim: int) -> np.ndarray:
query = np.asarray(query_slice, dtype=np.float32)
if query.ndim != 1:
raise ValueError("query_slice must have shape [head_dim]")
if query.shape[0] > padded_head_dim:
raise ValueError("query head_dim exceeds padded_head_dim")
if query.shape[0] == padded_head_dim:
return query
return np.pad(query, (0, padded_head_dim - query.shape[0]), mode="constant")
def softmax(logits: np.ndarray) -> np.ndarray:
values = np.asarray(logits, dtype=np.float32)
shifted = values - np.max(values)
weights = np.exp(shifted)
return weights / np.sum(weights)
def score_page_ref(query_slice: np.ndarray, page: EncodedPage) -> np.ndarray:
header = page.header
query = _pad_query(query_slice, header.padded_head_dim)
if header.mode_default == "M3":
if page.escape_payload is None:
raise ValueError("escape payload is missing")
dense = decode_escape_payload(page.escape_payload, head_dim=header.head_dim, scales=page.escape_scales)
return dense @ query_slice.astype(np.float32)
if header.mode_default == "M2":
if page.m2_sketch is None or page.m2_basis is None:
raise ValueError("M2 page is missing sketch payload")
query_groups = query.reshape(header.num_groups, header.group_size)
logits = np.zeros(header.token_count, dtype=np.float32)
for group_index in range(header.num_groups):
group_mean = None if page.m2_mean is None else page.m2_mean[group_index].astype(np.float32)
group_basis = page.m2_basis[group_index].astype(np.float32)
if group_basis.ndim == 2:
q_proj = group_basis @ query_groups[group_index]
logits += page.m2_sketch[:, group_index, :].astype(np.float32) @ q_proj.astype(np.float32)
if group_mean is not None:
logits += np.dot(group_mean, query_groups[group_index]).astype(np.float32)
continue
segment_ids = segment_ids_for_token_count(header.token_count, int(group_basis.shape[0]))
q_proj = np.einsum("srg,g->sr", group_basis, query_groups[group_index])
logits += np.einsum("tr,tr->t", page.m2_sketch[:, group_index, :].astype(np.float32), q_proj[segment_ids])
if group_mean is not None:
logits += group_mean[segment_ids].astype(np.float32) @ query_groups[group_index]
return logits
if header.mode_default == "M4":
if page.m2_sketch is None or page.m2_mean is None:
raise ValueError("M4 page is missing projected payload")
query_groups = query.reshape(header.num_groups, header.group_size)
logits = np.zeros(header.token_count, dtype=np.float32)
for group_index in range(header.num_groups):
basis = (
np.asarray(page.m2_basis[group_index], dtype=np.float32)
if page.m2_basis is not None
else fixed_project_basis(header.group_size, int(page.m2_sketch.shape[-1]), header.project_basis)
)
q_proj = basis @ query_groups[group_index]
logits += page.m2_sketch[:, group_index, :].astype(np.float32) @ q_proj.astype(np.float32)
logits += np.dot(page.m2_mean[group_index].astype(np.float32), query_groups[group_index]).astype(np.float32)
return logits
if header.mode_default == "T3":
if page.payload is None or page.scales is None or page.codebooks is None:
raise ValueError("T3 page is missing payload or correction metadata")
rotated_query_groups = fwht_last_dim(query.reshape(header.num_groups, header.group_size))
logits = np.zeros(header.token_count, dtype=np.float32)
centroids = np.asarray(page.codebooks, dtype=np.float32)
for group_index in range(header.num_groups):
words = load_group_words(page, group_index)
codes_u8 = unpack_bits(words, header.bits, header.group_size).astype(np.int64, copy=False)
corrected = centroids[codes_u8] * page.scales[:, group_index].astype(np.float32)[:, None]
logits += corrected @ rotated_query_groups[group_index]
return logits
if page.payload is None:
raise ValueError(f"{header.mode_default} page is missing payload")
query_groups = query.reshape(header.num_groups, header.group_size)
query_group_sums = query_groups.sum(axis=-1)
logits = np.zeros(header.token_count, dtype=np.float32)
for group_index in range(header.num_groups):
words = load_group_words(page, group_index)
codes_u8 = unpack_bits(words, header.bits, header.group_size)
qg = query_groups[group_index]
if header.mode_default == "M1":
if page.codebooks is None:
raise ValueError("M1 page is missing codebooks")
group = dequantize_group_lut(codes_u8, codebook=np.asarray(page.codebooks[group_index], dtype=np.float32))
logits += group @ qg
continue
if page.scales is None:
raise ValueError("M0 page is missing scales")
codes = codes_u8.astype(np.float32)
scales = page.scales[:, group_index].astype(np.float32)
if header.quant_scheme == "affine":
if page.bias is None:
raise ValueError("affine pages require bias metadata")
int_dot = codes @ qg
bias = page.bias[:, group_index].astype(np.float32)
logits += scales * int_dot + bias * query_group_sums[group_index]
continue
zero_point = (1 << (header.bits - 1)) - 1
logits += scales * ((codes - zero_point) @ qg)
return logits
def mix_page_ref(attn_weights: np.ndarray, page: EncodedPage, out_acc: np.ndarray | None = None) -> np.ndarray:
header = page.header
weights = np.asarray(attn_weights, dtype=np.float32)
if weights.shape != (header.token_count,):
raise ValueError("attn_weights must have shape [token_count]")
output = np.zeros(header.padded_head_dim, dtype=np.float32) if out_acc is None else np.asarray(out_acc, dtype=np.float32)
if output.shape != (header.padded_head_dim,):
raise ValueError("out_acc must have shape [padded_head_dim]")
if header.mode_default == "M3":
if page.escape_payload is None:
raise ValueError("escape payload is missing")
dense = decode_escape_payload(page.escape_payload, head_dim=header.head_dim, scales=page.escape_scales)
output[: header.head_dim] += weights @ dense
return output[: header.head_dim].copy()
if header.mode_default in {"M2", "M4"}:
raise ValueError(f"{header.mode_default} is only supported for key scoring in this phase")
if header.mode_default == "T3":
if page.payload is None or page.scales is None or page.codebooks is None:
raise ValueError("T3 page is missing payload or correction metadata")
centroids = np.asarray(page.codebooks, dtype=np.float32)
for group_index in range(header.num_groups):
words = load_group_words(page, group_index)
codes_u8 = unpack_bits(words, header.bits, header.group_size).astype(np.int64, copy=False)
rotated_group = centroids[codes_u8] * page.scales[:, group_index].astype(np.float32)[:, None]
group = fwht_last_dim(rotated_group)
start = group_index * header.group_size
end = start + header.group_size
output[start:end] += weights @ group
return output[: header.head_dim].copy()
if page.payload is None:
raise ValueError(f"{header.mode_default} page is missing payload")
for group_index in range(header.num_groups):
words = load_group_words(page, group_index)
codes_u8 = unpack_bits(words, header.bits, header.group_size)
if header.mode_default == "M1":
if page.codebooks is None:
raise ValueError("M1 page is missing codebooks")
group = dequantize_group_lut(codes_u8, codebook=np.asarray(page.codebooks[group_index], dtype=np.float32))
else:
if page.scales is None:
raise ValueError("M0 page is missing scales")
codes = codes_u8.astype(np.float32)
scales = page.scales[:, group_index].astype(np.float32)[:, None]
if header.quant_scheme == "affine":
if page.bias is None:
raise ValueError("affine pages require bias metadata")
group = scales * codes + page.bias[:, group_index].astype(np.float32)[:, None]
else:
zero_point = (1 << (header.bits - 1)) - 1
group = scales * (codes - zero_point)
start = group_index * header.group_size
end = start + header.group_size
output[start:end] += weights @ group
return output[: header.head_dim].copy()
def explicit_dequantized_score(query_slice: np.ndarray, page: EncodedPage) -> np.ndarray:
dense = decode_page(page)
query = np.asarray(query_slice, dtype=np.float32)
return dense @ query
def explicit_dequantized_mix(attn_weights: np.ndarray, page: EncodedPage) -> np.ndarray:
dense = decode_page(page)
weights = np.asarray(attn_weights, dtype=np.float32)
return weights @ dense
def run_attention_reference(query_slice: np.ndarray, key_page: EncodedPage, value_page: EncodedPage) -> tuple[np.ndarray, np.ndarray]:
logits = score_page_ref(query_slice, key_page)
weights = softmax(logits)
output = mix_page_ref(weights, value_page)
return logits, output
def explicit_dequantized_attention(query_slice: np.ndarray, key_page: EncodedPage, value_page: EncodedPage) -> tuple[np.ndarray, np.ndarray]:
logits = explicit_dequantized_score(query_slice, key_page)
weights = softmax(logits)
output = explicit_dequantized_mix(weights, value_page)
return logits, output