import math from dataclasses import dataclass import site import sys from pathlib import Path from .linalg import Matrix, Vector, identity, invert_matrix, matvec _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor" for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"): if _vendor_path.exists(): vendor_text = str(_vendor_path) if vendor_text not in sys.path: sys.path.insert(0, vendor_text) try: import numpy as np except ModuleNotFoundError: user_site = site.getusersitepackages() if user_site and user_site not in sys.path: sys.path.append(user_site) try: import numpy as np except ModuleNotFoundError: np = None try: from numba import njit as _numba_njit except (ImportError, ModuleNotFoundError, OSError): _numba_njit = None HAS_COMPILED_HIPPO_KERNEL = _numba_njit is not None if _numba_njit is not None: @_numba_njit(cache=True) def _hippo_legs_propagate_stack_numba(states: object, steps: object) -> object: rows = states.shape[0] width = states.shape[1] propagated = np.empty_like(states) prefixes = np.zeros(rows, dtype=states.dtype) for column in range(width): basis = math.sqrt(2 * column + 1) for row in range(rows): diagonal = 1.0 + (steps[row] * (column + 1)) value = (states[row, column] - (steps[row] * basis * prefixes[row])) / diagonal propagated[row, column] = value prefixes[row] += basis * value return propagated @_numba_njit(cache=True) def _hippo_document_combined_states_numba( token_ids: object, embeddings: object, trace_embeddings: object, timescales: object, trace_gain: object, input_projection: object, drive_primary: object, drive_secondary: object, drive_tertiary: object, state_dim: int, embedding_dim: int, ) -> object: steps = max(0, token_ids.shape[0] - 1) timescale_count = timescales.shape[0] feature_count = timescale_count * (state_dim + embedding_dim) combined = np.zeros((steps, feature_count), dtype=embeddings.dtype) hidden = np.zeros((timescale_count, state_dim), dtype=embeddings.dtype) traces = np.zeros((timescale_count, embedding_dim), dtype=embeddings.dtype) prefixes = np.zeros(timescale_count, dtype=embeddings.dtype) for token_index in range(steps): token_id = token_ids[token_index] for timescale_index in range(timescale_count): prefixes[timescale_index] = 0.0 for column in range(state_dim): embedding_value = ( embeddings[token_id, drive_primary[column]] + (0.5 * embeddings[token_id, drive_secondary[column]]) - (0.25 * embeddings[token_id, drive_tertiary[column]]) ) basis = math.sqrt(2 * column + 1) for timescale_index in range(timescale_count): step = timescales[timescale_index] diagonal = 1.0 + (step * (column + 1)) value = ( hidden[timescale_index, column] - (step * basis * prefixes[timescale_index]) ) / diagonal value += input_projection[timescale_index, column] * embedding_value hidden[timescale_index, column] = value prefixes[timescale_index] += basis * value for timescale_index in range(timescale_count): base = timescale_index * (state_dim + embedding_dim) for column in range(state_dim): combined[token_index, base + column] = hidden[timescale_index, column] trace_base = base + state_dim gain = trace_gain[timescale_index] for column in range(embedding_dim): traces[timescale_index, column] += gain * trace_embeddings[token_id, column] combined[token_index, trace_base + column] = traces[timescale_index, column] return combined @_numba_njit(cache=True) def _hippo_document_selected_combined_states_numba( token_ids: object, selected_positions: object, embeddings: object, trace_embeddings: object, timescales: object, trace_gain: object, input_projection: object, drive_primary: object, drive_secondary: object, drive_tertiary: object, state_dim: int, embedding_dim: int, ) -> object: steps = max(0, token_ids.shape[0] - 1) selected_count = selected_positions.shape[0] timescale_count = timescales.shape[0] feature_count = timescale_count * (state_dim + embedding_dim) combined = np.zeros((selected_count, feature_count), dtype=embeddings.dtype) hidden = np.zeros((timescale_count, state_dim), dtype=embeddings.dtype) traces = np.zeros((timescale_count, embedding_dim), dtype=embeddings.dtype) prefixes = np.zeros(timescale_count, dtype=embeddings.dtype) selected_cursor = 0 for token_index in range(steps): token_id = token_ids[token_index] for timescale_index in range(timescale_count): prefixes[timescale_index] = 0.0 for column in range(state_dim): embedding_value = ( embeddings[token_id, drive_primary[column]] + (0.5 * embeddings[token_id, drive_secondary[column]]) - (0.25 * embeddings[token_id, drive_tertiary[column]]) ) basis = math.sqrt(2 * column + 1) for timescale_index in range(timescale_count): step = timescales[timescale_index] diagonal = 1.0 + (step * (column + 1)) value = ( hidden[timescale_index, column] - (step * basis * prefixes[timescale_index]) ) / diagonal value += input_projection[timescale_index, column] * embedding_value hidden[timescale_index, column] = value prefixes[timescale_index] += basis * value for timescale_index in range(timescale_count): gain = trace_gain[timescale_index] for column in range(embedding_dim): traces[timescale_index, column] += gain * trace_embeddings[token_id, column] if ( selected_cursor < selected_count and token_index == selected_positions[selected_cursor] ): for timescale_index in range(timescale_count): base = timescale_index * (state_dim + embedding_dim) for column in range(state_dim): combined[selected_cursor, base + column] = hidden[timescale_index, column] trace_base = base + state_dim for column in range(embedding_dim): combined[selected_cursor, trace_base + column] = traces[timescale_index, column] selected_cursor += 1 return combined else: _hippo_legs_propagate_stack_numba = None _hippo_document_combined_states_numba = None _hippo_document_selected_combined_states_numba = None def hippo_legs_matrix(order: int) -> tuple[Matrix, Vector]: a_matrix = [[0.0 for _ in range(order)] for _ in range(order)] b_vector = [0.0 for _ in range(order)] for row in range(order): for col in range(order): if row > col: a_matrix[row][col] = -math.sqrt(2 * row + 1) * math.sqrt(2 * col + 1) elif row == col: a_matrix[row][col] = -(row + 1) b_vector[row] = math.sqrt(2 * row + 1) return a_matrix, b_vector def analytical_embedding_drive(embedding: Vector, state_dim: int) -> Vector: if not embedding: return [0.0 for _ in range(state_dim)] width = len(embedding) return [ ( embedding[index % width] + 0.5 * embedding[(3 * index + 1) % width] - 0.25 * embedding[(5 * index + 2) % width] ) for index in range(state_dim) ] def analytical_embedding_drive_fast(embedding: object, state_dim: int) -> object: if np is None: embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding) return analytical_embedding_drive(embedding_vector, state_dim) embedding_array = embedding if hasattr(embedding, "shape") else np.asarray(embedding, dtype=np.float64) if embedding_array.size == 0: return np.zeros(state_dim, dtype=np.float64) indices = np.arange(state_dim, dtype=np.int64) width = int(embedding_array.shape[0]) return ( embedding_array[indices % width] + 0.5 * embedding_array[(3 * indices + 1) % width] - 0.25 * embedding_array[(5 * indices + 2) % width] ) def hippo_legs_propagate(state: Vector, step: float) -> Vector: """Apply the implicit HiPPO-LegS transition without materializing its inverse.""" propagated: Vector = [] prefix = 0.0 for row, value in enumerate(state): basis = math.sqrt(2 * row + 1) diagonal = 1.0 + (step * (row + 1)) next_value = (value - (step * basis * prefix)) / diagonal propagated.append(next_value) prefix += basis * next_value return propagated def hippo_legs_propagate_fast(state: object, step: float) -> object: """Vector-friendly HiPPO-LegS implicit solve; exact up to floating precision.""" if np is None: state_vector = state.tolist() if hasattr(state, "tolist") else list(state) return hippo_legs_propagate(state_vector, step) state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64) propagated = np.empty_like(state_array) prefix = 0.0 for row in range(int(state_array.shape[0])): basis = math.sqrt(2 * row + 1) diagonal = 1.0 + (step * (row + 1)) value = (float(state_array[row]) - (step * basis * prefix)) / diagonal propagated[row] = value prefix += basis * value return propagated def hippo_legs_propagate_stack_fast(states: object, steps: object) -> object: """Apply structured HiPPO-LegS propagation to a stack of timescale states.""" if np is None: state_rows = states.tolist() if hasattr(states, "tolist") else list(states) step_values = steps.tolist() if hasattr(steps, "tolist") else list(steps) return [ hippo_legs_propagate(row, float(step)) for row, step in zip(state_rows, step_values) ] state_matrix = states if hasattr(states, "shape") else np.asarray(states, dtype=np.float64) step_array = steps if hasattr(steps, "shape") else np.asarray(steps, dtype=np.float64) if _hippo_legs_propagate_stack_numba is not None: return _hippo_legs_propagate_stack_numba(state_matrix, step_array) propagated = np.empty_like(state_matrix) rows, width = state_matrix.shape prefixes = np.zeros(rows, dtype=state_matrix.dtype) for column in range(int(width)): basis = math.sqrt(2 * column + 1) diagonal = 1.0 + (step_array * (column + 1)) values = (state_matrix[:, column] - (step_array * basis * prefixes)) / diagonal propagated[:, column] = values prefixes += basis * values return propagated def hippo_document_combined_states_fast( token_ids: object, embeddings: object, trace_embeddings: object, timescales: object, trace_gain: object, input_projection: object, drive_primary: object, drive_secondary: object, drive_tertiary: object, *, state_dim: int, embedding_dim: int, ) -> object | None: """Compute all per-token combined states for one document in a compiled kernel.""" if _hippo_document_combined_states_numba is None: return None return _hippo_document_combined_states_numba( token_ids, embeddings, trace_embeddings, timescales, trace_gain, input_projection, drive_primary, drive_secondary, drive_tertiary, state_dim, embedding_dim, ) def hippo_document_selected_combined_states_fast( token_ids: object, selected_positions: object, embeddings: object, trace_embeddings: object, timescales: object, trace_gain: object, input_projection: object, drive_primary: object, drive_secondary: object, drive_tertiary: object, *, state_dim: int, embedding_dim: int, ) -> object | None: """Compute per-token combined states only at requested document positions.""" if _hippo_document_selected_combined_states_numba is None: return None return _hippo_document_selected_combined_states_numba( token_ids, selected_positions, embeddings, trace_embeddings, timescales, trace_gain, input_projection, drive_primary, drive_secondary, drive_tertiary, state_dim, embedding_dim, ) @dataclass(slots=True) class AnalyticalMemoryUnit: state_dim: int timescale: float def __post_init__(self) -> None: a_matrix, b_vector = hippo_legs_matrix(self.state_dim) self.transition, self.input_projection = self._discretize_transition( a_matrix, b_vector, self.timescale, ) transition: Matrix = None # type: ignore[assignment] input_projection: Vector = None # type: ignore[assignment] transition_array: object | None = None # type: ignore[assignment] input_projection_array: object | None = None # type: ignore[assignment] @staticmethod def _discretize_transition( a_matrix: Matrix, b_vector: Vector, step: float, ) -> tuple[Matrix, Vector]: implicit_system = [ [ identity_value - step * a_value for identity_value, a_value in zip(identity_row, a_row) ] for identity_row, a_row in zip(identity(len(a_matrix)), a_matrix) ] transition = invert_matrix(implicit_system) input_projection = matvec(transition, [step * value for value in b_vector]) return transition, input_projection def step(self, state: Vector, scalar_input: float) -> Vector: if np is not None and self.transition_array is None: self.transition_array = np.asarray(self.transition, dtype=np.float64) self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64) propagated = matvec(self.transition, state) return [ propagated[index] + self.input_projection[index] * scalar_input for index in range(self.state_dim) ] def step_vector(self, state: Vector, drive: Vector) -> Vector: propagated = matvec(self.transition, state) return [ propagated[index] + self.input_projection[index] * drive[index] for index in range(self.state_dim) ] def step_fast(self, state: object, scalar_input: float) -> object: if np is None: state_vector = state.tolist() if hasattr(state, "tolist") else list(state) return self.step(state_vector, scalar_input) if self.transition_array is None or self.input_projection_array is None: self.transition_array = np.asarray(self.transition, dtype=np.float64) self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64) state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64) return (self.transition_array @ state_array) + ( self.input_projection_array * scalar_input ) def step_vector_fast(self, state: object, drive: object) -> object: if np is None: state_vector = state.tolist() if hasattr(state, "tolist") else list(state) drive_vector = drive.tolist() if hasattr(drive, "tolist") else list(drive) return self.step_vector(state_vector, drive_vector) if self.transition_array is None or self.input_projection_array is None: self.transition_array = np.asarray(self.transition, dtype=np.float64) self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64) state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64) drive_array = drive if hasattr(drive, "shape") else np.asarray(drive, dtype=np.float64) return (self.transition_array @ state_array) + ( self.input_projection_array * drive_array )