Spaces:
Paused
Paused
| from __future__ import annotations | |
| import numpy as np | |
| from .decode_reference import decode_page | |
| from .modes.m1_lut import dequantize_group_lut | |
| from .modes.m2_key_sketch import segment_ids_for_token_count | |
| from .modes.m4_key_project import fixed_project_basis | |
| from .modes.m3_escape import decode_escape_payload | |
| from .modes.turbo3 import fwht_last_dim | |
| from .page_format import load_group_words | |
| from .packing import unpack_bits | |
| from .types import EncodedPage | |
| def _pad_query(query_slice: np.ndarray, padded_head_dim: int) -> np.ndarray: | |
| query = np.asarray(query_slice, dtype=np.float32) | |
| if query.ndim != 1: | |
| raise ValueError("query_slice must have shape [head_dim]") | |
| if query.shape[0] > padded_head_dim: | |
| raise ValueError("query head_dim exceeds padded_head_dim") | |
| if query.shape[0] == padded_head_dim: | |
| return query | |
| return np.pad(query, (0, padded_head_dim - query.shape[0]), mode="constant") | |
| def softmax(logits: np.ndarray) -> np.ndarray: | |
| values = np.asarray(logits, dtype=np.float32) | |
| shifted = values - np.max(values) | |
| weights = np.exp(shifted) | |
| return weights / np.sum(weights) | |
| def score_page_ref(query_slice: np.ndarray, page: EncodedPage) -> np.ndarray: | |
| header = page.header | |
| query = _pad_query(query_slice, header.padded_head_dim) | |
| if header.mode_default == "M3": | |
| if page.escape_payload is None: | |
| raise ValueError("escape payload is missing") | |
| dense = decode_escape_payload(page.escape_payload, head_dim=header.head_dim, scales=page.escape_scales) | |
| return dense @ query_slice.astype(np.float32) | |
| if header.mode_default == "M2": | |
| if page.m2_sketch is None or page.m2_basis is None: | |
| raise ValueError("M2 page is missing sketch payload") | |
| query_groups = query.reshape(header.num_groups, header.group_size) | |
| logits = np.zeros(header.token_count, dtype=np.float32) | |
| for group_index in range(header.num_groups): | |
| group_mean = None if page.m2_mean is None else page.m2_mean[group_index].astype(np.float32) | |
| group_basis = page.m2_basis[group_index].astype(np.float32) | |
| if group_basis.ndim == 2: | |
| q_proj = group_basis @ query_groups[group_index] | |
| logits += page.m2_sketch[:, group_index, :].astype(np.float32) @ q_proj.astype(np.float32) | |
| if group_mean is not None: | |
| logits += np.dot(group_mean, query_groups[group_index]).astype(np.float32) | |
| continue | |
| segment_ids = segment_ids_for_token_count(header.token_count, int(group_basis.shape[0])) | |
| q_proj = np.einsum("srg,g->sr", group_basis, query_groups[group_index]) | |
| logits += np.einsum("tr,tr->t", page.m2_sketch[:, group_index, :].astype(np.float32), q_proj[segment_ids]) | |
| if group_mean is not None: | |
| logits += group_mean[segment_ids].astype(np.float32) @ query_groups[group_index] | |
| return logits | |
| if header.mode_default == "M4": | |
| if page.m2_sketch is None or page.m2_mean is None: | |
| raise ValueError("M4 page is missing projected payload") | |
| query_groups = query.reshape(header.num_groups, header.group_size) | |
| logits = np.zeros(header.token_count, dtype=np.float32) | |
| for group_index in range(header.num_groups): | |
| basis = ( | |
| np.asarray(page.m2_basis[group_index], dtype=np.float32) | |
| if page.m2_basis is not None | |
| else fixed_project_basis(header.group_size, int(page.m2_sketch.shape[-1]), header.project_basis) | |
| ) | |
| q_proj = basis @ query_groups[group_index] | |
| logits += page.m2_sketch[:, group_index, :].astype(np.float32) @ q_proj.astype(np.float32) | |
| logits += np.dot(page.m2_mean[group_index].astype(np.float32), query_groups[group_index]).astype(np.float32) | |
| return logits | |
| if header.mode_default == "T3": | |
| if page.payload is None or page.scales is None or page.codebooks is None: | |
| raise ValueError("T3 page is missing payload or correction metadata") | |
| rotated_query_groups = fwht_last_dim(query.reshape(header.num_groups, header.group_size)) | |
| logits = np.zeros(header.token_count, dtype=np.float32) | |
| centroids = np.asarray(page.codebooks, dtype=np.float32) | |
| for group_index in range(header.num_groups): | |
| words = load_group_words(page, group_index) | |
| codes_u8 = unpack_bits(words, header.bits, header.group_size).astype(np.int64, copy=False) | |
| corrected = centroids[codes_u8] * page.scales[:, group_index].astype(np.float32)[:, None] | |
| logits += corrected @ rotated_query_groups[group_index] | |
| return logits | |
| if page.payload is None: | |
| raise ValueError(f"{header.mode_default} page is missing payload") | |
| query_groups = query.reshape(header.num_groups, header.group_size) | |
| query_group_sums = query_groups.sum(axis=-1) | |
| logits = np.zeros(header.token_count, dtype=np.float32) | |
| for group_index in range(header.num_groups): | |
| words = load_group_words(page, group_index) | |
| codes_u8 = unpack_bits(words, header.bits, header.group_size) | |
| qg = query_groups[group_index] | |
| if header.mode_default == "M1": | |
| if page.codebooks is None: | |
| raise ValueError("M1 page is missing codebooks") | |
| group = dequantize_group_lut(codes_u8, codebook=np.asarray(page.codebooks[group_index], dtype=np.float32)) | |
| logits += group @ qg | |
| continue | |
| if page.scales is None: | |
| raise ValueError("M0 page is missing scales") | |
| codes = codes_u8.astype(np.float32) | |
| scales = page.scales[:, group_index].astype(np.float32) | |
| if header.quant_scheme == "affine": | |
| if page.bias is None: | |
| raise ValueError("affine pages require bias metadata") | |
| int_dot = codes @ qg | |
| bias = page.bias[:, group_index].astype(np.float32) | |
| logits += scales * int_dot + bias * query_group_sums[group_index] | |
| continue | |
| zero_point = (1 << (header.bits - 1)) - 1 | |
| logits += scales * ((codes - zero_point) @ qg) | |
| return logits | |
| def mix_page_ref(attn_weights: np.ndarray, page: EncodedPage, out_acc: np.ndarray | None = None) -> np.ndarray: | |
| header = page.header | |
| weights = np.asarray(attn_weights, dtype=np.float32) | |
| if weights.shape != (header.token_count,): | |
| raise ValueError("attn_weights must have shape [token_count]") | |
| output = np.zeros(header.padded_head_dim, dtype=np.float32) if out_acc is None else np.asarray(out_acc, dtype=np.float32) | |
| if output.shape != (header.padded_head_dim,): | |
| raise ValueError("out_acc must have shape [padded_head_dim]") | |
| if header.mode_default == "M3": | |
| if page.escape_payload is None: | |
| raise ValueError("escape payload is missing") | |
| dense = decode_escape_payload(page.escape_payload, head_dim=header.head_dim, scales=page.escape_scales) | |
| output[: header.head_dim] += weights @ dense | |
| return output[: header.head_dim].copy() | |
| if header.mode_default in {"M2", "M4"}: | |
| raise ValueError(f"{header.mode_default} is only supported for key scoring in this phase") | |
| if header.mode_default == "T3": | |
| if page.payload is None or page.scales is None or page.codebooks is None: | |
| raise ValueError("T3 page is missing payload or correction metadata") | |
| centroids = np.asarray(page.codebooks, dtype=np.float32) | |
| for group_index in range(header.num_groups): | |
| words = load_group_words(page, group_index) | |
| codes_u8 = unpack_bits(words, header.bits, header.group_size).astype(np.int64, copy=False) | |
| rotated_group = centroids[codes_u8] * page.scales[:, group_index].astype(np.float32)[:, None] | |
| group = fwht_last_dim(rotated_group) | |
| start = group_index * header.group_size | |
| end = start + header.group_size | |
| output[start:end] += weights @ group | |
| return output[: header.head_dim].copy() | |
| if page.payload is None: | |
| raise ValueError(f"{header.mode_default} page is missing payload") | |
| for group_index in range(header.num_groups): | |
| words = load_group_words(page, group_index) | |
| codes_u8 = unpack_bits(words, header.bits, header.group_size) | |
| if header.mode_default == "M1": | |
| if page.codebooks is None: | |
| raise ValueError("M1 page is missing codebooks") | |
| group = dequantize_group_lut(codes_u8, codebook=np.asarray(page.codebooks[group_index], dtype=np.float32)) | |
| else: | |
| if page.scales is None: | |
| raise ValueError("M0 page is missing scales") | |
| codes = codes_u8.astype(np.float32) | |
| scales = page.scales[:, group_index].astype(np.float32)[:, None] | |
| if header.quant_scheme == "affine": | |
| if page.bias is None: | |
| raise ValueError("affine pages require bias metadata") | |
| group = scales * codes + page.bias[:, group_index].astype(np.float32)[:, None] | |
| else: | |
| zero_point = (1 << (header.bits - 1)) - 1 | |
| group = scales * (codes - zero_point) | |
| start = group_index * header.group_size | |
| end = start + header.group_size | |
| output[start:end] += weights @ group | |
| return output[: header.head_dim].copy() | |
| def explicit_dequantized_score(query_slice: np.ndarray, page: EncodedPage) -> np.ndarray: | |
| dense = decode_page(page) | |
| query = np.asarray(query_slice, dtype=np.float32) | |
| return dense @ query | |
| def explicit_dequantized_mix(attn_weights: np.ndarray, page: EncodedPage) -> np.ndarray: | |
| dense = decode_page(page) | |
| weights = np.asarray(attn_weights, dtype=np.float32) | |
| return weights @ dense | |
| def run_attention_reference(query_slice: np.ndarray, key_page: EncodedPage, value_page: EncodedPage) -> tuple[np.ndarray, np.ndarray]: | |
| logits = score_page_ref(query_slice, key_page) | |
| weights = softmax(logits) | |
| output = mix_page_ref(weights, value_page) | |
| return logits, output | |
| def explicit_dequantized_attention(query_slice: np.ndarray, key_page: EncodedPage, value_page: EncodedPage) -> tuple[np.ndarray, np.ndarray]: | |
| logits = explicit_dequantized_score(query_slice, key_page) | |
| weights = softmax(logits) | |
| output = explicit_dequantized_mix(weights, value_page) | |
| return logits, output | |