| import math |
| from dataclasses import dataclass |
| import site |
| import sys |
| from pathlib import Path |
|
|
| from .linalg import Matrix, Vector, identity, invert_matrix, matvec |
|
|
| _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor" |
| for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"): |
| if _vendor_path.exists(): |
| vendor_text = str(_vendor_path) |
| if vendor_text not in sys.path: |
| sys.path.insert(0, vendor_text) |
|
|
| try: |
| import numpy as np |
| except ModuleNotFoundError: |
| user_site = site.getusersitepackages() |
| if user_site and user_site not in sys.path: |
| sys.path.append(user_site) |
| try: |
| import numpy as np |
| except ModuleNotFoundError: |
| np = None |
|
|
| try: |
| from numba import njit as _numba_njit |
| except (ImportError, ModuleNotFoundError, OSError): |
| _numba_njit = None |
|
|
| HAS_COMPILED_HIPPO_KERNEL = _numba_njit is not None |
|
|
|
|
| if _numba_njit is not None: |
| @_numba_njit(cache=True) |
| def _hippo_legs_propagate_stack_numba(states: object, steps: object) -> object: |
| rows = states.shape[0] |
| width = states.shape[1] |
| propagated = np.empty_like(states) |
| prefixes = np.zeros(rows, dtype=states.dtype) |
| for column in range(width): |
| basis = math.sqrt(2 * column + 1) |
| for row in range(rows): |
| diagonal = 1.0 + (steps[row] * (column + 1)) |
| value = (states[row, column] - (steps[row] * basis * prefixes[row])) / diagonal |
| propagated[row, column] = value |
| prefixes[row] += basis * value |
| return propagated |
|
|
| @_numba_njit(cache=True) |
| def _hippo_document_combined_states_numba( |
| token_ids: object, |
| embeddings: object, |
| trace_embeddings: object, |
| timescales: object, |
| trace_gain: object, |
| input_projection: object, |
| drive_primary: object, |
| drive_secondary: object, |
| drive_tertiary: object, |
| state_dim: int, |
| embedding_dim: int, |
| ) -> object: |
| steps = max(0, token_ids.shape[0] - 1) |
| timescale_count = timescales.shape[0] |
| feature_count = timescale_count * (state_dim + embedding_dim) |
| combined = np.zeros((steps, feature_count), dtype=embeddings.dtype) |
| hidden = np.zeros((timescale_count, state_dim), dtype=embeddings.dtype) |
| traces = np.zeros((timescale_count, embedding_dim), dtype=embeddings.dtype) |
| prefixes = np.zeros(timescale_count, dtype=embeddings.dtype) |
| for token_index in range(steps): |
| token_id = token_ids[token_index] |
| for timescale_index in range(timescale_count): |
| prefixes[timescale_index] = 0.0 |
| for column in range(state_dim): |
| embedding_value = ( |
| embeddings[token_id, drive_primary[column]] |
| + (0.5 * embeddings[token_id, drive_secondary[column]]) |
| - (0.25 * embeddings[token_id, drive_tertiary[column]]) |
| ) |
| basis = math.sqrt(2 * column + 1) |
| for timescale_index in range(timescale_count): |
| step = timescales[timescale_index] |
| diagonal = 1.0 + (step * (column + 1)) |
| value = ( |
| hidden[timescale_index, column] |
| - (step * basis * prefixes[timescale_index]) |
| ) / diagonal |
| value += input_projection[timescale_index, column] * embedding_value |
| hidden[timescale_index, column] = value |
| prefixes[timescale_index] += basis * value |
| for timescale_index in range(timescale_count): |
| base = timescale_index * (state_dim + embedding_dim) |
| for column in range(state_dim): |
| combined[token_index, base + column] = hidden[timescale_index, column] |
| trace_base = base + state_dim |
| gain = trace_gain[timescale_index] |
| for column in range(embedding_dim): |
| traces[timescale_index, column] += gain * trace_embeddings[token_id, column] |
| combined[token_index, trace_base + column] = traces[timescale_index, column] |
| return combined |
|
|
| @_numba_njit(cache=True) |
| def _hippo_document_selected_combined_states_numba( |
| token_ids: object, |
| selected_positions: object, |
| embeddings: object, |
| trace_embeddings: object, |
| timescales: object, |
| trace_gain: object, |
| input_projection: object, |
| drive_primary: object, |
| drive_secondary: object, |
| drive_tertiary: object, |
| state_dim: int, |
| embedding_dim: int, |
| ) -> object: |
| steps = max(0, token_ids.shape[0] - 1) |
| selected_count = selected_positions.shape[0] |
| timescale_count = timescales.shape[0] |
| feature_count = timescale_count * (state_dim + embedding_dim) |
| combined = np.zeros((selected_count, feature_count), dtype=embeddings.dtype) |
| hidden = np.zeros((timescale_count, state_dim), dtype=embeddings.dtype) |
| traces = np.zeros((timescale_count, embedding_dim), dtype=embeddings.dtype) |
| prefixes = np.zeros(timescale_count, dtype=embeddings.dtype) |
| selected_cursor = 0 |
| for token_index in range(steps): |
| token_id = token_ids[token_index] |
| for timescale_index in range(timescale_count): |
| prefixes[timescale_index] = 0.0 |
| for column in range(state_dim): |
| embedding_value = ( |
| embeddings[token_id, drive_primary[column]] |
| + (0.5 * embeddings[token_id, drive_secondary[column]]) |
| - (0.25 * embeddings[token_id, drive_tertiary[column]]) |
| ) |
| basis = math.sqrt(2 * column + 1) |
| for timescale_index in range(timescale_count): |
| step = timescales[timescale_index] |
| diagonal = 1.0 + (step * (column + 1)) |
| value = ( |
| hidden[timescale_index, column] |
| - (step * basis * prefixes[timescale_index]) |
| ) / diagonal |
| value += input_projection[timescale_index, column] * embedding_value |
| hidden[timescale_index, column] = value |
| prefixes[timescale_index] += basis * value |
| for timescale_index in range(timescale_count): |
| gain = trace_gain[timescale_index] |
| for column in range(embedding_dim): |
| traces[timescale_index, column] += gain * trace_embeddings[token_id, column] |
| if ( |
| selected_cursor < selected_count |
| and token_index == selected_positions[selected_cursor] |
| ): |
| for timescale_index in range(timescale_count): |
| base = timescale_index * (state_dim + embedding_dim) |
| for column in range(state_dim): |
| combined[selected_cursor, base + column] = hidden[timescale_index, column] |
| trace_base = base + state_dim |
| for column in range(embedding_dim): |
| combined[selected_cursor, trace_base + column] = traces[timescale_index, column] |
| selected_cursor += 1 |
| return combined |
| else: |
| _hippo_legs_propagate_stack_numba = None |
| _hippo_document_combined_states_numba = None |
| _hippo_document_selected_combined_states_numba = None |
|
|
|
|
| def hippo_legs_matrix(order: int) -> tuple[Matrix, Vector]: |
| a_matrix = [[0.0 for _ in range(order)] for _ in range(order)] |
| b_vector = [0.0 for _ in range(order)] |
|
|
| for row in range(order): |
| for col in range(order): |
| if row > col: |
| a_matrix[row][col] = -math.sqrt(2 * row + 1) * math.sqrt(2 * col + 1) |
| elif row == col: |
| a_matrix[row][col] = -(row + 1) |
| b_vector[row] = math.sqrt(2 * row + 1) |
|
|
| return a_matrix, b_vector |
|
|
|
|
| def analytical_embedding_drive(embedding: Vector, state_dim: int) -> Vector: |
| if not embedding: |
| return [0.0 for _ in range(state_dim)] |
| width = len(embedding) |
| return [ |
| ( |
| embedding[index % width] |
| + 0.5 * embedding[(3 * index + 1) % width] |
| - 0.25 * embedding[(5 * index + 2) % width] |
| ) |
| for index in range(state_dim) |
| ] |
|
|
|
|
| def analytical_embedding_drive_fast(embedding: object, state_dim: int) -> object: |
| if np is None: |
| embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding) |
| return analytical_embedding_drive(embedding_vector, state_dim) |
| embedding_array = embedding if hasattr(embedding, "shape") else np.asarray(embedding, dtype=np.float64) |
| if embedding_array.size == 0: |
| return np.zeros(state_dim, dtype=np.float64) |
| indices = np.arange(state_dim, dtype=np.int64) |
| width = int(embedding_array.shape[0]) |
| return ( |
| embedding_array[indices % width] |
| + 0.5 * embedding_array[(3 * indices + 1) % width] |
| - 0.25 * embedding_array[(5 * indices + 2) % width] |
| ) |
|
|
|
|
| def hippo_legs_propagate(state: Vector, step: float) -> Vector: |
| """Apply the implicit HiPPO-LegS transition without materializing its inverse.""" |
| propagated: Vector = [] |
| prefix = 0.0 |
| for row, value in enumerate(state): |
| basis = math.sqrt(2 * row + 1) |
| diagonal = 1.0 + (step * (row + 1)) |
| next_value = (value - (step * basis * prefix)) / diagonal |
| propagated.append(next_value) |
| prefix += basis * next_value |
| return propagated |
|
|
|
|
| def hippo_legs_propagate_fast(state: object, step: float) -> object: |
| """Vector-friendly HiPPO-LegS implicit solve; exact up to floating precision.""" |
| if np is None: |
| state_vector = state.tolist() if hasattr(state, "tolist") else list(state) |
| return hippo_legs_propagate(state_vector, step) |
| state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64) |
| propagated = np.empty_like(state_array) |
| prefix = 0.0 |
| for row in range(int(state_array.shape[0])): |
| basis = math.sqrt(2 * row + 1) |
| diagonal = 1.0 + (step * (row + 1)) |
| value = (float(state_array[row]) - (step * basis * prefix)) / diagonal |
| propagated[row] = value |
| prefix += basis * value |
| return propagated |
|
|
|
|
| def hippo_legs_propagate_stack_fast(states: object, steps: object) -> object: |
| """Apply structured HiPPO-LegS propagation to a stack of timescale states.""" |
| if np is None: |
| state_rows = states.tolist() if hasattr(states, "tolist") else list(states) |
| step_values = steps.tolist() if hasattr(steps, "tolist") else list(steps) |
| return [ |
| hippo_legs_propagate(row, float(step)) |
| for row, step in zip(state_rows, step_values) |
| ] |
| state_matrix = states if hasattr(states, "shape") else np.asarray(states, dtype=np.float64) |
| step_array = steps if hasattr(steps, "shape") else np.asarray(steps, dtype=np.float64) |
| if _hippo_legs_propagate_stack_numba is not None: |
| return _hippo_legs_propagate_stack_numba(state_matrix, step_array) |
| propagated = np.empty_like(state_matrix) |
| rows, width = state_matrix.shape |
| prefixes = np.zeros(rows, dtype=state_matrix.dtype) |
| for column in range(int(width)): |
| basis = math.sqrt(2 * column + 1) |
| diagonal = 1.0 + (step_array * (column + 1)) |
| values = (state_matrix[:, column] - (step_array * basis * prefixes)) / diagonal |
| propagated[:, column] = values |
| prefixes += basis * values |
| return propagated |
|
|
|
|
| def hippo_document_combined_states_fast( |
| token_ids: object, |
| embeddings: object, |
| trace_embeddings: object, |
| timescales: object, |
| trace_gain: object, |
| input_projection: object, |
| drive_primary: object, |
| drive_secondary: object, |
| drive_tertiary: object, |
| *, |
| state_dim: int, |
| embedding_dim: int, |
| ) -> object | None: |
| """Compute all per-token combined states for one document in a compiled kernel.""" |
| if _hippo_document_combined_states_numba is None: |
| return None |
| return _hippo_document_combined_states_numba( |
| token_ids, |
| embeddings, |
| trace_embeddings, |
| timescales, |
| trace_gain, |
| input_projection, |
| drive_primary, |
| drive_secondary, |
| drive_tertiary, |
| state_dim, |
| embedding_dim, |
| ) |
|
|
|
|
| def hippo_document_selected_combined_states_fast( |
| token_ids: object, |
| selected_positions: object, |
| embeddings: object, |
| trace_embeddings: object, |
| timescales: object, |
| trace_gain: object, |
| input_projection: object, |
| drive_primary: object, |
| drive_secondary: object, |
| drive_tertiary: object, |
| *, |
| state_dim: int, |
| embedding_dim: int, |
| ) -> object | None: |
| """Compute per-token combined states only at requested document positions.""" |
| if _hippo_document_selected_combined_states_numba is None: |
| return None |
| return _hippo_document_selected_combined_states_numba( |
| token_ids, |
| selected_positions, |
| embeddings, |
| trace_embeddings, |
| timescales, |
| trace_gain, |
| input_projection, |
| drive_primary, |
| drive_secondary, |
| drive_tertiary, |
| state_dim, |
| embedding_dim, |
| ) |
|
|
|
|
| @dataclass(slots=True) |
| class AnalyticalMemoryUnit: |
| state_dim: int |
| timescale: float |
|
|
| def __post_init__(self) -> None: |
| a_matrix, b_vector = hippo_legs_matrix(self.state_dim) |
| self.transition, self.input_projection = self._discretize_transition( |
| a_matrix, |
| b_vector, |
| self.timescale, |
| ) |
|
|
| transition: Matrix = None |
| input_projection: Vector = None |
| transition_array: object | None = None |
| input_projection_array: object | None = None |
|
|
| @staticmethod |
| def _discretize_transition( |
| a_matrix: Matrix, |
| b_vector: Vector, |
| step: float, |
| ) -> tuple[Matrix, Vector]: |
| implicit_system = [ |
| [ |
| identity_value - step * a_value |
| for identity_value, a_value in zip(identity_row, a_row) |
| ] |
| for identity_row, a_row in zip(identity(len(a_matrix)), a_matrix) |
| ] |
| transition = invert_matrix(implicit_system) |
| input_projection = matvec(transition, [step * value for value in b_vector]) |
| return transition, input_projection |
|
|
| def step(self, state: Vector, scalar_input: float) -> Vector: |
| if np is not None and self.transition_array is None: |
| self.transition_array = np.asarray(self.transition, dtype=np.float64) |
| self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64) |
| propagated = matvec(self.transition, state) |
| return [ |
| propagated[index] + self.input_projection[index] * scalar_input |
| for index in range(self.state_dim) |
| ] |
|
|
| def step_vector(self, state: Vector, drive: Vector) -> Vector: |
| propagated = matvec(self.transition, state) |
| return [ |
| propagated[index] + self.input_projection[index] * drive[index] |
| for index in range(self.state_dim) |
| ] |
|
|
| def step_fast(self, state: object, scalar_input: float) -> object: |
| if np is None: |
| state_vector = state.tolist() if hasattr(state, "tolist") else list(state) |
| return self.step(state_vector, scalar_input) |
| if self.transition_array is None or self.input_projection_array is None: |
| self.transition_array = np.asarray(self.transition, dtype=np.float64) |
| self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64) |
| state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64) |
| return (self.transition_array @ state_array) + ( |
| self.input_projection_array * scalar_input |
| ) |
|
|
| def step_vector_fast(self, state: object, drive: object) -> object: |
| if np is None: |
| state_vector = state.tolist() if hasattr(state, "tolist") else list(state) |
| drive_vector = drive.tolist() if hasattr(drive, "tolist") else list(drive) |
| return self.step_vector(state_vector, drive_vector) |
| if self.transition_array is None or self.input_projection_array is None: |
| self.transition_array = np.asarray(self.transition, dtype=np.float64) |
| self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64) |
| state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64) |
| drive_array = drive if hasattr(drive, "shape") else np.asarray(drive, dtype=np.float64) |
| return (self.transition_array @ state_array) + ( |
| self.input_projection_array * drive_array |
| ) |
|
|