OkeyMeta's picture
Add Reframr-RFM-v2-Base release files
52da7b7 verified
import math
from dataclasses import dataclass
import site
import sys
from pathlib import Path
from .linalg import Matrix, Vector, identity, invert_matrix, matvec
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
if _vendor_path.exists():
vendor_text = str(_vendor_path)
if vendor_text not in sys.path:
sys.path.insert(0, vendor_text)
try:
import numpy as np
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
try:
import numpy as np
except ModuleNotFoundError:
np = None
try:
from numba import njit as _numba_njit
except (ImportError, ModuleNotFoundError, OSError):
_numba_njit = None
HAS_COMPILED_HIPPO_KERNEL = _numba_njit is not None
if _numba_njit is not None:
@_numba_njit(cache=True)
def _hippo_legs_propagate_stack_numba(states: object, steps: object) -> object:
rows = states.shape[0]
width = states.shape[1]
propagated = np.empty_like(states)
prefixes = np.zeros(rows, dtype=states.dtype)
for column in range(width):
basis = math.sqrt(2 * column + 1)
for row in range(rows):
diagonal = 1.0 + (steps[row] * (column + 1))
value = (states[row, column] - (steps[row] * basis * prefixes[row])) / diagonal
propagated[row, column] = value
prefixes[row] += basis * value
return propagated
@_numba_njit(cache=True)
def _hippo_document_combined_states_numba(
token_ids: object,
embeddings: object,
trace_embeddings: object,
timescales: object,
trace_gain: object,
input_projection: object,
drive_primary: object,
drive_secondary: object,
drive_tertiary: object,
state_dim: int,
embedding_dim: int,
) -> object:
steps = max(0, token_ids.shape[0] - 1)
timescale_count = timescales.shape[0]
feature_count = timescale_count * (state_dim + embedding_dim)
combined = np.zeros((steps, feature_count), dtype=embeddings.dtype)
hidden = np.zeros((timescale_count, state_dim), dtype=embeddings.dtype)
traces = np.zeros((timescale_count, embedding_dim), dtype=embeddings.dtype)
prefixes = np.zeros(timescale_count, dtype=embeddings.dtype)
for token_index in range(steps):
token_id = token_ids[token_index]
for timescale_index in range(timescale_count):
prefixes[timescale_index] = 0.0
for column in range(state_dim):
embedding_value = (
embeddings[token_id, drive_primary[column]]
+ (0.5 * embeddings[token_id, drive_secondary[column]])
- (0.25 * embeddings[token_id, drive_tertiary[column]])
)
basis = math.sqrt(2 * column + 1)
for timescale_index in range(timescale_count):
step = timescales[timescale_index]
diagonal = 1.0 + (step * (column + 1))
value = (
hidden[timescale_index, column]
- (step * basis * prefixes[timescale_index])
) / diagonal
value += input_projection[timescale_index, column] * embedding_value
hidden[timescale_index, column] = value
prefixes[timescale_index] += basis * value
for timescale_index in range(timescale_count):
base = timescale_index * (state_dim + embedding_dim)
for column in range(state_dim):
combined[token_index, base + column] = hidden[timescale_index, column]
trace_base = base + state_dim
gain = trace_gain[timescale_index]
for column in range(embedding_dim):
traces[timescale_index, column] += gain * trace_embeddings[token_id, column]
combined[token_index, trace_base + column] = traces[timescale_index, column]
return combined
@_numba_njit(cache=True)
def _hippo_document_selected_combined_states_numba(
token_ids: object,
selected_positions: object,
embeddings: object,
trace_embeddings: object,
timescales: object,
trace_gain: object,
input_projection: object,
drive_primary: object,
drive_secondary: object,
drive_tertiary: object,
state_dim: int,
embedding_dim: int,
) -> object:
steps = max(0, token_ids.shape[0] - 1)
selected_count = selected_positions.shape[0]
timescale_count = timescales.shape[0]
feature_count = timescale_count * (state_dim + embedding_dim)
combined = np.zeros((selected_count, feature_count), dtype=embeddings.dtype)
hidden = np.zeros((timescale_count, state_dim), dtype=embeddings.dtype)
traces = np.zeros((timescale_count, embedding_dim), dtype=embeddings.dtype)
prefixes = np.zeros(timescale_count, dtype=embeddings.dtype)
selected_cursor = 0
for token_index in range(steps):
token_id = token_ids[token_index]
for timescale_index in range(timescale_count):
prefixes[timescale_index] = 0.0
for column in range(state_dim):
embedding_value = (
embeddings[token_id, drive_primary[column]]
+ (0.5 * embeddings[token_id, drive_secondary[column]])
- (0.25 * embeddings[token_id, drive_tertiary[column]])
)
basis = math.sqrt(2 * column + 1)
for timescale_index in range(timescale_count):
step = timescales[timescale_index]
diagonal = 1.0 + (step * (column + 1))
value = (
hidden[timescale_index, column]
- (step * basis * prefixes[timescale_index])
) / diagonal
value += input_projection[timescale_index, column] * embedding_value
hidden[timescale_index, column] = value
prefixes[timescale_index] += basis * value
for timescale_index in range(timescale_count):
gain = trace_gain[timescale_index]
for column in range(embedding_dim):
traces[timescale_index, column] += gain * trace_embeddings[token_id, column]
if (
selected_cursor < selected_count
and token_index == selected_positions[selected_cursor]
):
for timescale_index in range(timescale_count):
base = timescale_index * (state_dim + embedding_dim)
for column in range(state_dim):
combined[selected_cursor, base + column] = hidden[timescale_index, column]
trace_base = base + state_dim
for column in range(embedding_dim):
combined[selected_cursor, trace_base + column] = traces[timescale_index, column]
selected_cursor += 1
return combined
else:
_hippo_legs_propagate_stack_numba = None
_hippo_document_combined_states_numba = None
_hippo_document_selected_combined_states_numba = None
def hippo_legs_matrix(order: int) -> tuple[Matrix, Vector]:
a_matrix = [[0.0 for _ in range(order)] for _ in range(order)]
b_vector = [0.0 for _ in range(order)]
for row in range(order):
for col in range(order):
if row > col:
a_matrix[row][col] = -math.sqrt(2 * row + 1) * math.sqrt(2 * col + 1)
elif row == col:
a_matrix[row][col] = -(row + 1)
b_vector[row] = math.sqrt(2 * row + 1)
return a_matrix, b_vector
def analytical_embedding_drive(embedding: Vector, state_dim: int) -> Vector:
if not embedding:
return [0.0 for _ in range(state_dim)]
width = len(embedding)
return [
(
embedding[index % width]
+ 0.5 * embedding[(3 * index + 1) % width]
- 0.25 * embedding[(5 * index + 2) % width]
)
for index in range(state_dim)
]
def analytical_embedding_drive_fast(embedding: object, state_dim: int) -> object:
if np is None:
embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding)
return analytical_embedding_drive(embedding_vector, state_dim)
embedding_array = embedding if hasattr(embedding, "shape") else np.asarray(embedding, dtype=np.float64)
if embedding_array.size == 0:
return np.zeros(state_dim, dtype=np.float64)
indices = np.arange(state_dim, dtype=np.int64)
width = int(embedding_array.shape[0])
return (
embedding_array[indices % width]
+ 0.5 * embedding_array[(3 * indices + 1) % width]
- 0.25 * embedding_array[(5 * indices + 2) % width]
)
def hippo_legs_propagate(state: Vector, step: float) -> Vector:
"""Apply the implicit HiPPO-LegS transition without materializing its inverse."""
propagated: Vector = []
prefix = 0.0
for row, value in enumerate(state):
basis = math.sqrt(2 * row + 1)
diagonal = 1.0 + (step * (row + 1))
next_value = (value - (step * basis * prefix)) / diagonal
propagated.append(next_value)
prefix += basis * next_value
return propagated
def hippo_legs_propagate_fast(state: object, step: float) -> object:
"""Vector-friendly HiPPO-LegS implicit solve; exact up to floating precision."""
if np is None:
state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
return hippo_legs_propagate(state_vector, step)
state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
propagated = np.empty_like(state_array)
prefix = 0.0
for row in range(int(state_array.shape[0])):
basis = math.sqrt(2 * row + 1)
diagonal = 1.0 + (step * (row + 1))
value = (float(state_array[row]) - (step * basis * prefix)) / diagonal
propagated[row] = value
prefix += basis * value
return propagated
def hippo_legs_propagate_stack_fast(states: object, steps: object) -> object:
"""Apply structured HiPPO-LegS propagation to a stack of timescale states."""
if np is None:
state_rows = states.tolist() if hasattr(states, "tolist") else list(states)
step_values = steps.tolist() if hasattr(steps, "tolist") else list(steps)
return [
hippo_legs_propagate(row, float(step))
for row, step in zip(state_rows, step_values)
]
state_matrix = states if hasattr(states, "shape") else np.asarray(states, dtype=np.float64)
step_array = steps if hasattr(steps, "shape") else np.asarray(steps, dtype=np.float64)
if _hippo_legs_propagate_stack_numba is not None:
return _hippo_legs_propagate_stack_numba(state_matrix, step_array)
propagated = np.empty_like(state_matrix)
rows, width = state_matrix.shape
prefixes = np.zeros(rows, dtype=state_matrix.dtype)
for column in range(int(width)):
basis = math.sqrt(2 * column + 1)
diagonal = 1.0 + (step_array * (column + 1))
values = (state_matrix[:, column] - (step_array * basis * prefixes)) / diagonal
propagated[:, column] = values
prefixes += basis * values
return propagated
def hippo_document_combined_states_fast(
token_ids: object,
embeddings: object,
trace_embeddings: object,
timescales: object,
trace_gain: object,
input_projection: object,
drive_primary: object,
drive_secondary: object,
drive_tertiary: object,
*,
state_dim: int,
embedding_dim: int,
) -> object | None:
"""Compute all per-token combined states for one document in a compiled kernel."""
if _hippo_document_combined_states_numba is None:
return None
return _hippo_document_combined_states_numba(
token_ids,
embeddings,
trace_embeddings,
timescales,
trace_gain,
input_projection,
drive_primary,
drive_secondary,
drive_tertiary,
state_dim,
embedding_dim,
)
def hippo_document_selected_combined_states_fast(
token_ids: object,
selected_positions: object,
embeddings: object,
trace_embeddings: object,
timescales: object,
trace_gain: object,
input_projection: object,
drive_primary: object,
drive_secondary: object,
drive_tertiary: object,
*,
state_dim: int,
embedding_dim: int,
) -> object | None:
"""Compute per-token combined states only at requested document positions."""
if _hippo_document_selected_combined_states_numba is None:
return None
return _hippo_document_selected_combined_states_numba(
token_ids,
selected_positions,
embeddings,
trace_embeddings,
timescales,
trace_gain,
input_projection,
drive_primary,
drive_secondary,
drive_tertiary,
state_dim,
embedding_dim,
)
@dataclass(slots=True)
class AnalyticalMemoryUnit:
state_dim: int
timescale: float
def __post_init__(self) -> None:
a_matrix, b_vector = hippo_legs_matrix(self.state_dim)
self.transition, self.input_projection = self._discretize_transition(
a_matrix,
b_vector,
self.timescale,
)
transition: Matrix = None # type: ignore[assignment]
input_projection: Vector = None # type: ignore[assignment]
transition_array: object | None = None # type: ignore[assignment]
input_projection_array: object | None = None # type: ignore[assignment]
@staticmethod
def _discretize_transition(
a_matrix: Matrix,
b_vector: Vector,
step: float,
) -> tuple[Matrix, Vector]:
implicit_system = [
[
identity_value - step * a_value
for identity_value, a_value in zip(identity_row, a_row)
]
for identity_row, a_row in zip(identity(len(a_matrix)), a_matrix)
]
transition = invert_matrix(implicit_system)
input_projection = matvec(transition, [step * value for value in b_vector])
return transition, input_projection
def step(self, state: Vector, scalar_input: float) -> Vector:
if np is not None and self.transition_array is None:
self.transition_array = np.asarray(self.transition, dtype=np.float64)
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
propagated = matvec(self.transition, state)
return [
propagated[index] + self.input_projection[index] * scalar_input
for index in range(self.state_dim)
]
def step_vector(self, state: Vector, drive: Vector) -> Vector:
propagated = matvec(self.transition, state)
return [
propagated[index] + self.input_projection[index] * drive[index]
for index in range(self.state_dim)
]
def step_fast(self, state: object, scalar_input: float) -> object:
if np is None:
state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
return self.step(state_vector, scalar_input)
if self.transition_array is None or self.input_projection_array is None:
self.transition_array = np.asarray(self.transition, dtype=np.float64)
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
return (self.transition_array @ state_array) + (
self.input_projection_array * scalar_input
)
def step_vector_fast(self, state: object, drive: object) -> object:
if np is None:
state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
drive_vector = drive.tolist() if hasattr(drive, "tolist") else list(drive)
return self.step_vector(state_vector, drive_vector)
if self.transition_array is None or self.input_projection_array is None:
self.transition_array = np.asarray(self.transition, dtype=np.float64)
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
drive_array = drive if hasattr(drive, "shape") else np.asarray(drive, dtype=np.float64)
return (self.transition_array @ state_array) + (
self.input_projection_array * drive_array
)