File size: 9,276 Bytes
751ad26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from __future__ import annotations

import math
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Literal

import numpy as np

StateMode = Literal["M0", "M3"]


@dataclass(slots=True)
class StateTileSpec:
    state_rows: int
    state_cols: int
    group_size: int = 32
    bits: int = 8
    mode: StateMode = "M0"
    escape_dtype: str = "float32"

    def to_dict(self) -> dict[str, int | str]:
        return asdict(self)


@dataclass(slots=True)
class StateLayerRecord:
    layer_id: int
    layer_type: str
    state_family: str
    conv_state_bytes: int
    recurrent_state_bytes: int
    layer_state_bytes: int
    state_shapes: dict[str, list[int]] = field(default_factory=dict)
    state_delta_norms: list[dict[str, float | int]] = field(default_factory=list)

    def to_dict(self) -> dict[str, Any]:
        return asdict(self)


@dataclass(slots=True)
class StateAblationResult:
    stage_name: str
    bits: int | None
    max_abs_error: float
    max_rel_error: float
    output_max_abs_error: float
    error_grows_step_to_step: bool
    per_layer_max_abs_error: dict[str, float] = field(default_factory=dict)
    per_layer_max_rel_error: dict[str, float] = field(default_factory=dict)
    per_layer_output_max_abs_error: dict[str, float] = field(default_factory=dict)
    per_step_output_max_abs_error: list[float] = field(default_factory=list)

    def to_dict(self) -> dict[str, Any]:
        return asdict(self)


@dataclass(slots=True)
class StateSimResult:
    mode: StateMode
    bits: int
    group_size: int
    renorm_interval: int
    bytes_per_token: int
    bytes_per_layer: int
    effective_compression_ratio: float
    update_error_curve: list[float] = field(default_factory=list)
    readout_error_curve: list[float] = field(default_factory=list)

    def to_dict(self) -> dict[str, Any]:
        return asdict(self)


@dataclass(slots=True)
class CapturedStateSample:
    source: str
    state_kind: str
    layer_id: int
    prompt_length: int
    token_indices: list[int]
    initial_state: np.ndarray
    update_deltas: np.ndarray

    @property
    def state_rows(self) -> int:
        return int(np.prod(self.initial_state.shape[:-1], dtype=np.int64)) if self.initial_state.ndim > 1 else 1

    @property
    def state_cols(self) -> int:
        return int(self.initial_state.shape[-1])

    @property
    def steps(self) -> int:
        return int(self.update_deltas.shape[0])

    def to_dict(self) -> dict[str, Any]:
        return {
            "source": self.source,
            "state_kind": self.state_kind,
            "layer_id": self.layer_id,
            "prompt_length": self.prompt_length,
            "token_indices": list(self.token_indices),
            "state_rows": self.state_rows,
            "state_cols": self.state_cols,
            "steps": self.steps,
        }


def _codec_dtype_bytes(escape_dtype: str) -> int:
    if escape_dtype == "float16":
        return np.dtype(np.float16).itemsize
    if escape_dtype == "float32":
        return np.dtype(np.float32).itemsize
    raise ValueError(f"unsupported StateCache escape_dtype {escape_dtype!r}")


def _renorm_rows(array: np.ndarray) -> np.ndarray:
    flat = array.reshape(-1, array.shape[-1]).astype(np.float32, copy=True)
    row_norms = np.linalg.norm(flat, axis=-1, keepdims=True)
    row_norms = np.maximum(row_norms, 1e-8)
    flat = flat / row_norms
    return flat.reshape(array.shape)


def simulate_state_codec(tile: np.ndarray, spec: StateTileSpec) -> tuple[np.ndarray, int, int]:
    array = np.asarray(tile, dtype=np.float32)
    if array.ndim < 2:
        raise ValueError("state tiles must be at least 2D")
    if spec.mode == "M3":
        payload_nbytes = int(array.size * _codec_dtype_bytes(spec.escape_dtype))
        return array.astype(np.float32, copy=True), payload_nbytes, 0
    if spec.mode != "M0":
        raise ValueError(f"unsupported StateCache mode {spec.mode!r}")
    if spec.bits <= 0:
        raise ValueError("spec.bits must be positive")

    flat = array.reshape(-1, array.shape[-1]).astype(np.float32, copy=False)
    group_size = max(int(spec.group_size), 1)
    levels = max((1 << int(spec.bits)) - 1, 1)
    decoded = np.empty_like(flat)
    num_groups = 0

    for row_id in range(flat.shape[0]):
        row = flat[row_id]
        for start in range(0, row.shape[0], group_size):
            end = min(start + group_size, row.shape[0])
            group = row[start:end]
            num_groups += 1
            lo = float(group.min(initial=0.0))
            hi = float(group.max(initial=0.0))
            if math.isclose(lo, hi, rel_tol=0.0, abs_tol=1e-8):
                decoded[row_id, start:end] = lo
                continue
            scale = (hi - lo) / float(levels)
            quantized = np.rint((group - lo) / scale).clip(0, levels)
            decoded[row_id, start:end] = quantized.astype(np.float32) * np.float32(scale) + np.float32(lo)

    payload_nbytes = int(math.ceil(array.size * int(spec.bits) / 8.0))
    metadata_nbytes = int(num_groups * 8)
    return decoded.reshape(array.shape), payload_nbytes, metadata_nbytes


