File size: 3,393 Bytes
751ad26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import annotations

import json
from dataclasses import asdict, dataclass
from typing import Literal

import numpy as np

Kind = Literal["K", "V"]
Layout = Literal["group_major", "token_major"]
Mode = Literal["M0", "M1", "M2", "M3", "M4", "T3"]
QuantScheme = Literal["affine", "symmetric", "lut", "sketch", "project", "turbo3"]


@dataclass(slots=True)
class PageHeader:
    layer_id: int
    kv_head_id: int
    kind: Kind
    token_start: int
    token_count: int
    head_dim: int
    padded_head_dim: int
    group_size: int
    num_groups: int
    bits: int
    words_per_group: int
    mode_default: Mode
    layout: Layout
    quant_scheme: QuantScheme
    policy_id: str = "exact_baseline"
    sensitivity_tier: str = "exact"
    fallback_reason: str = ""
    age_bucket: str = "aged"
    escape_dtype: str = "float16"
    project_basis: str = "hadamard"

    def to_dict(self) -> dict[str, int | str]:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict[str, int | str]) -> "PageHeader":
        return cls(**data)

    def to_json(self) -> str:
        return json.dumps(self.to_dict(), sort_keys=True)

    @classmethod
    def from_json(cls, payload: str) -> "PageHeader":
        return cls.from_dict(json.loads(payload))


@dataclass(slots=True)
class EncodedPage:
    header: PageHeader
    payload: np.ndarray | None = None
    scales: np.ndarray | None = None
    bias: np.ndarray | None = None
    codebooks: np.ndarray | None = None
    m2_sketch: np.ndarray | None = None
    m2_basis: np.ndarray | None = None
    m2_mean: np.ndarray | None = None
    lut_segment_count: int = 1
    escape_payload: np.ndarray | None = None
    escape_scales: np.ndarray | None = None
    requested_mode: str | None = None
    trial_quant_error: float | None = None
    trial_token_p95_error: float | None = None
    runtime_page_mean: np.ndarray | None = None
    runtime_page_sketch: np.ndarray | None = None
    runtime_page_min: np.ndarray | None = None
    runtime_page_max: np.ndarray | None = None
    full_page_decode_calls: int = 0
    decode_group_calls: int = 0

    @property
    def payload_nbytes(self) -> int:
        total = 0
        if self.payload is not None:
            total += int(self.payload.nbytes)
        if self.escape_payload is not None:
            total += int(self.escape_payload.nbytes)
        return total

    @property
    def metadata_nbytes(self) -> int:
        total = len(self.header.to_json().encode("utf-8"))
        if self.scales is not None:
            total += int(self.scales.nbytes)
        if self.bias is not None:
            total += int(self.bias.nbytes)
        if self.codebooks is not None:
            total += int(self.codebooks.nbytes)
        if self.m2_sketch is not None:
            total += int(self.m2_sketch.nbytes)
        if self.m2_basis is not None:
            total += int(self.m2_basis.nbytes)
        if self.m2_mean is not None:
            total += int(self.m2_mean.nbytes)
        if self.escape_scales is not None:
            total += int(self.escape_scales.nbytes)
        return total

    @property
    def total_nbytes(self) -> int:
        return self.payload_nbytes + self.metadata_nbytes

    def record_full_decode(self) -> None:
        self.full_page_decode_calls += 1

    def record_group_decode(self) -> None:
        self.decode_group_calls += 1