def simulate_state_sequence(
    initial_state: np.ndarray,
    update_deltas: np.ndarray,
    readout_projections: np.ndarray,
    *,
    spec: StateTileSpec,
    renorm_interval: int = 0,
) -> StateSimResult:
    dense_state = np.asarray(initial_state, dtype=np.float32).copy()
    approx_state = dense_state.copy()
    deltas = np.asarray(update_deltas, dtype=np.float32)
    projections = np.asarray(readout_projections, dtype=np.float32)
    if deltas.ndim < 3:
        raise ValueError("update_deltas must have shape [steps, rows, cols]")
    if projections.ndim != 3:
        raise ValueError("readout_projections must have shape [steps, cols, output_dim]")
    if deltas.shape[0] != projections.shape[0]:
        raise ValueError("update_deltas and readout_projections must have the same step count")
    if tuple(deltas.shape[1:]) != tuple(dense_state.shape):
        raise ValueError("update_deltas shape must match initial_state")
    if projections.shape[1] != dense_state.shape[-1]:
        raise ValueError("readout projection input dimension must match state_cols")

    update_error_curve: list[float] = []
    readout_error_curve: list[float] = []
    stored_bytes = dense_state.nbytes
    compressed_bytes = stored_bytes

    for step_index in range(deltas.shape[0]):
        decoded_state, payload_nbytes, metadata_nbytes = simulate_state_codec(approx_state, spec)
        compressed_bytes = int(payload_nbytes + metadata_nbytes)
        dense_state = dense_state + deltas[step_index]
        approx_after = decoded_state + deltas[step_index]

        if renorm_interval > 0 and (step_index + 1) % int(renorm_interval) == 0:
            dense_state = _renorm_rows(dense_state)
            approx_after = _renorm_rows(approx_after)

        approx_state, _, _ = simulate_state_codec(approx_after, spec)

        dense_readout = dense_state @ projections[step_index]
        approx_readout = approx_state @ projections[step_index]
        update_error_curve.append(float(np.max(np.abs(approx_state - dense_state))))
        readout_error_curve.append(float(np.max(np.abs(approx_readout - dense_readout))))

    bytes_per_layer = int(compressed_bytes)
    bytes_per_token = int(bytes_per_layer * 2)
    effective_ratio = float(stored_bytes / max(bytes_per_layer, 1))
    return StateSimResult(
        mode=spec.mode,
        bits=int(spec.bits),
        group_size=int(spec.group_size),
        renorm_interval=int(renorm_interval),
        bytes_per_token=bytes_per_token,
        bytes_per_layer=bytes_per_layer,
        effective_compression_ratio=effective_ratio,
        update_error_curve=update_error_curve,
        readout_error_curve=readout_error_curve,
    )


def load_captured_state_sample(path: str | Path) -> CapturedStateSample:
    with np.load(Path(path), allow_pickle=False) as data:
        source = str(data["source"].item())
        state_kind = str(data["state_kind"].item())
        layer_id = int(data["layer_id"].item())
        prompt_length = int(data["prompt_length"].item())
        token_indices = [int(value) for value in np.asarray(data["token_indices"]).tolist()]
        initial_state = np.asarray(data["initial_state"], dtype=np.float32)
        update_deltas = np.asarray(data["update_deltas"], dtype=np.float32)
    if initial_state.ndim < 2:
        raise ValueError("captured initial_state must be at least 2D")
    if update_deltas.ndim < 3:
        raise ValueError("captured update_deltas must be at least 3D")
    if tuple(update_deltas.shape[1:]) != tuple(initial_state.shape):
        raise ValueError("captured update_deltas shape must match initial_state shape")
    return CapturedStateSample(
        source=source,
        state_kind=state_kind,
        layer_id=layer_id,
        prompt_length=prompt_length,
        token_indices=token_indices,
        initial_state=initial_state,
        update_deltas=update_deltas,
    )


__all__ = [
    "CapturedStateSample",
    "StateAblationResult",
    "StateLayerRecord",
    "StateSimResult",
    "StateTileSpec",
    "load_captured_state_sample",
    "simulate_state_codec",
    "simulate_state_sequence",
]