OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
import json
import hashlib
import random
import site
import string
import sys
import unicodedata
from dataclasses import dataclass
from pathlib import Path
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
if _vendor_path.exists():
vendor_text = str(_vendor_path)
if vendor_text not in sys.path:
sys.path.insert(0, vendor_text)
try:
import numpy as np
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
try:
import numpy as np
except ModuleNotFoundError:
np = None
if np is not None and not hasattr(np, "asarray"):
np = None
from .checkpoint import read_safetensor_file, write_safetensor_file
from .config import ReframrConfig
from .embeddings import EmbeddingModel, fit_ppmi_embedding_from_tokens
from .hippo import AnalyticalMemoryUnit, analytical_embedding_drive, analytical_embedding_drive_fast
from .linalg import Vector, dot, mean, norm, softmax, zeros_vector
from .reservoir import apply_readout, ridge_regression_readout
from .reasoning import reasoning_prefix
from .ternary import apply_ternary_mask, derive_ternary_mask_from_states
from .tokenizer import NativeTokenizer
ASSOCIATIVE_BLEND = 0.42
TRANSITION_BLEND = 0.08
COPY_BLEND = 0.04
BASE_BLEND = 0.34
FAST_ASSOCIATIVE_BLEND = 0.06
FAST_TRANSITION_BLEND = 0.14
FAST_COPY_BLEND = 0.04
FAST_BASE_BLEND = 0.58
FAST_PREFERENCE_BLEND = 0.15
FAST_ANSWER_BLEND = 0.30
PROMPT_READOUT_LOGIT_ZSCORE_SCALE = 0.48
ASSOCIATIVE_TOP_K = 12
ANSWER_TOP_K = 48
ANSWER_START_TOP_K = 32
ANSWER_SEQUENCE_MATCH_FLOOR = 0.30
ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR = 0.45
ANSWER_SEQUENCE_LOCK_FLOOR = 0.55
ANSWER_SEQUENCE_SPIKE_CONFIDENCE = 0.80
READOUT_LOGIT_ZSCORE_SCALE = 0.22
TRACE_IDENTITY_SCALE = 0.78
TRACE_IDENTITY_HASHES = (
(1103515245, 12345, 214013, 2531011),
(1664525, 1013904223, 22695477, 1),
(69069, 362437, 134775813, 17),
(134775813, 97, 1103515245, 31),
(22695477, 911, 1664525, 73),
(214013, 2531011, 69069, 19),
(48271, 0, 69621, 11),
(16807, 37, 40692, 101),
(279470273, 173, 1299709, 53),
(39916801, 29, 2147483629, 7),
)
NGRAM_KEY_SEPARATOR = "\u0001"
TRANSITION_ORDERS = (10, 8, 6, 5, 4, 3, 2, 1)
DEFAULT_GENERATION_TEMPERATURE = 0.82
DEFAULT_GENERATION_TOP_K = 24
DEFAULT_GENERATION_TOP_P = 0.92
DEFAULT_REPETITION_PENALTY = 1.18
ANSWER_SEQUENCE_MAX_TOKENS = 192
RUNTIME_ARRAY_DTYPE = np.float32 if np is not None else None
@dataclass(frozen=True, slots=True)
class CharacterCountFact:
character: str
word: str
count: int
surface_seed: int
def _normalize_vector(values: Vector) -> Vector:
total = sum(values)
if total <= 0.0:
return [0.0 for _ in values]
return [value / total for value in values]
def _encode_ngram_key(tokens: tuple[str, ...]) -> str:
return NGRAM_KEY_SEPARATOR.join(tokens)
def _decode_ngram_key(key: str) -> tuple[str, ...]:
return tuple(part for part in key.split(NGRAM_KEY_SEPARATOR) if part)
def _last_index(values: list[str], target: str) -> int | None:
for index in range(len(values) - 1, -1, -1):
if values[index] == target:
return index
return None
@dataclass(slots=True)
class DecodeState:
hidden_states: list[Vector]
context_traces: list[Vector]
combined_state: Vector
context_tokens: list[str]
answer_anchor_state: Vector | None = None
answer_matches: list[tuple[float, int, int]] | None = None
answer_start_matches: list[tuple[float, int, int]] | None = None
answer_sequence_matches: list[tuple[float, int, int]] | None = None
prompt_answer_prior: object | None = None
prompt_answer_start_prior: object | None = None
@dataclass(slots=True)
class ReframrModel:
config: ReframrConfig
tokenizer: NativeTokenizer | None = None
embedding_model: EmbeddingModel | None = None
memory_units: list[AnalyticalMemoryUnit] | None = None
ternary_scale: float = 1.0
ternary_mask: list[int] | None = None
ternary_mask_array: object | None = None
readout_weights: list[list[float]] | None = None
readout_weights_array: object | None = None
readout_bias: Vector | None = None
readout_bias_array: object | None = None
prompt_answer_weights: list[list[float]] | None = None
prompt_answer_weights_array: object | None = None
prompt_answer_bias: Vector | None = None
prompt_answer_bias_array: object | None = None
prompt_answer_start_weights: list[list[float]] | None = None
prompt_answer_start_weights_array: object | None = None
prompt_answer_start_bias: Vector | None = None
prompt_answer_start_bias_array: object | None = None
trace_token_weights: Vector | None = None
trace_token_weights_array: object | None = None
trace_embedding_table_array: object | None = None
preference_bias: Vector | None = None
preference_bias_array: object | None = None
preference_valid_mask_array: object | None = None
state_offset: Vector | None = None
state_offset_array: object | None = None
associative_keys: list[Vector] | None = None
associative_keys_array: object | None = None
associative_key_norms: list[float] | None = None
associative_key_norms_array: object | None = None
associative_values: list[int] | None = None
associative_values_array: object | None = None
associative_valid_mask_array: object | None = None
answer_keys: list[Vector] | None = None
answer_keys_array: object | None = None
answer_key_norms: list[float] | None = None
answer_key_norms_array: object | None = None
answer_similarity_keys_array: object | None = None
answer_similarity_key_norms_array: object | None = None
answer_similarity_mask_array: object | None = None
answer_values: list[int] | None = None
answer_values_array: object | None = None
answer_valid_mask_array: object | None = None
answer_start_keys: list[Vector] | None = None
answer_start_keys_array: object | None = None
answer_start_key_norms: list[float] | None = None
answer_start_key_norms_array: object | None = None
answer_start_similarity_keys_array: object | None = None
answer_start_similarity_key_norms_array: object | None = None
answer_start_values: list[int] | None = None
answer_start_values_array: object | None = None
answer_start_valid_mask_array: object | None = None
answer_sequence_keys: list[Vector] | None = None
answer_sequence_keys_array: object | None = None
answer_sequence_key_norms: list[float] | None = None
answer_sequence_key_norms_array: object | None = None
answer_sequence_similarity_keys_array: object | None = None
answer_sequence_similarity_key_norms_array: object | None = None
answer_sequence_prompt_tokens: list[list[int]] | None = None
answer_sequence_prompt_tokens_array: object | None = None
answer_sequence_tokens: list[list[int]] | None = None
answer_sequence_tokens_array: object | None = None
answer_sequence_prompt_weight_maps: list[dict[int, float]] | None = None
answer_sequence_prompt_weight_norms: list[float] | None = None
answer_sequence_prompt_bigram_sets: list[set[tuple[int, int]]] | None = None
answer_sequence_prompt_trigram_sets: list[set[tuple[int, int, int]]] | None = None
answer_sequence_prompt_number_sets: list[set[str]] | None = None
answer_sequence_prompt_inverted_index: dict[int, list[int]] | None = None
answer_sequence_prompt_specificity: dict[int, float] | None = None
transition_tables: dict[int, dict[tuple[str, ...], dict[str, float]]] | None = None
def fit(self, text: str) -> "ReframrModel":
self.tokenizer = NativeTokenizer.train(
text,
vocab_size=self.config.tokenizer_vocab_size,
min_pair_frequency=self.config.tokenizer_min_pair_frequency,
lowercase=self.config.lowercase,
)
tokens = self.tokenizer.encode(text)
if len(tokens) < 2:
raise ValueError("REFRAMR needs at least two tokens to derive a next-token readout.")
self.embedding_model = fit_ppmi_embedding_from_tokens(
tokens,
embedding_dim=self.config.embedding_dim,
window_size=self.config.window_size,
min_frequency=self.config.min_frequency,
max_vocab=self.config.max_vocab,
)
self.memory_units = [
AnalyticalMemoryUnit(self.config.state_dim, timescale)
for timescale in self.config.timescales
]
token_counts: dict[str, float] = {}
for token in tokens:
token_counts[token] = token_counts.get(token, 0.0) + 1.0
self.trace_token_weights = self._derive_trace_token_weights_from_counts(token_counts)
raw_states, targets, target_ids = self._collect_training_examples(tokens)
self.ternary_scale, self.ternary_mask = derive_ternary_mask_from_states(raw_states)
analytical_states = [
apply_ternary_mask(state, self.ternary_mask, self.ternary_scale)
for state in raw_states
]
self.associative_keys = [state[:] for state in analytical_states]
self.associative_key_norms = [norm(state) for state in analytical_states]
self.associative_values = target_ids[:]
self.answer_keys = []
self.answer_key_norms = []
self.answer_values = []
self.answer_start_keys = []
self.answer_start_key_norms = []
self.answer_start_values = []
self.answer_sequence_keys = []
self.answer_sequence_key_norms = []
self.answer_sequence_prompt_tokens = []
self.answer_sequence_tokens = []
self.prompt_answer_weights = []
self.prompt_answer_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.prompt_answer_start_weights = []
self.prompt_answer_start_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.transition_tables = self._build_transition_tables(tokens)
self._fit_answer_memory_from_text(text)
self.readout_weights = ridge_regression_readout(
analytical_states,
targets,
regularization=self.config.regularization,
)
self.readout_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.preference_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.state_offset = [0.0 for _ in analytical_states[0]] if analytical_states else []
self._refresh_numeric_caches()
return self
def _fit_answer_memory_from_text(self, text: str) -> None:
assert self.tokenizer is not None
assert self.embedding_model is not None
if (
self.answer_keys is None
or self.answer_key_norms is None
or self.answer_values is None
or self.answer_start_keys is None
or self.answer_start_key_norms is None
or self.answer_start_values is None
or self.answer_sequence_keys is None
or self.answer_sequence_key_norms is None
or self.answer_sequence_prompt_tokens is None
or self.answer_sequence_tokens is None
):
return
for line in text.splitlines():
if "<answer>" not in line:
continue
prompt_text, answer_text = line.split("<answer>", 1)
prompt_text = prompt_text.strip()
answer_text = answer_text.strip()
if not prompt_text or not answer_text:
continue
prompt_tokens = self.tokenizer.encode(prompt_text) + ["<answer>"]
answer_tokens = [
token
for token in self.tokenizer.encode(answer_text)
if token in self.embedding_model.token_to_id
and token not in self.tokenizer.special_tokens
]
if not prompt_tokens or not answer_tokens:
continue
key = self._encode_context(prompt_tokens)
key_norm = norm(key)
if key_norm <= 0.0:
continue
answer_ids = [
self.embedding_model.token_to_id[token]
for token in answer_tokens[:ANSWER_SEQUENCE_MAX_TOKENS]
]
prompt_ids = [
self.embedding_model.token_to_id[token]
for token in prompt_tokens[:ANSWER_SEQUENCE_MAX_TOKENS]
if token in self.embedding_model.token_to_id
and token not in self.tokenizer.special_tokens
]
if not answer_ids:
continue
self.answer_keys.append(key[:])
self.answer_key_norms.append(key_norm)
self.answer_values.append(answer_ids[0])
self.answer_start_keys.append(key[:])
self.answer_start_key_norms.append(key_norm)
self.answer_start_values.append(answer_ids[0])
self.answer_sequence_keys.append(key[:])
self.answer_sequence_key_norms.append(key_norm)
self.answer_sequence_prompt_tokens.append(
prompt_ids
+ [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(prompt_ids))]
)
self.answer_sequence_tokens.append(
answer_ids
+ [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(answer_ids))]
)
def predict_next_distribution(
self,
context: str,
*,
reasoning_mode: str | None = None,
) -> dict[str, float]:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
probabilities = self.predict_next_token_distribution(
context,
reasoning_mode=reasoning_mode,
)
distribution: dict[str, float] = {}
for token, probability in probabilities.items():
rendered = self._render_token(token)
distribution[rendered] = distribution.get(rendered, 0.0) + probability
return distribution
def predict_next_token_distribution(
self,
context: str,
*,
reasoning_mode: str | None = None,
) -> dict[str, float]:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
assert self.readout_weights is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context)
return self._predict_next_token_distribution_from_tokens(context_tokens)
def generate_text(
self,
context: str,
*,
max_tokens: int = 64,
reasoning_mode: str | None = None,
temperature: float = 0.0,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
) -> str:
character_count_response = self._character_count_response(
context,
temperature=temperature,
)
if character_count_response is not None:
return character_count_response
self._require_fit()
self._ensure_numeric_caches()
assert self.tokenizer is not None
if (
np is not None
and self.readout_weights_array is not None
and self.embedding_model is not None
and len(self.embedding_model.id_to_token) >= 1024
):
return self._generate_text_fast(
context,
max_tokens=max_tokens,
reasoning_mode=reasoning_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
active_mode = reasoning_mode or self.config.default_reasoning_profile
_, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
generated_tokens: list[str] = []
for _ in range(max_tokens):
distribution, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=self._answer_decode_has_continuation(
decode_state,
generated_tokens,
),
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
if self._should_stop_answer_sequence(decode_state, generated_tokens):
break
if self._should_stop_generation(
generated_tokens
) and not self._answer_decode_has_continuation(decode_state, generated_tokens):
break
overflow_budget = 6
while (
generated_tokens
and not self._starts_new_word(generated_tokens[-1])
and overflow_budget > 0
):
distribution, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=self._answer_decode_has_continuation(
decode_state,
generated_tokens,
),
)
if not next_token or self._starts_new_word(next_token):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
overflow_budget -= 1
return self._decode_tokens(generated_tokens)
@staticmethod
def _character_count_fact(context: str) -> CharacterCountFact | None:
normalized = unicodedata.normalize("NFKC", context).strip()
tokens = ReframrModel._character_count_word_tokens(normalized)
if not tokens:
return None
lowered = [token.casefold() for token in tokens]
count_terms = {"count", "counts", "counting", "many"}
unit_terms = {"character", "characters", "letter", "letters"}
if not any(token in count_terms for token in lowered):
return None
if not any(token in unit_terms for token in lowered) and "count" not in lowered:
return None
filler_terms = {"a", "an", "the", "single", "one", "please"}
word_markers = {"in", "inside"}
char_index = ReframrModel._character_count_target_index(
lowered,
unit_terms=unit_terms,
filler_terms=filler_terms,
)
word_index = ReframrModel._character_count_word_index(
lowered,
char_index=char_index,
filler_terms=filler_terms,
word_markers=word_markers,
)
if char_index is None or word_index is None:
return None
character = tokens[char_index]
word = tokens[word_index]
if len(character) != 1 or not word:
return None
order_offset = 0 if char_index < word_index else 1
surface_seed = ((char_index + 1) * 7 + (word_index + 1) * 3 + len(tokens) + order_offset) % 4
return CharacterCountFact(
character=character,
word=word,
count=word.casefold().count(character.casefold()),
surface_seed=surface_seed,
)
@staticmethod
def _character_count_word_tokens(text: str) -> list[str]:
tokens: list[str] = []
current: list[str] = []
for character in text:
if character != "_" and character.isalnum():
current.append(character)
continue
if current:
tokens.append("".join(current))
current = []
if current:
tokens.append("".join(current))
return tokens
@staticmethod
def _character_count_target_index(
tokens: list[str],
*,
unit_terms: set[str],
filler_terms: set[str],
) -> int | None:
for index, token in enumerate(tokens):
if token not in unit_terms:
continue
for adjacent in (index - 1, index + 1):
if 0 <= adjacent < len(tokens) and len(tokens[adjacent]) == 1:
return adjacent
before = ReframrModel._nearest_content_index(tokens, index - 1, -1, filler_terms)
after = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
for candidate in (before, after):
if candidate is not None and len(tokens[candidate]) == 1:
return candidate
for index, token in enumerate(tokens):
if token not in {"count", "counts", "counting"}:
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and tokens[candidate] in unit_terms:
candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms)
if candidate is not None and len(tokens[candidate]) == 1:
return candidate
return None
@staticmethod
def _character_count_word_index(
tokens: list[str],
*,
char_index: int | None,
filler_terms: set[str],
word_markers: set[str],
) -> int | None:
for index, token in enumerate(tokens):
if token != "word":
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1:
return candidate
for index, token in enumerate(tokens):
if token not in word_markers:
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and tokens[candidate] == "word":
candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms)
if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1:
return candidate
skipped_terms = {
"how",
"many",
"count",
"counts",
"counting",
"letter",
"letters",
"character",
"characters",
"word",
"there",
"are",
"is",
"appear",
"appears",
"times",
} | filler_terms | word_markers
for index in range(len(tokens) - 1, -1, -1):
if index == char_index:
continue
if len(tokens[index]) <= 1 or tokens[index] in skipped_terms:
continue
return index
return None
@staticmethod
def _nearest_content_index(
tokens: list[str],
start: int,
direction: int,
skipped_terms: set[str],
) -> int | None:
index = start
while 0 <= index < len(tokens):
if tokens[index] not in skipped_terms:
return index
index += direction
return None
@classmethod
def _character_count_response(cls, context: str, *, temperature: float = 0.0) -> str | None:
fact = cls._character_count_fact(context)
if fact is None:
return None
return cls._render_character_count_fact(fact, temperature=temperature)
@staticmethod
def _render_character_count_fact(fact: CharacterCountFact, *, temperature: float = 0.0) -> str:
character_label = f"'{fact.character}'"
word_label = f"'{fact.word}'"
character_noun = "character" if fact.count == 1 else "characters"
plural_times = "" if fact.count == 1 else "s"
surfaces = (
f"There {'is' if fact.count == 1 else 'are'} {fact.count} {character_label} {character_noun} in {word_label}.",
f"{word_label} contains {fact.count} {character_label} {character_noun}.",
f"In {word_label}, {character_label} appears {fact.count} time{plural_times}.",
f"The count is {fact.count} for {character_label} in {word_label}.",
)
if temperature > 0.0:
return surfaces[(random.randrange(len(surfaces)) + fact.surface_seed) % len(surfaces)]
return surfaces[fact.surface_seed % len(surfaces)]
def _generate_text_fast(
self,
context: str,
*,
max_tokens: int,
reasoning_mode: str | None,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
) -> str:
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
_, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
generated_tokens: list[str] = []
for _ in range(max_tokens):
probabilities, _ = self._score_next_token_array_from_state(
decode_state,
include_associative=True,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token_from_array(
probabilities,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=self._answer_decode_has_continuation(
decode_state,
generated_tokens,
),
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
if self._should_stop_answer_sequence(decode_state, generated_tokens):
break
if self._should_stop_generation(
generated_tokens
) and not self._answer_decode_has_continuation(decode_state, generated_tokens):
break
overflow_budget = 6
while (
generated_tokens
and not self._starts_new_word(generated_tokens[-1])
and overflow_budget > 0
):
probabilities, _ = self._score_next_token_array_from_state(
decode_state,
include_associative=True,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token_from_array(
probabilities,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=self._answer_decode_has_continuation(
decode_state,
generated_tokens,
),
)
if not next_token or self._starts_new_word(next_token):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
overflow_budget -= 1
return self._decode_tokens(generated_tokens)
def trace_next_token(
self,
context: str,
*,
reasoning_mode: str | None = None,
top_k: int = 5,
) -> dict[str, object]:
self._require_fit()
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context)
_, trace = self._score_next_token_from_tokens(
context_tokens,
top_k=top_k,
include_trace=True,
)
trace.update(
{
"context": context,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"context_tokens": context_tokens,
}
)
return trace
def trace_generation(
self,
context: str,
*,
max_tokens: int = 16,
reasoning_mode: str | None = None,
top_k: int = 5,
temperature: float = 0.0,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
) -> dict[str, object]:
character_count_response = self._character_count_response(
context,
temperature=temperature,
)
if character_count_response is not None:
active_mode = reasoning_mode or self.config.default_reasoning_profile
prompt = context if "<answer>" in context else f"{context} <answer>"
return {
"context": context,
"prompt": prompt,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generation_policy": {
"temperature": temperature,
"top_k": max(DEFAULT_GENERATION_TOP_K, top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
},
"prompt_tokens": [],
"generated_tokens": [],
"generated_text": character_count_response,
"generated_token_count": len(character_count_response.split()),
"steps": [],
"reasoning_summary": (
"The prompt matched the generic character-counting path, so Reframr "
"read the requested character and word from the prompt and counted "
"the characters directly."
),
}
self._require_fit()
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
prompt, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
prompt_tokens = decode_state.context_tokens[:]
generated_tokens: list[str] = []
steps: list[dict[str, object]] = []
for step_index in range(1, max_tokens + 1):
distribution, trace = self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=True,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=max(DEFAULT_GENERATION_TOP_K, top_k),
top_p=top_p,
repetition_penalty=repetition_penalty,
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
trace["step"] = step_index
trace["chosen_token"] = next_token
trace["chosen_text"] = self._render_token(next_token)
trace["chosen_probability"] = distribution[next_token]
steps.append(trace)
if self._should_stop_generation(
generated_tokens
) and not self._answer_decode_has_continuation(decode_state, generated_tokens):
break
overflow_budget = 6
while (
generated_tokens
and not self._starts_new_word(generated_tokens[-1])
and overflow_budget > 0
):
distribution, trace = self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=True,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=max(DEFAULT_GENERATION_TOP_K, top_k),
top_p=top_p,
repetition_penalty=repetition_penalty,
)
if not next_token or self._starts_new_word(next_token):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
trace["step"] = len(steps) + 1
trace["chosen_token"] = next_token
trace["chosen_text"] = self._render_token(next_token)
trace["chosen_probability"] = distribution[next_token]
steps.append(trace)
overflow_budget -= 1
return {
"context": context,
"prompt": prompt,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generation_policy": {
"temperature": temperature,
"top_k": max(DEFAULT_GENERATION_TOP_K, top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
},
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"generated_text": self._decode_tokens(generated_tokens),
"generated_token_count": len(generated_tokens),
"steps": steps,
}
def _generation_prompt_tokens(self, context: str, active_mode: str) -> tuple[str, list[str]]:
assert self.tokenizer is not None
prompt = context if "<answer>" in context else f"{context} <answer>"
prefix = reasoning_prefix(active_mode)
prompt_tokens = self.tokenizer.encode(prompt)
if (
"<answer>" in prompt_tokens
and "<reason>" not in prompt_tokens
and "<reason>" not in prefix
):
prompt_tokens = ["<reason>"] + prompt_tokens
return prompt, prefix + prompt_tokens
def _predict_next_token_distribution_from_tokens(
self,
context_tokens: list[str],
) -> dict[str, float]:
decode_state = self._build_decode_state(context_tokens)
return self._predict_next_token_distribution_from_state(decode_state)
def _predict_next_token_distribution_from_state(
self,
decode_state: DecodeState,
) -> dict[str, float]:
probabilities, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
)
return probabilities
@staticmethod
def _answer_sequence_should_lock(
*,
answer_sequence_confidence: float,
answer_sequence_match_confidence: float,
has_answer_sequence_prior: bool,
) -> bool:
if not has_answer_sequence_prior or answer_sequence_confidence <= 0.0:
return False
if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR:
return True
return (
answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR
and answer_sequence_confidence <= ANSWER_SEQUENCE_SPIKE_CONFIDENCE
)
@staticmethod
def _answer_start_blend_weights(
*,
answer_sequence_match_confidence: float,
) -> dict[str, float]:
if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR:
return {
"prompt_answer_start": 0.35,
"prompt_answer": 0.10,
"answer_sequence": 0.45,
"answer_start": 0.10,
}
return {
"prompt_answer_start": 0.55,
"prompt_answer": 0.20,
"answer_sequence": 0.15,
"answer_start": 0.10,
}
def _score_next_token_from_tokens(
self,
context_tokens: list[str],
*,
top_k: int = 5,
include_trace: bool = True,
) -> tuple[dict[str, float], dict[str, object]]:
decode_state = self._build_decode_state(context_tokens)
return self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=include_trace,
)
def _score_next_token_from_state(
self,
decode_state: DecodeState,
*,
top_k: int = 5,
include_trace: bool = True,
generated_tokens: list[str] | None = None,
) -> tuple[dict[str, float], dict[str, object]]:
assert self.embedding_model is not None
assert self.readout_weights is not None
generated_tokens = generated_tokens or []
state = self._masked_decode_state(decode_state)
logits = self._apply_readout_fast(state)
base_probabilities = self._calibrated_softmax(logits)
if decode_state.answer_matches is None:
decode_state.answer_matches = self._score_answer_matches(
decode_state.answer_anchor_state,
limit=max(ANSWER_TOP_K, top_k) if include_trace else ANSWER_TOP_K,
)
answer_matches = decode_state.answer_matches
if decode_state.answer_start_matches is None:
decode_state.answer_start_matches = self._score_answer_start_matches(
decode_state.answer_anchor_state,
limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K,
)
answer_start_matches = decode_state.answer_start_matches
if decode_state.answer_sequence_matches is None:
decode_state.answer_sequence_matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K,
)
answer_sequence_matches = decode_state.answer_sequence_matches
answer_prior = self._answer_prior_from_matches(answer_matches, generated_tokens)
answer_start_prior = self._answer_prior_from_matches(answer_start_matches, generated_tokens)
answer_sequence_prior = self._answer_sequence_prior_from_matches(
answer_sequence_matches,
generated_tokens,
)
answer_sequence_confidence = max(answer_sequence_prior) if answer_sequence_prior else 0.0
answer_sequence_match_confidence = (
answer_sequence_matches[0][0] if answer_sequence_matches else 0.0
)
has_answer_sequence_prior = any(value > 0.0 for value in answer_sequence_prior)
answer_locked = self._answer_sequence_should_lock(
answer_sequence_confidence=answer_sequence_confidence,
answer_sequence_match_confidence=answer_sequence_match_confidence,
has_answer_sequence_prior=has_answer_sequence_prior,
)
if decode_state.prompt_answer_prior is None:
decode_state.prompt_answer_prior = self._prompt_answer_readout_prior(
decode_state.answer_anchor_state,
start=False,
)
prompt_answer_prior = decode_state.prompt_answer_prior
prompt_answer_start_prior = (
decode_state.prompt_answer_start_prior
if not generated_tokens
else [0.0 for _ in self.embedding_model.id_to_token]
)
if not generated_tokens and prompt_answer_start_prior is None:
decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior(
decode_state.answer_anchor_state,
start=True,
)
prompt_answer_start_prior = decode_state.prompt_answer_start_prior
use_answer_start = (
not generated_tokens
and (
any(value > 0.0 for value in answer_start_prior)
or any(value > 0.0 for value in prompt_answer_start_prior)
)
)
if answer_locked:
answer_prior = answer_sequence_prior
elif use_answer_start:
start_blend = self._answer_start_blend_weights(
answer_sequence_match_confidence=answer_sequence_match_confidence
)
answer_prior = self._weighted_prior_sum(
[
(start_blend["prompt_answer_start"], prompt_answer_start_prior),
(start_blend["prompt_answer"], prompt_answer_prior),
(start_blend["answer_sequence"], answer_sequence_prior),
(start_blend["answer_start"], answer_start_prior),
],
)
elif any(value > 0.0 for value in answer_sequence_prior):
answer_prior = self._weighted_prior_sum(
[
(0.50, prompt_answer_prior),
(0.30, answer_sequence_prior),
(0.20, answer_prior),
],
)
elif any(value > 0.0 for value in prompt_answer_prior):
answer_prior = self._weighted_prior_sum(
[
(0.65, prompt_answer_prior),
(0.35, answer_prior),
],
)
associative_matches = (
[]
if use_answer_start
else self._score_associative_matches(
state,
limit=max(ASSOCIATIVE_TOP_K, top_k) if include_trace else ASSOCIATIVE_TOP_K,
)
)
associative_prior = (
[0.0 for _ in self.embedding_model.id_to_token]
if use_answer_start
else self._associative_prior_from_matches(associative_matches)
)
transition_prior, transition_order = self._transition_prior_with_order(decode_state.context_tokens)
copy_prior = self._copy_prior(decode_state.context_tokens)
preference_prior = self._preference_prior()
probabilities, blend_weights = self._blend_probabilities(
base_probabilities,
answer_prior,
associative_prior,
transition_prior,
copy_prior,
preference_prior,
transition_order=transition_order,
generated_count=len(generated_tokens),
answer_locked=answer_locked,
answer_guided_start=use_answer_start,
)
distribution = {
token: probabilities[index]
for index, token in enumerate(self.embedding_model.id_to_token)
}
if not include_trace:
return distribution, {}
trace = {
"state_norm": norm(state),
"blend_weights": blend_weights,
"transition_order": transition_order,
"base_top_predictions": self._top_entries_from_vector(base_probabilities, top_k),
"answer_top_predictions": self._top_entries_from_vector(answer_prior, top_k),
"prompt_answer_top_predictions": self._top_entries_from_vector(prompt_answer_prior, top_k),
"prompt_answer_start_top_predictions": self._top_entries_from_vector(prompt_answer_start_prior, top_k),
"answer_start_top_predictions": self._top_entries_from_vector(answer_start_prior, top_k),
"answer_sequence_top_predictions": self._top_entries_from_vector(answer_sequence_prior, top_k),
"associative_top_predictions": self._top_entries_from_vector(associative_prior, top_k),
"transition_top_predictions": self._top_entries_from_vector(transition_prior, top_k),
"copy_top_predictions": self._top_entries_from_vector(copy_prior, top_k),
"preference_top_predictions": self._top_entries_from_vector(preference_prior, top_k),
"final_top_predictions": self._top_entries_from_vector(probabilities, top_k),
"associative_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in associative_matches[:top_k]
],
"answer_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in answer_matches[:top_k]
],
"answer_start_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in answer_start_matches[:top_k]
],
"answer_sequence_matches": [
{
"example_index": example_index,
"similarity": similarity,
}
for similarity, _, example_index in answer_sequence_matches[:top_k]
],
"reasoning_summary": self._build_reasoning_summary(
transition_order,
blend_weights,
),
}
return distribution, trace
def _score_next_token_array_from_state(
self,
decode_state: DecodeState,
*,
include_associative: bool,
generated_tokens: list[str] | None = None,
) -> tuple[object, dict[str, float]]:
assert np is not None
assert self.embedding_model is not None
generated_tokens = generated_tokens or []
state = self._masked_decode_state_array(decode_state)
logits = self._apply_readout_array(state)
base_probabilities = self._calibrated_softmax_array(logits)
if decode_state.answer_matches is None:
decode_state.answer_matches = self._score_answer_matches(decode_state.answer_anchor_state)
answer_prior = np.asarray(
self._answer_prior_from_matches(
decode_state.answer_matches,
generated_tokens,
),
dtype=np.float64,
)
if decode_state.answer_sequence_matches is None:
decode_state.answer_sequence_matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
answer_sequence_matches = decode_state.answer_sequence_matches
answer_sequence_prior = np.asarray(
self._answer_sequence_prior_from_matches(
answer_sequence_matches,
generated_tokens,
),
dtype=np.float64,
)
answer_sequence_confidence = (
float(answer_sequence_prior.max()) if answer_sequence_prior.size else 0.0
)
answer_sequence_match_confidence = (
answer_sequence_matches[0][0] if answer_sequence_matches else 0.0
)
has_answer_sequence_prior = bool(np.any(answer_sequence_prior > 0.0))
answer_locked = self._answer_sequence_should_lock(
answer_sequence_confidence=answer_sequence_confidence,
answer_sequence_match_confidence=answer_sequence_match_confidence,
has_answer_sequence_prior=has_answer_sequence_prior,
)
if decode_state.prompt_answer_prior is None:
decode_state.prompt_answer_prior = self._prompt_answer_readout_prior_array(
decode_state.answer_anchor_state,
start=False,
)
prompt_answer_prior = decode_state.prompt_answer_prior
prompt_answer_start_prior = np.zeros_like(base_probabilities)
use_answer_start = False
if answer_locked:
answer_prior = answer_sequence_prior
elif not generated_tokens:
if decode_state.prompt_answer_start_prior is None:
decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior_array(
decode_state.answer_anchor_state,
start=True,
)
prompt_answer_start_prior = decode_state.prompt_answer_start_prior
if decode_state.answer_start_matches is None:
decode_state.answer_start_matches = self._score_answer_start_matches(
decode_state.answer_anchor_state
)
answer_start_prior = np.asarray(
self._answer_prior_from_matches(
decode_state.answer_start_matches,
generated_tokens,
),
dtype=np.float64,
)
if np.any(answer_start_prior > 0.0) or np.any(prompt_answer_start_prior > 0.0):
start_blend = self._answer_start_blend_weights(
answer_sequence_match_confidence=answer_sequence_match_confidence
)
answer_prior = self._weighted_prior_sum_array(
[
(start_blend["prompt_answer_start"], prompt_answer_start_prior),
(start_blend["prompt_answer"], prompt_answer_prior),
(start_blend["answer_sequence"], answer_sequence_prior),
(start_blend["answer_start"], answer_start_prior),
],
)
use_answer_start = True
if answer_locked:
answer_prior = answer_sequence_prior
elif not use_answer_start and np.any(answer_sequence_prior > 0.0):
answer_prior = self._weighted_prior_sum_array(
[
(0.50, prompt_answer_prior),
(0.30, answer_sequence_prior),
(0.20, answer_prior),
],
)
elif not use_answer_start and np.any(prompt_answer_prior > 0.0):
answer_prior = self._weighted_prior_sum_array(
[
(0.65, prompt_answer_prior),
(0.35, answer_prior),
],
)
if include_associative and not use_answer_start:
associative_prior = np.asarray(
self._associative_prior_from_matches(
self._score_associative_matches(state)
),
dtype=np.float64,
)
else:
associative_prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
transition_prior, transition_order = self._transition_prior_array_with_order(
decode_state.context_tokens
)
copy_prior = self._copy_prior_array(decode_state.context_tokens)
preference_prior = self._preference_prior_array()
return self._blend_probability_arrays(
base_probabilities,
answer_prior,
associative_prior,
transition_prior,
copy_prior,
preference_prior,
transition_order=transition_order,
generated_count=len(generated_tokens),
answer_locked=answer_locked,
answer_guided_start=use_answer_start,
)
def _calibrated_softmax(
self,
logits: Vector,
*,
scale: float = READOUT_LOGIT_ZSCORE_SCALE,
) -> Vector:
if np is not None:
return self._calibrated_softmax_array(
np.asarray(logits, dtype=np.float64),
scale=scale,
).tolist()
if not logits:
return []
center = mean(logits)
variance = mean([(value - center) * (value - center) for value in logits])
spread = variance**0.5
if spread <= 1e-12:
return softmax(logits)
calibrated = [
max(-20.0, min(20.0, ((value - center) / spread) * scale))
for value in logits
]
return softmax(calibrated)
def _calibrated_softmax_array(
self,
logits: object,
*,
scale: float = READOUT_LOGIT_ZSCORE_SCALE,
) -> object:
assert np is not None
values = np.asarray(logits, dtype=np.float64)
if values.size == 0:
return values
spread = float(values.std())
if spread > 1e-12:
values = ((values - float(values.mean())) / spread) * scale
values = np.clip(values, -20.0, 20.0)
else:
values = values - float(values.max())
values = values - float(values.max())
exponentials = np.exp(values)
total = float(exponentials.sum())
if total <= 0.0:
return np.full(values.shape, 1.0 / max(1, values.size), dtype=np.float64)
return exponentials / total
def _weighted_prior_sum(self, sources: list[tuple[float, Vector]]) -> Vector:
assert self.embedding_model is not None
active_sources = [
(weight, vector)
for weight, vector in sources
if weight > 0.0 and any(value > 0.0 for value in vector)
]
if not active_sources:
return [0.0 for _ in self.embedding_model.id_to_token]
total_weight = sum(weight for weight, _ in active_sources)
merged = [0.0 for _ in self.embedding_model.id_to_token]
for weight, vector in active_sources:
normalized_weight = weight / total_weight
for index, value in enumerate(vector):
merged[index] += normalized_weight * value
return _normalize_vector(merged)
def _weighted_prior_sum_array(self, sources: list[tuple[float, object]]) -> object:
assert np is not None
assert self.embedding_model is not None
active_sources = [
(weight, np.asarray(vector, dtype=np.float64))
for weight, vector in sources
if weight > 0.0 and np.any(np.asarray(vector, dtype=np.float64) > 0.0)
]
if not active_sources:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
total_weight = sum(weight for weight, _ in active_sources)
merged = np.zeros_like(active_sources[0][1], dtype=np.float64)
for weight, vector in active_sources:
merged += (weight / total_weight) * vector
total = float(merged.sum())
if total > 0.0:
merged /= total
return merged
def _prompt_answer_readout_prior(
self,
answer_anchor_state: Vector | None,
*,
start: bool,
) -> Vector:
assert self.embedding_model is not None
if answer_anchor_state is None:
return [0.0 for _ in self.embedding_model.id_to_token]
weights = self.prompt_answer_start_weights if start else self.prompt_answer_weights
bias = self.prompt_answer_start_bias if start else self.prompt_answer_bias
if np is not None:
return self._prompt_answer_readout_prior_array(
answer_anchor_state,
start=start,
).tolist()
if not weights:
return [0.0 for _ in self.embedding_model.id_to_token]
state = self._center_state_vector(self._masked_combined_state(answer_anchor_state))
logits = apply_readout(weights, state)
if bias:
logits = [value + bias[index] for index, value in enumerate(logits)]
return self._calibrated_softmax(
logits,
scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE,
)
def _prompt_answer_readout_prior_array(
self,
answer_anchor_state: Vector | None,
*,
start: bool,
) -> object:
assert np is not None
assert self.embedding_model is not None
if answer_anchor_state is None:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
weights = (
self.prompt_answer_start_weights_array
if start
else self.prompt_answer_weights_array
)
bias = self.prompt_answer_start_bias_array if start else self.prompt_answer_bias_array
if weights is None:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
state_array = self._center_state_array(
self._masked_combined_state_array(answer_anchor_state)
)
logits = weights @ state_array
if bias is not None and bias.shape == logits.shape:
logits = logits + bias
return self._calibrated_softmax_array(
logits,
scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE,
)
def save(self, path: str | Path) -> None:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
assert self.ternary_mask is not None
assert self.readout_weights is not None
assert self.associative_keys is not None
assert self.associative_values is not None
assert self.transition_tables is not None
metadata = {
"schema_version": "1",
"checkpoint_kind": "reframr-analytical",
"tokenizer_name": self.tokenizer.name,
"config": json.dumps(self.config.to_dict(), separators=(",", ":")),
"tokenizer": json.dumps(self.tokenizer.to_dict(), separators=(",", ":")),
"embedding_id_to_token": json.dumps(self.embedding_model.id_to_token, separators=(",", ":")),
"tokenizer_vocab_size": str(self.tokenizer.vocab_size),
"transition_tables": json.dumps(self._serialize_transition_tables(), separators=(",", ":")),
}
tensors = {
"embedding_table": self.embedding_model.embeddings,
"ternary_scale": [self.ternary_scale],
"ternary_mask": self.ternary_mask,
"readout_weights": self.readout_weights,
"readout_bias": self.readout_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"prompt_answer_weights": self.prompt_answer_weights
if self.prompt_answer_weights is not None
else [],
"prompt_answer_bias": self.prompt_answer_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"prompt_answer_start_weights": self.prompt_answer_start_weights
if self.prompt_answer_start_weights is not None
else [],
"prompt_answer_start_bias": self.prompt_answer_start_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"trace_token_weights": self.trace_token_weights
or [1.0 for _ in self.embedding_model.id_to_token],
"preference_bias": self.preference_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"state_offset": self.state_offset
or [0.0 for _ in range(self._combined_state_width())],
"associative_keys": self.associative_keys,
"associative_values": self.associative_values,
"answer_keys": self.answer_keys if self.answer_keys is not None else [],
"answer_values": self.answer_values if self.answer_values is not None else [],
"answer_start_keys": self.answer_start_keys if self.answer_start_keys is not None else [],
"answer_start_values": self.answer_start_values if self.answer_start_values is not None else [],
"answer_sequence_keys": self.answer_sequence_keys if self.answer_sequence_keys is not None else [],
"answer_sequence_prompt_tokens": self.answer_sequence_prompt_tokens if self.answer_sequence_prompt_tokens is not None else [],
"answer_sequence_tokens": self.answer_sequence_tokens if self.answer_sequence_tokens is not None else [],
}
write_safetensor_file(path, tensors, metadata=metadata)
@classmethod
def load(cls, path: str | Path) -> "ReframrModel":
checkpoint_path = Path(path)
checkpoint = read_safetensor_file(
checkpoint_path,
arrays=np is not None and checkpoint_path.stat().st_size > 10_000_000,
)
metadata = checkpoint.metadata
config = ReframrConfig.from_dict(json.loads(metadata["config"]))
model = cls(config)
model.tokenizer = NativeTokenizer.from_dict(json.loads(metadata["tokenizer"]))
id_to_token = [str(token) for token in json.loads(metadata["embedding_id_to_token"])]
embedding_table = checkpoint.tensors["embedding_table"]
if np is not None and hasattr(embedding_table, "shape"):
embeddings = embedding_table.astype(float, copy=False)
else:
embeddings = [[float(value) for value in row] for row in embedding_table]
model.embedding_model = EmbeddingModel(
token_to_id={token: index for index, token in enumerate(id_to_token)},
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=[],
)
model.memory_units = [
AnalyticalMemoryUnit(model.config.state_dim, timescale)
for timescale in model.config.timescales
]
model.ternary_scale = float(checkpoint.tensors["ternary_scale"][0])
model.ternary_mask = [int(value) for value in checkpoint.tensors["ternary_mask"]]
readout_tensor = checkpoint.tensors["readout_weights"]
model.readout_weights = (
readout_tensor.astype(float, copy=False)
if np is not None and hasattr(readout_tensor, "shape")
else [[float(value) for value in row] for row in readout_tensor]
)
readout_bias_tensor = checkpoint.tensors.get("readout_bias", [])
model.readout_bias = [
float(value) for value in (
readout_bias_tensor.tolist()
if hasattr(readout_bias_tensor, "tolist")
else readout_bias_tensor
)
]
if not model.readout_bias:
model.readout_bias = [0.0 for _ in id_to_token]
prompt_answer_tensor = checkpoint.tensors.get("prompt_answer_weights", [])
model.prompt_answer_weights = (
prompt_answer_tensor.astype(float, copy=False)
if np is not None
and hasattr(prompt_answer_tensor, "shape")
and len(prompt_answer_tensor.shape) == 2
else [[float(value) for value in row] for row in prompt_answer_tensor]
)
prompt_answer_bias_tensor = checkpoint.tensors.get("prompt_answer_bias", [])
model.prompt_answer_bias = [
float(value) for value in (
prompt_answer_bias_tensor.tolist()
if hasattr(prompt_answer_bias_tensor, "tolist")
else prompt_answer_bias_tensor
)
]
if not model.prompt_answer_bias:
model.prompt_answer_bias = [0.0 for _ in id_to_token]
prompt_answer_start_tensor = checkpoint.tensors.get("prompt_answer_start_weights", [])
model.prompt_answer_start_weights = (
prompt_answer_start_tensor.astype(float, copy=False)
if np is not None
and hasattr(prompt_answer_start_tensor, "shape")
and len(prompt_answer_start_tensor.shape) == 2
else [[float(value) for value in row] for row in prompt_answer_start_tensor]
)
prompt_answer_start_bias_tensor = checkpoint.tensors.get("prompt_answer_start_bias", [])
model.prompt_answer_start_bias = [
float(value) for value in (
prompt_answer_start_bias_tensor.tolist()
if hasattr(prompt_answer_start_bias_tensor, "tolist")
else prompt_answer_start_bias_tensor
)
]
if not model.prompt_answer_start_bias:
model.prompt_answer_start_bias = [0.0 for _ in id_to_token]
trace_weight_tensor = checkpoint.tensors.get("trace_token_weights", [])
model.trace_token_weights = [
float(value) for value in (
trace_weight_tensor.tolist()
if hasattr(trace_weight_tensor, "tolist")
else trace_weight_tensor
)
]
if not model.trace_token_weights:
model.trace_token_weights = [
0.0 if token in model.tokenizer.special_tokens else 1.0
for token in id_to_token
]
preference_bias_tensor = checkpoint.tensors.get("preference_bias", [])
model.preference_bias = [
float(value) for value in (
preference_bias_tensor.tolist()
if hasattr(preference_bias_tensor, "tolist")
else preference_bias_tensor
)
]
if not model.preference_bias:
model.preference_bias = [0.0 for _ in id_to_token]
state_offset_tensor = checkpoint.tensors.get("state_offset", [])
model.state_offset = [
float(value) for value in (
state_offset_tensor.tolist()
if hasattr(state_offset_tensor, "tolist")
else state_offset_tensor
)
]
if not model.state_offset:
model.state_offset = [0.0 for _ in range(model._combined_state_width())]
associative_tensor = checkpoint.tensors.get("associative_keys", [])
model.associative_keys = (
associative_tensor.astype(float, copy=False)
if np is not None and hasattr(associative_tensor, "shape")
else [[float(value) for value in row] for row in associative_tensor]
)
if np is not None and hasattr(model.associative_keys, "shape"):
model.associative_key_norms = np.linalg.norm(model.associative_keys, axis=1).tolist()
else:
model.associative_key_norms = [norm(key) for key in model.associative_keys]
raw_associative_values = checkpoint.tensors.get("associative_values", [])
model.associative_values = [
int(value) for value in (
raw_associative_values.tolist()
if hasattr(raw_associative_values, "tolist")
else raw_associative_values
)
]
answer_tensor = checkpoint.tensors.get("answer_keys", [])
if np is not None and hasattr(answer_tensor, "shape"):
model.answer_keys = (
answer_tensor.astype(float, copy=False)
if len(answer_tensor.shape) == 2
else []
)
else:
model.answer_keys = [[float(value) for value in row] for row in answer_tensor]
if (
np is not None
and hasattr(model.answer_keys, "shape")
and len(model.answer_keys.shape) == 2
):
model.answer_key_norms = np.linalg.norm(model.answer_keys, axis=1).tolist()
else:
model.answer_key_norms = [norm(key) for key in model.answer_keys]
raw_answer_values = checkpoint.tensors.get("answer_values", [])
model.answer_values = [
int(value) for value in (
raw_answer_values.tolist()
if hasattr(raw_answer_values, "tolist")
else raw_answer_values
)
]
answer_start_tensor = checkpoint.tensors.get("answer_start_keys", [])
if np is not None and hasattr(answer_start_tensor, "shape"):
model.answer_start_keys = (
answer_start_tensor.astype(float, copy=False)
if len(answer_start_tensor.shape) == 2
else []
)
else:
model.answer_start_keys = [
[float(value) for value in row] for row in answer_start_tensor
]
if (
np is not None
and hasattr(model.answer_start_keys, "shape")
and len(model.answer_start_keys.shape) == 2
):
model.answer_start_key_norms = np.linalg.norm(model.answer_start_keys, axis=1).tolist()
else:
model.answer_start_key_norms = [norm(key) for key in model.answer_start_keys]
raw_answer_start_values = checkpoint.tensors.get("answer_start_values", [])
model.answer_start_values = [
int(value) for value in (
raw_answer_start_values.tolist()
if hasattr(raw_answer_start_values, "tolist")
else raw_answer_start_values
)
]
answer_sequence_tensor = checkpoint.tensors.get("answer_sequence_keys", [])
if np is not None and hasattr(answer_sequence_tensor, "shape"):
model.answer_sequence_keys = (
answer_sequence_tensor.astype(float, copy=False)
if len(answer_sequence_tensor.shape) == 2
else []
)
else:
model.answer_sequence_keys = [
[float(value) for value in row] for row in answer_sequence_tensor
]
if (
np is not None
and hasattr(model.answer_sequence_keys, "shape")
and len(model.answer_sequence_keys.shape) == 2
):
model.answer_sequence_key_norms = np.linalg.norm(
model.answer_sequence_keys,
axis=1,
).tolist()
else:
model.answer_sequence_key_norms = [norm(key) for key in model.answer_sequence_keys]
raw_answer_sequence_prompt_tokens = checkpoint.tensors.get("answer_sequence_prompt_tokens", [])
if np is not None and hasattr(raw_answer_sequence_prompt_tokens, "shape"):
model.answer_sequence_prompt_tokens = raw_answer_sequence_prompt_tokens.astype(int, copy=False)
else:
model.answer_sequence_prompt_tokens = [
[int(value) for value in row] for row in raw_answer_sequence_prompt_tokens
]
raw_answer_sequence_tokens = checkpoint.tensors.get("answer_sequence_tokens", [])
if np is not None and hasattr(raw_answer_sequence_tokens, "shape"):
model.answer_sequence_tokens = raw_answer_sequence_tokens.astype(int, copy=False)
else:
model.answer_sequence_tokens = [
[int(value) for value in row] for row in raw_answer_sequence_tokens
]
model.transition_tables = model._deserialize_transition_tables(
json.loads(metadata.get("transition_tables", "{}"))
)
model._refresh_numeric_caches()
return model
def _collect_training_examples(
self,
tokens: list[str],
) -> tuple[list[Vector], list[Vector], list[int]]:
assert self.embedding_model is not None
if np is not None:
hidden_states = [
np.zeros(self.config.state_dim, dtype=np.float64)
for _ in self.config.timescales
]
context_traces = [
np.zeros(self.config.embedding_dim, dtype=np.float64)
for _ in self.config.timescales
]
zero_embedding: Vector | object = np.zeros(self.config.embedding_dim, dtype=np.float64)
else:
hidden_states = [zeros_vector(self.config.state_dim) for _ in self.config.timescales]
context_traces = [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales]
zero_embedding = zeros_vector(self.config.embedding_dim)
states: list[Vector] = []
labels: list[Vector] = []
label_ids: list[int] = []
token_ids = [
self.embedding_model.token_to_id.get(token, -1)
for token in tokens
]
example_count = max(0, len(tokens) - 1)
stride = 1
if self.config.max_training_examples and example_count > self.config.max_training_examples:
stride = max(
1,
(example_count + self.config.max_training_examples - 1) // self.config.max_training_examples,
)
for index in range(len(tokens) - 1):
token = tokens[index]
token_id = token_ids[index]
embedding = (
self.embedding_model.embeddings[token_id]
if token_id >= 0
else zero_embedding
)
trace_embedding = self._trace_embedding_from_token_id(embedding, token_id)
hidden_states, context_traces, combined_state = self._step_hidden_states_from_embedding(
hidden_states,
context_traces,
embedding,
trace_embedding=trace_embedding,
)
if stride > 1 and index % stride != 0 and index != len(tokens) - 2:
continue
states.append(combined_state)
next_token_id = token_ids[index + 1]
labels.append(self._one_hot_from_id(next_token_id))
label_ids.append(next_token_id)
if self.config.max_training_examples and len(states) > self.config.max_training_examples:
states = states[: self.config.max_training_examples]
labels = labels[: self.config.max_training_examples]
label_ids = label_ids[: self.config.max_training_examples]
return states, labels, label_ids
def _is_punctuation_piece(self, piece: str) -> bool:
return bool(piece) and all(character in string.punctuation for character in piece)
def _encode_context(self, tokens: list[str]) -> Vector:
return self._masked_decode_state(self._build_decode_state(tokens))
def _build_decode_state(self, tokens: list[str]) -> DecodeState:
assert self.memory_units is not None
state = DecodeState(
hidden_states=(
[
np.zeros(self.config.state_dim, dtype=np.float64)
for _ in self.config.timescales
]
if np is not None
else [zeros_vector(self.config.state_dim) for _ in self.config.timescales]
),
context_traces=(
[
np.zeros(self.config.embedding_dim, dtype=np.float64)
for _ in self.config.timescales
]
if np is not None
else [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales]
),
combined_state=self._zero_combined_state(),
context_tokens=[],
)
for token in tokens:
self._advance_decode_state(state, token)
return state
def _advance_decode_state(self, state: DecodeState, token: str) -> DecodeState:
next_hidden_states, next_context_traces, combined_state = self._step_hidden_states(
state.hidden_states,
state.context_traces,
token,
)
state.hidden_states = next_hidden_states
state.context_traces = next_context_traces
state.combined_state = combined_state
state.context_tokens.append(token)
if token == "<answer>":
state.answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:]
state.answer_matches = None
state.answer_start_matches = None
state.answer_sequence_matches = None
state.prompt_answer_prior = None
state.prompt_answer_start_prior = None
return state
def _masked_decode_state(self, state: DecodeState) -> Vector:
assert self.ternary_mask is not None
return apply_ternary_mask(state.combined_state, self.ternary_mask, self.ternary_scale)
def _masked_combined_state(self, combined_state: Vector) -> Vector:
assert self.ternary_mask is not None
return apply_ternary_mask(combined_state, self.ternary_mask, self.ternary_scale)
def _masked_decode_state_array(self, state: DecodeState) -> object:
assert np is not None
if self.ternary_mask_array is None:
return np.asarray(self._masked_decode_state(state), dtype=RUNTIME_ARRAY_DTYPE)
return (
np.asarray(state.combined_state, dtype=RUNTIME_ARRAY_DTYPE)
* self.ternary_scale
* self.ternary_mask_array
)
def _masked_combined_state_array(self, combined_state: Vector) -> object:
assert np is not None
if self.ternary_mask_array is None:
return np.asarray(self._masked_combined_state(combined_state), dtype=RUNTIME_ARRAY_DTYPE)
return (
np.asarray(combined_state, dtype=RUNTIME_ARRAY_DTYPE)
* self.ternary_scale
* self.ternary_mask_array
)
def _center_state_vector(self, state: Vector) -> Vector:
if not self.state_offset or len(self.state_offset) != len(state):
return state
return [value - self.state_offset[index] for index, value in enumerate(state)]
def _center_state_array(self, state: object) -> object:
assert np is not None
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is None or self.state_offset_array.shape != state_array.shape:
return state_array
return state_array - self.state_offset_array
def _zero_combined_state(self) -> Vector:
return [0.0 for _ in range(self._combined_state_width())]
def _combined_state_width(self) -> int:
return (self.config.state_dim + self.config.embedding_dim) * len(self.config.timescales)
def _derive_trace_token_weights_from_counts(self, token_counts: dict[str, float]) -> Vector:
assert self.embedding_model is not None
assert self.tokenizer is not None
counts = [
float(token_counts.get(token, 0.0))
for token in self.embedding_model.id_to_token
]
positive_counts = sorted(value for value in counts if value > 0.0)
reference = (
positive_counts[len(positive_counts) // 2]
if positive_counts
else 1.0
)
weights: Vector = []
for token, count in zip(self.embedding_model.id_to_token, counts):
if token in self.tokenizer.special_tokens:
weights.append(0.0)
elif count <= 0.0:
weights.append(1.0)
else:
weight = (reference / count) ** 0.75
weights.append(max(0.08, min(4.8, weight)))
return weights
def _token_id_for_token(self, token: str) -> int:
assert self.embedding_model is not None
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None and token.lower() != token:
token_id = self.embedding_model.token_to_id.get(token.lower())
return int(token_id) if token_id is not None else -1
def _trace_embedding_from_token_id(
self,
embedding: Vector | object,
token_id: int,
) -> Vector | object:
if token_id < 0:
return embedding
if self.trace_embedding_table_array is not None:
return self.trace_embedding_table_array[token_id]
weight = self.trace_token_weights[token_id] if self.trace_token_weights is not None else 1.0
dimension = self.config.embedding_dim
if hasattr(embedding, "shape"):
trace_embedding = embedding * weight
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
bucket = (token_id * bucket_multiplier + bucket_offset) % dimension
sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0
trace_embedding[bucket] += weight * TRACE_IDENTITY_SCALE * sign
return trace_embedding
trace_values = [float(value) * weight for value in embedding]
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
bucket = (token_id * bucket_multiplier + bucket_offset) % dimension
sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0
trace_values[bucket] += weight * TRACE_IDENTITY_SCALE * sign
return trace_values
def _build_trace_embedding_table_array(self, embedding_array: object) -> object | None:
if np is None or self.trace_token_weights is None:
return None
values = np.asarray(embedding_array, dtype=np.float64)
if values.size == 0 or len(values.shape) != 2:
return None
weights = np.asarray(self.trace_token_weights, dtype=np.float64)
if weights.shape[0] != values.shape[0]:
return None
trace_values = values * weights[:, None]
if values.shape[1] <= 0:
return trace_values
token_ids = np.arange(values.shape[0], dtype=np.int64)
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
buckets = ((token_ids * bucket_multiplier + bucket_offset) % values.shape[1]).astype(
np.int64,
copy=False,
)
signs = np.where(
((token_ids * sign_multiplier + sign_offset) & 1) == 0,
1.0,
-1.0,
)
np.add.at(trace_values, (token_ids, buckets), weights * TRACE_IDENTITY_SCALE * signs)
return trace_values
def _refresh_numeric_caches(self) -> None:
if np is None:
self.ternary_mask_array = None
self.readout_weights_array = None
self.readout_bias_array = None
self.prompt_answer_weights_array = None
self.prompt_answer_bias_array = None
self.prompt_answer_start_weights_array = None
self.prompt_answer_start_bias_array = None
self.trace_token_weights_array = None
self.trace_embedding_table_array = None
self.preference_bias_array = None
self.preference_valid_mask_array = None
self.state_offset_array = None
self.associative_keys_array = None
self.associative_key_norms_array = None
self.associative_values_array = None
self.associative_valid_mask_array = None
self.answer_keys_array = None
self.answer_key_norms_array = None
self.answer_similarity_keys_array = None
self.answer_similarity_key_norms_array = None
self.answer_similarity_mask_array = None
self.answer_values_array = None
self.answer_valid_mask_array = None
self.answer_start_keys_array = None
self.answer_start_key_norms_array = None
self.answer_start_similarity_keys_array = None
self.answer_start_similarity_key_norms_array = None
self.answer_start_values_array = None
self.answer_start_valid_mask_array = None
self.answer_sequence_keys_array = None
self.answer_sequence_key_norms_array = None
self.answer_sequence_similarity_keys_array = None
self.answer_sequence_similarity_key_norms_array = None
self.answer_sequence_prompt_tokens_array = None
self.answer_sequence_tokens_array = None
self.answer_sequence_prompt_weight_maps = None
self.answer_sequence_prompt_weight_norms = None
self.answer_sequence_prompt_bigram_sets = None
self.answer_sequence_prompt_trigram_sets = None
self.answer_sequence_prompt_number_sets = None
self.answer_sequence_prompt_inverted_index = None
self._refresh_answer_sequence_prompt_overlap_cache()
return
self.ternary_mask_array = (
np.asarray(self.ternary_mask, dtype=RUNTIME_ARRAY_DTYPE)
if self.ternary_mask is not None
else None
)
self.readout_weights_array = (
np.asarray(self.readout_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.readout_weights is not None
else None
)
self.readout_bias_array = (
np.asarray(self.readout_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.readout_bias is not None
else None
)
self.prompt_answer_weights_array = (
np.asarray(self.prompt_answer_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_weights is not None
and len(self.prompt_answer_weights) > 0
else None
)
self.prompt_answer_bias_array = (
np.asarray(self.prompt_answer_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_bias is not None
else None
)
self.prompt_answer_start_weights_array = (
np.asarray(self.prompt_answer_start_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_start_weights is not None
and len(self.prompt_answer_start_weights) > 0
else None
)
self.prompt_answer_start_bias_array = (
np.asarray(self.prompt_answer_start_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_start_bias is not None
else None
)
self.trace_token_weights_array = (
np.asarray(self.trace_token_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.trace_token_weights is not None
else None
)
trace_embedding_table = (
self._build_trace_embedding_table_array(self.embedding_model.embeddings)
if self.embedding_model is not None and self.trace_token_weights is not None
else None
)
self.trace_embedding_table_array = (
trace_embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if trace_embedding_table is not None
else None
)
self.preference_bias_array = (
np.asarray(self.preference_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.preference_bias is not None
else None
)
self.preference_valid_mask_array = (
np.asarray(
[
self._eligible_preference_token(token)
for token in self.embedding_model.id_to_token
],
dtype=bool,
)
if self.embedding_model is not None and self.tokenizer is not None
else None
)
self.state_offset_array = (
np.asarray(self.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset is not None
else None
)
self.associative_keys_array = (
np.asarray(self.associative_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.associative_keys is not None and len(self.associative_keys) > 0
else None
)
self.associative_key_norms_array = (
np.asarray(self.associative_key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if self.associative_key_norms is not None and len(self.associative_key_norms) > 0
else None
)
self.associative_values_array = (
np.asarray(self.associative_values, dtype=np.int64)
if self.associative_values is not None and len(self.associative_values) > 0
else None
)
self.associative_valid_mask_array = (
self.associative_values_array >= 0
if self.associative_values_array is not None
else None
)
self.answer_keys_array = (
np.asarray(self.answer_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_keys is not None and len(self.answer_keys) > 0
else None
)
self.answer_key_norms_array = (
np.asarray(self.answer_key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_key_norms is not None and len(self.answer_key_norms) > 0
else None
)
self.answer_similarity_keys_array = None
self.answer_similarity_key_norms_array = None
self.answer_similarity_mask_array = None
if self.answer_keys_array is not None and len(self.answer_keys_array.shape) == 2:
width = int(self.answer_keys_array.shape[1])
block_width = self.config.state_dim + self.config.embedding_dim
expected_width = block_width * len(self.config.timescales)
if block_width > 0 and width == expected_width:
mask = np.zeros(width, dtype=RUNTIME_ARRAY_DTYPE)
for scale_index in range(len(self.config.timescales)):
start = scale_index * block_width + self.config.state_dim
end = start + self.config.embedding_dim
mask[start:end] = 1.0
self.answer_similarity_mask_array = mask
self.answer_similarity_keys_array = self.answer_keys_array * mask[None, :]
self.answer_similarity_key_norms_array = np.linalg.norm(
self.answer_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_values_array = (
np.asarray(self.answer_values, dtype=np.int64)
if self.answer_values is not None and len(self.answer_values) > 0
else None
)
self.answer_valid_mask_array = (
self.answer_values_array >= 0
if self.answer_values_array is not None
else None
)
self.answer_start_keys_array = (
np.asarray(self.answer_start_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_start_keys is not None and len(self.answer_start_keys) > 0
else None
)
self.answer_start_key_norms_array = (
np.asarray(self.answer_start_key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_start_key_norms is not None and len(self.answer_start_key_norms) > 0
else None
)
self.answer_start_similarity_keys_array = None
self.answer_start_similarity_key_norms_array = None
if (
self.answer_start_keys_array is not None
and len(self.answer_start_keys_array.shape) == 2
and self.answer_similarity_mask_array is not None
and int(self.answer_start_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0])
):
self.answer_start_similarity_keys_array = (
self.answer_start_keys_array * self.answer_similarity_mask_array[None, :]
)
self.answer_start_similarity_key_norms_array = np.linalg.norm(
self.answer_start_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_start_values_array = (
np.asarray(self.answer_start_values, dtype=np.int64)
if self.answer_start_values is not None and len(self.answer_start_values) > 0
else None
)
self.answer_start_valid_mask_array = (
self.answer_start_values_array >= 0
if self.answer_start_values_array is not None
else None
)
self.answer_sequence_keys_array = (
np.asarray(self.answer_sequence_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_sequence_keys is not None and len(self.answer_sequence_keys) > 0
else None
)
self.answer_sequence_key_norms_array = (
np.asarray(self.answer_sequence_key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_sequence_key_norms is not None and len(self.answer_sequence_key_norms) > 0
else None
)
self.answer_sequence_similarity_keys_array = None
self.answer_sequence_similarity_key_norms_array = None
if (
self.answer_sequence_keys_array is not None
and len(self.answer_sequence_keys_array.shape) == 2
and self.answer_similarity_mask_array is not None
and int(self.answer_sequence_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0])
):
self.answer_sequence_similarity_keys_array = (
self.answer_sequence_keys_array * self.answer_similarity_mask_array[None, :]
)
self.answer_sequence_similarity_key_norms_array = np.linalg.norm(
self.answer_sequence_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_sequence_tokens_array = (
np.asarray(self.answer_sequence_tokens, dtype=np.int64)
if self.answer_sequence_tokens is not None and len(self.answer_sequence_tokens) > 0
else None
)
self.answer_sequence_prompt_tokens_array = (
np.asarray(self.answer_sequence_prompt_tokens, dtype=np.int64)
if self.answer_sequence_prompt_tokens is not None
and len(self.answer_sequence_prompt_tokens) > 0
else None
)
self._refresh_answer_sequence_prompt_overlap_cache()
def _refresh_answer_sequence_prompt_overlap_cache(self) -> None:
self.answer_sequence_prompt_weight_maps = None
self.answer_sequence_prompt_weight_norms = None
self.answer_sequence_prompt_bigram_sets = None
self.answer_sequence_prompt_trigram_sets = None
self.answer_sequence_prompt_number_sets = None
self.answer_sequence_prompt_inverted_index = None
self.answer_sequence_prompt_specificity = None
if self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None:
return
inverted: dict[int, list[int]] = {}
row_id_lists: list[list[int]] = []
for row in self.answer_sequence_prompt_tokens:
row_values = row.tolist() if hasattr(row, "tolist") else row
row_ids: list[int] = []
for raw_token_id in row_values:
token_id = int(raw_token_id)
if token_id < 0 or token_id >= len(self.trace_token_weights):
continue
row_ids.append(token_id)
sequence_index = len(row_id_lists)
for token_id in set(row_ids):
inverted.setdefault(token_id, []).append(sequence_index)
row_id_lists.append(row_ids)
total_rows = len(row_id_lists)
specificity = {
token_id: self._prompt_overlap_token_specificity(len(indices), total_rows)
for token_id, indices in inverted.items()
}
self.answer_sequence_prompt_inverted_index = inverted
self.answer_sequence_prompt_specificity = specificity
weight_maps: list[dict[int, float]] = []
weight_norms: list[float] = []
bigram_sets: list[set[tuple[int, int]]] = []
trigram_sets: list[set[tuple[int, int, int]]] = []
number_sets: list[set[str]] = []
for row_ids in row_id_lists:
row_weights: dict[int, float] = {}
for token_id in row_ids:
row_weights[token_id] = max(
row_weights.get(token_id, 0.0),
float(self.trace_token_weights[token_id]) * specificity.get(token_id, 1.0),
)
weight_maps.append(row_weights)
weight_norms.append(sum(value * value for value in row_weights.values()) ** 0.5)
bigram_sets.append(
{
(row_ids[index], row_ids[index + 1])
for index in range(len(row_ids) - 1)
}
)
trigram_sets.append(
{
(row_ids[index], row_ids[index + 1], row_ids[index + 2])
for index in range(len(row_ids) - 2)
}
)
number_sets.append(self._number_strings_from_token_ids(row_ids))
self.answer_sequence_prompt_weight_maps = weight_maps
self.answer_sequence_prompt_weight_norms = weight_norms
self.answer_sequence_prompt_bigram_sets = bigram_sets
self.answer_sequence_prompt_trigram_sets = trigram_sets
self.answer_sequence_prompt_number_sets = number_sets
@staticmethod
def _prompt_overlap_token_specificity(document_frequency: int, total_documents: int) -> float:
if document_frequency <= 0 or total_documents <= 0:
return 1.0
coverage = min(1.0, document_frequency / total_documents)
return max(0.02, 1.0 - (coverage ** 0.5))
def _number_strings_from_token_ids(self, token_ids: list[int]) -> set[str]:
assert self.embedding_model is not None
tokens = [
self.embedding_model.id_to_token[token_id]
for token_id in token_ids
if 0 <= token_id < len(self.embedding_model.id_to_token)
]
return self._number_strings_from_tokens(tokens)
def _number_strings_from_tokens(self, tokens: list[str]) -> set[str]:
numbers: set[str] = set()
current = ""
for token in tokens:
if self.tokenizer is not None and token in self.tokenizer.special_tokens:
if current:
numbers.add(current)
current = ""
continue
rendered = self._render_token(token)
digits = "".join(character for character in rendered if character.isdigit())
starts_number = self._starts_new_word(token) if self.tokenizer is not None else True
if digits and starts_number:
if current:
numbers.add(current)
current = digits
elif digits and current:
current += digits
else:
if current:
numbers.add(current)
current = ""
if current:
numbers.add(current)
return numbers
@staticmethod
def _numeric_prompt_can_match(query_numbers: set[str], row_numbers: set[str]) -> bool:
if not query_numbers:
return True
if not row_numbers:
return False
return query_numbers.issubset(row_numbers)
def _apply_readout_fast(self, state: Vector) -> Vector:
if self.readout_weights_array is None or np is None:
assert self.readout_weights is not None
centered_state = self._center_state_vector(state)
logits = apply_readout(self.readout_weights, centered_state)
if self.readout_bias:
logits = [
value + self.readout_bias[index]
for index, value in enumerate(logits)
]
return logits
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape:
state_array = state_array - self.state_offset_array
logits = self.readout_weights_array @ state_array
if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape:
logits = logits + self.readout_bias_array
return logits.tolist()
def _apply_readout_array(self, state: object) -> object:
assert np is not None
assert self.readout_weights_array is not None
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape:
state_array = state_array - self.state_offset_array
logits = self.readout_weights_array @ state_array
if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape:
logits = logits + self.readout_bias_array
return logits
def _step_hidden_states(
self,
hidden_states: list[Vector],
context_traces: list[Vector],
token: str,
) -> tuple[list[Vector], list[Vector], Vector]:
assert self.embedding_model is not None
assert self.tokenizer is not None
token_id = self._token_id_for_token(token)
embedding = self.embedding_model.vector(token)
trace_embedding = self._trace_embedding_from_token_id(embedding, token_id)
return self._step_hidden_states_from_embedding(
hidden_states,
context_traces,
embedding,
trace_embedding=trace_embedding,
)
def _step_hidden_states_from_embedding(
self,
hidden_states: list[Vector],
context_traces: list[Vector],
embedding: Vector | object,
*,
trace_embedding: Vector | object | None = None,
) -> tuple[list[Vector], list[Vector], Vector]:
assert self.memory_units is not None
if trace_embedding is None:
trace_embedding = embedding
if np is not None and hidden_states and hasattr(hidden_states[0], "shape"):
embedding_array = (
embedding
if hasattr(embedding, "shape")
else np.asarray(embedding, dtype=np.float64)
)
trace_embedding_array = (
trace_embedding
if hasattr(trace_embedding, "shape")
else np.asarray(trace_embedding, dtype=np.float64)
)
drive = analytical_embedding_drive_fast(embedding_array, self.config.state_dim)
next_states: list[Vector] = []
next_traces: list[Vector] = []
combined_state: Vector = []
for unit, state, trace in zip(self.memory_units, hidden_states, context_traces):
next_state = unit.step_vector_fast(state, drive)
decay = 1.0 / (1.0 + unit.timescale)
next_trace = trace + ((1.0 - decay) * trace_embedding_array)
next_states.append(next_state)
next_traces.append(next_trace)
combined_state.extend(next_state.tolist())
combined_state.extend(next_trace.tolist())
return next_states, next_traces, combined_state
embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else embedding
trace_embedding_vector = (
trace_embedding.tolist()
if hasattr(trace_embedding, "tolist")
else trace_embedding
)
drive = analytical_embedding_drive(embedding_vector, self.config.state_dim)
next_states: list[Vector] = []
next_traces: list[Vector] = []
combined_state: Vector = []
for unit, state, trace in zip(self.memory_units, hidden_states, context_traces):
next_state = unit.step_vector(state, drive)
decay = 1.0 / (1.0 + unit.timescale)
next_trace = [
previous + ((1.0 - decay) * value)
for previous, value in zip(trace, trace_embedding_vector)
]
next_states.append(next_state)
next_traces.append(next_trace)
combined_state.extend(next_state)
combined_state.extend(next_trace)
return next_states, next_traces, combined_state
def _one_hot(self, token: str) -> Vector:
assert self.embedding_model is not None
return self._one_hot_from_id(self.embedding_model.token_to_id.get(token, -1))
def _one_hot_from_id(self, token_id: int) -> Vector:
assert self.embedding_model is not None
vector = [0.0 for _ in self.embedding_model.id_to_token]
if token_id >= 0:
vector[token_id] = 1.0
return vector
def _blend_probabilities(
self,
base: Vector,
answer: Vector,
associative: Vector,
transition: Vector,
copy: Vector,
preference: Vector,
*,
transition_order: int | None,
generated_count: int = 0,
answer_locked: bool = False,
answer_guided_start: bool = False,
) -> tuple[Vector, dict[str, float]]:
base_weight = FAST_BASE_BLEND
answer_weight = FAST_ANSWER_BLEND
associative_weight = FAST_ASSOCIATIVE_BLEND
transition_weight = FAST_TRANSITION_BLEND
copy_weight = FAST_COPY_BLEND
preference_weight = FAST_PREFERENCE_BLEND
if answer_locked:
base_weight *= 0.18
answer_weight *= 5.0
associative_weight *= 0.2
transition_weight *= 0.2
copy_weight *= 0.2
preference_weight *= 0.2
elif answer_guided_start:
base_weight *= 0.35
answer_weight *= 3.5
associative_weight *= 0.2
transition_weight *= 0.35
copy_weight *= 0.2
preference_weight *= 0.2
elif generated_count > 0:
answer_weight *= 0.32
transition_weight *= 2.0
copy_weight *= 0.75
if transition_order is None:
answer_weight *= 1.1
associative_weight *= 0.75
copy_weight += 0.02
elif transition_order <= 2:
answer_weight *= 1.15
associative_weight *= 0.65
transition_weight *= 0.55
copy_weight += 0.01
elif transition_order >= 5:
transition_weight *= 1.25
sources: list[tuple[str, float, Vector]] = [("base", base_weight, base)]
if any(value > 0.0 for value in answer):
sources.append(("answer", answer_weight, answer))
if any(value > 0.0 for value in associative):
sources.append(("associative", associative_weight, associative))
if any(value > 0.0 for value in transition):
sources.append(("transition", transition_weight, transition))
if any(value > 0.0 for value in copy):
sources.append(("copy", copy_weight, copy))
if any(value > 0.0 for value in preference):
sources.append(("preference", preference_weight, preference))
total_weight = sum(weight for _, weight, _ in sources)
blended = [0.0 for _ in base]
blend_weights: dict[str, float] = {}
for name, weight, source in sources:
normalized_weight = weight / total_weight if total_weight else 0.0
blend_weights[name] = normalized_weight
for index, value in enumerate(source):
blended[index] += normalized_weight * value
return _normalize_vector(blended), blend_weights
def _blend_probability_arrays(
self,
base: object,
answer: object,
associative: object,
transition: object,
copy: object,
preference: object,
*,
transition_order: int | None,
generated_count: int = 0,
answer_locked: bool = False,
answer_guided_start: bool = False,
) -> tuple[object, dict[str, float]]:
assert np is not None
base_weight = FAST_BASE_BLEND
answer_weight = FAST_ANSWER_BLEND
associative_weight = FAST_ASSOCIATIVE_BLEND
transition_weight = FAST_TRANSITION_BLEND
copy_weight = FAST_COPY_BLEND
preference_weight = FAST_PREFERENCE_BLEND
if answer_locked:
base_weight *= 0.18
answer_weight *= 5.0
associative_weight *= 0.2
transition_weight *= 0.2
copy_weight *= 0.2
preference_weight *= 0.2
elif answer_guided_start:
base_weight *= 0.35
answer_weight *= 3.5
associative_weight *= 0.2
transition_weight *= 0.35
copy_weight *= 0.2
preference_weight *= 0.2
elif generated_count > 0:
answer_weight *= 0.32
transition_weight *= 2.0
copy_weight *= 0.75
if transition_order is None:
answer_weight *= 1.1
associative_weight *= 0.75
copy_weight += 0.02
elif transition_order <= 2:
answer_weight *= 1.15
associative_weight *= 0.65
transition_weight *= 0.55
copy_weight += 0.01
elif transition_order >= 5:
transition_weight *= 1.25
sources: list[tuple[str, float, object]] = [("base", base_weight, base)]
if np.any(answer > 0.0):
sources.append(("answer", answer_weight, answer))
if np.any(associative > 0.0):
sources.append(("associative", associative_weight, associative))
if np.any(transition > 0.0):
sources.append(("transition", transition_weight, transition))
if np.any(copy > 0.0):
sources.append(("copy", copy_weight, copy))
if np.any(preference > 0.0):
sources.append(("preference", preference_weight, preference))
total_weight = sum(weight for _, weight, _ in sources)
blended = np.zeros_like(base, dtype=np.float64)
blend_weights: dict[str, float] = {}
for name, weight, source in sources:
normalized_weight = weight / total_weight if total_weight else 0.0
blend_weights[name] = normalized_weight
blended += normalized_weight * source
total = float(blended.sum())
if total <= 0.0:
return base, blend_weights
return blended / total, blend_weights
def _score_associative_matches(
self,
state: Vector,
*,
limit: int = ASSOCIATIVE_TOP_K,
) -> list[tuple[float, int, int]]:
if (
self.associative_keys is None
or self.associative_values is None
or self.associative_key_norms is None
or len(self.associative_keys) == 0
or len(self.associative_values) == 0
or len(self.associative_key_norms) == 0
):
return []
if (
np is not None
and
self.associative_keys_array is not None
and self.associative_key_norms_array is not None
and self.associative_values_array is not None
and self.associative_valid_mask_array is not None
and limit > 0
):
state_array = self._center_state_array(state).astype(self.associative_keys_array.dtype, copy=False)
state_norm = float(np.linalg.norm(state_array))
if state_norm == 0.0:
return []
numerators = self.associative_keys_array @ state_array
denominators = self.associative_key_norms_array * state_norm
valid_mask = self.associative_valid_mask_array & (denominators > 0.0)
if np.any(valid_mask):
scores = np.zeros_like(numerators, dtype=self.associative_keys_array.dtype)
np.divide(numerators, denominators, out=scores, where=valid_mask)
positive_positions = np.flatnonzero(valid_mask & (scores > 0.0))
if positive_positions.size:
selected_positions = positive_positions
if positive_positions.size > limit:
partition = np.argpartition(scores[positive_positions], -limit)[-limit:]
selected_positions = positive_positions[partition]
ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]]
return [
(
float(scores[position]),
int(self.associative_values_array[position]),
int(position),
)
for position in ordered_positions
]
state = self._center_state_vector(state)
state_norm = norm(state)
if state_norm == 0.0:
return []
scored: list[tuple[float, int, int]] = []
for example_index, (key, key_norm, token_id) in enumerate(
zip(self.associative_keys, self.associative_key_norms, self.associative_values)
):
if token_id < 0:
continue
denominator = state_norm * key_norm
if denominator == 0.0:
continue
similarity = dot(state, key) / denominator
if similarity > 0.0:
scored.append((similarity, token_id, example_index))
scored.sort(key=lambda item: item[0], reverse=True)
return scored[:limit]
def _associative_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
) -> Vector:
assert self.embedding_model is not None
if not matches:
return [0.0 for _ in self.embedding_model.id_to_token]
prior = [0.0 for _ in self.embedding_model.id_to_token]
for similarity, token_id, _ in matches[:ASSOCIATIVE_TOP_K]:
prior[token_id] += similarity
return _normalize_vector(prior)
def _associative_prior(self, state: Vector) -> Vector:
return self._associative_prior_from_matches(self._score_associative_matches(state))
def _score_answer_matches(
self,
answer_anchor_state: Vector | None,
*,
limit: int = ANSWER_TOP_K,
) -> list[tuple[float, int, int]]:
return self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_keys,
self.answer_key_norms,
self.answer_values,
self.answer_keys_array,
self.answer_key_norms_array,
self.answer_values_array,
self.answer_valid_mask_array,
self.answer_similarity_keys_array,
self.answer_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=limit,
)
def _score_answer_start_matches(
self,
answer_anchor_state: Vector | None,
*,
limit: int = ANSWER_START_TOP_K,
) -> list[tuple[float, int, int]]:
return self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_start_keys,
self.answer_start_key_norms,
self.answer_start_values,
self.answer_start_keys_array,
self.answer_start_key_norms_array,
self.answer_start_values_array,
self.answer_start_valid_mask_array,
self.answer_start_similarity_keys_array,
self.answer_start_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=limit,
)
def _score_answer_sequence_matches(
self,
answer_anchor_state: Vector | None,
context_tokens: list[str],
*,
limit: int = ANSWER_START_TOP_K,
) -> list[tuple[float, int, int]]:
if (
answer_anchor_state is None
or self.answer_sequence_keys is None
or self.answer_sequence_key_norms is None
or self.answer_sequence_tokens is None
):
return []
values = list(range(len(self.answer_sequence_tokens)))
values_array = np.arange(len(values), dtype=np.int64) if np is not None else None
anchor_matches = self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_sequence_keys,
self.answer_sequence_key_norms,
values,
self.answer_sequence_keys_array,
self.answer_sequence_key_norms_array,
values_array,
values_array >= 0 if values_array is not None else None,
self.answer_sequence_similarity_keys_array,
self.answer_sequence_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=max(limit * 4, limit),
)
overlap_scores = self._answer_sequence_prompt_overlap_scores(context_tokens)
if overlap_scores is None:
return anchor_matches[:limit]
if not overlap_scores:
return []
best_overlap = max(overlap_scores.values()) if overlap_scores else 0.0
overlap_floor = max(0.16, best_overlap * 0.90)
focused_overlap_scores = {
sequence_index: overlap
for sequence_index, overlap in overlap_scores.items()
if overlap >= overlap_floor
}
if not focused_overlap_scores:
focused_overlap_scores = overlap_scores
focused_indices = set(focused_overlap_scores)
merged: dict[int, float] = {}
for similarity, sequence_index, _ in anchor_matches:
if sequence_index not in focused_indices:
continue
merged[sequence_index] = max(merged.get(sequence_index, 0.0), 0.20 * similarity)
for sequence_index, overlap in focused_overlap_scores.items():
merged[sequence_index] = merged.get(sequence_index, 0.0) + (0.80 * overlap)
ranked = [
(score, sequence_index, sequence_index)
for sequence_index, score in merged.items()
if score > 0.0
]
ranked.sort(key=lambda item: item[0], reverse=True)
return ranked[:limit]
def _answer_sequence_prompt_overlap_scores(
self,
context_tokens: list[str],
) -> dict[int, float] | None:
if (
self.embedding_model is None
or self.answer_sequence_prompt_tokens is None
or self.trace_token_weights is None
):
return None
answer_boundary = _last_index(context_tokens, "<answer>")
prompt_tokens = (
context_tokens[:answer_boundary]
if answer_boundary is not None
else context_tokens
)
if self.answer_sequence_prompt_specificity is None:
self._refresh_answer_sequence_prompt_overlap_cache()
specificity_map = self.answer_sequence_prompt_specificity or {}
query_weights: dict[int, float] = {}
query_specificity: dict[int, float] = {}
query_content_weight = 0.0
query_ids: list[int] = []
for token in prompt_tokens:
if self.tokenizer is not None and token in self.tokenizer.special_tokens:
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
query_ids.append(token_id)
specificity = specificity_map.get(token_id, 1.0)
weight = specificity
query_weights[token_id] = max(
query_weights.get(token_id, 0.0),
weight,
)
query_specificity[token_id] = max(
query_specificity.get(token_id, 0.0),
specificity,
)
if specificity >= 0.20:
query_content_weight += weight
if not query_weights:
return None
query_norm = sum(value * value for value in query_weights.values()) ** 0.5
if query_norm <= 0.0:
return None
query_bigrams = {
(query_ids[index], query_ids[index + 1])
for index in range(len(query_ids) - 1)
}
query_trigrams = {
(query_ids[index], query_ids[index + 1], query_ids[index + 2])
for index in range(len(query_ids) - 2)
}
query_numbers = self._number_strings_from_tokens(prompt_tokens)
def ordered_ngram_score(
query_grams: set[tuple[int, ...]],
row_grams: set[tuple[int, ...]],
) -> float:
if not query_grams or not row_grams:
return 0.0
overlap = len(query_grams & row_grams)
if overlap <= 0:
return 0.0
return overlap / ((len(query_grams) * len(row_grams)) ** 0.5)
cached_maps = self.answer_sequence_prompt_weight_maps
cached_norms = self.answer_sequence_prompt_weight_norms
cached_bigrams = self.answer_sequence_prompt_bigram_sets
cached_trigrams = self.answer_sequence_prompt_trigram_sets
cached_numbers = self.answer_sequence_prompt_number_sets
cached_index = self.answer_sequence_prompt_inverted_index
if (
cached_maps is not None
and cached_norms is not None
and cached_bigrams is not None
and cached_trigrams is not None
and cached_numbers is not None
and len(cached_maps) == len(self.answer_sequence_prompt_tokens)
):
candidate_indices: set[int] | range
if cached_index is not None:
candidates: set[int] = set()
for token_id in query_weights:
candidates.update(cached_index.get(token_id, ()))
candidate_indices = candidates if candidates else range(len(cached_maps))
else:
candidate_indices = range(len(cached_maps))
candidate_indices = list(candidate_indices)
if cached_index is not None and candidate_indices:
candidate_set = set(candidate_indices)
local_query_weights: dict[int, float] = {}
local_query_specificity: dict[int, float] = {}
local_query_content_weight = 0.0
for token_id in query_weights:
local_frequency = len(candidate_set & set(cached_index.get(token_id, ())))
if local_frequency <= 0:
continue
specificity = self._prompt_overlap_token_specificity(
local_frequency,
len(candidate_indices),
)
weight = specificity
local_query_weights[token_id] = weight
local_query_specificity[token_id] = specificity
if specificity >= 0.20:
local_query_content_weight += weight
local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5
if local_query_norm > 0.0:
query_weights = local_query_weights
query_specificity = local_query_specificity
if local_query_content_weight > 0.0:
query_content_weight = local_query_content_weight
query_norm = local_query_norm
scores: dict[int, float] = {}
for sequence_index in candidate_indices:
row_weights = cached_maps[sequence_index]
if not row_weights:
continue
if not self._numeric_prompt_can_match(query_numbers, cached_numbers[sequence_index]):
continue
matched_content_weight = sum(
query_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
if query_specificity.get(token_id, 0.0) >= 0.20
)
row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max(
1,
len(row_weights),
)
if (
query_content_weight > 0.0
and matched_content_weight / query_content_weight < 0.40
and row_token_coverage < 0.75
):
continue
query_coverage = (
matched_content_weight / query_content_weight
if query_content_weight > 0.0
else row_token_coverage
)
numerator = sum(
query_weights[token_id] * row_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
)
if numerator <= 0.0:
continue
row_norm = cached_norms[sequence_index]
if row_norm <= 0.0:
continue
token_score = numerator / (query_norm * row_norm)
bigram_score = ordered_ngram_score(
query_bigrams,
cached_bigrams[sequence_index],
)
trigram_score = ordered_ngram_score(
query_trigrams,
cached_trigrams[sequence_index],
)
scores[sequence_index] = (
(0.35 * token_score)
+ (0.35 * query_coverage)
+ (0.15 * bigram_score)
+ (0.15 * trigram_score)
)
return scores
if cached_index is not None:
candidate_set: set[int] = set()
for token_id in query_weights:
candidate_set.update(cached_index.get(token_id, ()))
if not candidate_set:
return {}
candidate_indices: list[int] | range = sorted(candidate_set)
local_query_weights: dict[int, float] = {}
local_query_specificity: dict[int, float] = {}
local_query_content_weight = 0.0
candidate_count = len(candidate_indices)
for token_id in query_weights:
local_frequency = len(candidate_set & set(cached_index.get(token_id, ())))
if local_frequency <= 0:
continue
specificity = self._prompt_overlap_token_specificity(
local_frequency,
candidate_count,
)
local_query_weights[token_id] = specificity
local_query_specificity[token_id] = specificity
if specificity >= 0.20:
local_query_content_weight += specificity
local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5
if local_query_norm > 0.0:
query_weights = local_query_weights
query_specificity = local_query_specificity
if local_query_content_weight > 0.0:
query_content_weight = local_query_content_weight
query_norm = local_query_norm
else:
candidate_indices = range(len(self.answer_sequence_prompt_tokens))
scores: dict[int, float] = {}
for sequence_index in candidate_indices:
row = self.answer_sequence_prompt_tokens[sequence_index]
row_values = row.tolist() if hasattr(row, "tolist") else row
row_weights: dict[int, float] = {}
row_ids: list[int] = []
for raw_token_id in row_values:
token_id = int(raw_token_id)
if token_id < 0 or token_id >= len(self.trace_token_weights):
continue
row_ids.append(token_id)
row_weights[token_id] = max(
row_weights.get(token_id, 0.0),
specificity_map.get(token_id, 1.0),
)
if not row_weights:
continue
if not self._numeric_prompt_can_match(
query_numbers,
self._number_strings_from_token_ids(row_ids),
):
continue
matched_content_weight = sum(
query_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
if query_specificity.get(token_id, 0.0) >= 0.20
)
row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max(
1,
len(row_weights),
)
if (
query_content_weight > 0.0
and matched_content_weight / query_content_weight < 0.40
and row_token_coverage < 0.75
):
continue
query_coverage = (
matched_content_weight / query_content_weight
if query_content_weight > 0.0
else row_token_coverage
)
numerator = sum(
query_weights[token_id] * row_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
)
if numerator <= 0.0:
continue
row_norm = sum(value * value for value in row_weights.values()) ** 0.5
if row_norm > 0.0:
token_score = numerator / (query_norm * row_norm)
row_bigrams = {
(row_ids[index], row_ids[index + 1])
for index in range(len(row_ids) - 1)
}
row_trigrams = {
(row_ids[index], row_ids[index + 1], row_ids[index + 2])
for index in range(len(row_ids) - 2)
}
bigram_score = ordered_ngram_score(query_bigrams, row_bigrams)
trigram_score = ordered_ngram_score(query_trigrams, row_trigrams)
scores[sequence_index] = (
(0.35 * token_score)
+ (0.35 * query_coverage)
+ (0.15 * bigram_score)
+ (0.15 * trigram_score)
)
return scores
def _score_prompt_anchor_matches(
self,
answer_anchor_state: Vector | None,
keys: object | None,
key_norms_list: object | None,
values: object | None,
keys_array: object | None,
key_norms_array: object | None,
values_array: object | None,
valid_mask_array: object | None,
similarity_keys_array: object | None,
similarity_key_norms_array: object | None,
similarity_mask_array: object | None,
*,
limit: int,
) -> list[tuple[float, int, int]]:
if (
answer_anchor_state is None
or keys is None
or key_norms_list is None
or values is None
):
return []
if (
np is not None
and keys_array is not None
and key_norms_array is not None
and values_array is not None
and valid_mask_array is not None
and limit > 0
):
state_array = self._center_state_array(
self._masked_combined_state_array(answer_anchor_state)
).astype(keys_array.dtype, copy=False)
key_array = keys_array
key_norms = key_norms_array
if (
similarity_keys_array is not None
and similarity_key_norms_array is not None
and similarity_mask_array is not None
):
state_array = state_array * similarity_mask_array
key_array = similarity_keys_array
key_norms = similarity_key_norms_array
state_norm = float(np.linalg.norm(state_array))
if state_norm == 0.0:
return []
numerators = key_array @ state_array
denominators = key_norms * state_norm
valid_mask = valid_mask_array & (denominators > 0.0)
if np.any(valid_mask):
scores = np.zeros_like(numerators, dtype=key_array.dtype)
np.divide(numerators, denominators, out=scores, where=valid_mask)
positive_positions = np.flatnonzero(valid_mask & (scores > 0.0))
if positive_positions.size:
selected_positions = positive_positions
if positive_positions.size > limit:
partition = np.argpartition(scores[positive_positions], -limit)[-limit:]
selected_positions = positive_positions[partition]
ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]]
return [
(
float(scores[position]),
int(values_array[position]),
int(position),
)
for position in ordered_positions
]
state = self._center_state_vector(self._masked_combined_state(answer_anchor_state))
state_norm = norm(state)
if state_norm == 0.0:
return []
scored: list[tuple[float, int, int]] = []
for example_index, (key, key_norm, token_id) in enumerate(
zip(keys, key_norms_list, values)
):
if token_id < 0:
continue
denominator = state_norm * key_norm
if denominator == 0.0:
continue
similarity = dot(state, key) / denominator
if similarity > 0.0:
scored.append((similarity, token_id, example_index))
scored.sort(key=lambda item: item[0], reverse=True)
return scored[:limit]
def _answer_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
generated_tokens: list[str],
) -> Vector:
assert self.embedding_model is not None
if not matches:
return [0.0 for _ in self.embedding_model.id_to_token]
prior = [0.0 for _ in self.embedding_model.id_to_token]
generated_ids = {
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
}
for similarity, token_id, _ in matches[:ANSWER_TOP_K]:
token = self.embedding_model.id_to_token[token_id]
if not self._allowed_generation_token(token, generated_tokens):
continue
if token_id in generated_ids:
prior[token_id] += similarity * 0.35
else:
prior[token_id] += similarity
return _normalize_vector(prior)
def _answer_sequence_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
generated_tokens: list[str],
) -> Vector:
assert self.embedding_model is not None
if not matches or self.answer_sequence_tokens is None:
return [0.0 for _ in self.embedding_model.id_to_token]
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
prior = [0.0 for _ in self.embedding_model.id_to_token]
best_similarity = matches[0][0]
match_floor = best_similarity - 0.02 if best_similarity >= 0.9 else 0.0
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < match_floor:
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids:
continue
next_token_id = self._next_sequence_token_id(token_ids, generated_ids)
if next_token_id is None:
continue
token = self.embedding_model.id_to_token[next_token_id]
if self._allowed_generation_token(token, generated_tokens):
prior[next_token_id] += max(1e-9, similarity - match_floor)
return _normalize_vector(prior)
def _should_stop_answer_sequence(
self,
decode_state: DecodeState,
generated_tokens: list[str],
) -> bool:
matches = decode_state.answer_sequence_matches
if matches is None:
matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
return self._answer_sequence_is_complete(generated_tokens, matches)
def _answer_decode_has_continuation(
self,
decode_state: DecodeState,
generated_tokens: list[str],
) -> bool:
matches = decode_state.answer_sequence_matches
if matches is None:
matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
return self._answer_sequence_has_continuation(generated_tokens, matches)
def _answer_sequence_is_complete(
self,
generated_tokens: list[str],
matches: list[tuple[float, int, int]],
) -> bool:
if (
self.embedding_model is None
or self.answer_sequence_tokens is None
or not generated_tokens
or not matches
):
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens):
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids or len(generated_ids) < len(token_ids):
continue
if generated_ids[: len(token_ids)] == token_ids:
return True
return False
def _answer_sequence_has_continuation(
self,
generated_tokens: list[str],
matches: list[tuple[float, int, int]],
) -> bool:
if (
self.embedding_model is None
or self.answer_sequence_tokens is None
or not generated_tokens
or not matches
):
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens):
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids:
continue
next_token_id = self._next_sequence_token_id(token_ids, generated_ids)
if next_token_id is None:
continue
token = self.embedding_model.id_to_token[next_token_id]
if self._allowed_generation_token(token, generated_tokens):
return True
return False
def _next_sequence_token_id(
self,
token_ids: list[int],
generated_ids: list[int],
) -> int | None:
if not generated_ids:
return token_ids[0]
if len(generated_ids) >= len(token_ids):
return None
if token_ids[: len(generated_ids)] != generated_ids:
return None
return token_ids[len(generated_ids)]
def _transition_prior(self, context_tokens: list[str]) -> Vector:
prior, _ = self._transition_prior_with_order(context_tokens)
return prior
def _transition_prior_with_order(
self,
context_tokens: list[str],
) -> tuple[Vector, int | None]:
assert self.embedding_model is not None
if not self.transition_tables:
return [0.0 for _ in self.embedding_model.id_to_token], None
for order in TRANSITION_ORDERS:
if len(context_tokens) < order:
continue
key = tuple(context_tokens[-order:])
transitions = self.transition_tables.get(order, {}).get(key)
if not transitions:
continue
prior = [0.0 for _ in self.embedding_model.id_to_token]
for token, probability in transitions.items():
token_id = self.embedding_model.token_to_id.get(token)
if token_id is not None:
prior[token_id] = probability
return _normalize_vector(prior), order
return [0.0 for _ in self.embedding_model.id_to_token], None
def _transition_prior_array_with_order(
self,
context_tokens: list[str],
) -> tuple[object, int | None]:
assert np is not None
assert self.embedding_model is not None
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
if not self.transition_tables:
return prior, None
for order in TRANSITION_ORDERS:
if len(context_tokens) < order:
continue
key = tuple(context_tokens[-order:])
transitions = self.transition_tables.get(order, {}).get(key)
if not transitions:
continue
for token, probability in transitions.items():
token_id = self.embedding_model.token_to_id.get(token)
if token_id is not None:
prior[token_id] = probability
total = float(prior.sum())
if total > 0.0:
prior /= total
return prior, order
return prior, None
def _copy_prior(self, context_tokens: list[str]) -> Vector:
assert self.embedding_model is not None
assert self.tokenizer is not None
prior = [0.0 for _ in self.embedding_model.id_to_token]
decay = 0.82
answer_start = None
for index in range(len(context_tokens) - 1, -1, -1):
if context_tokens[index] == "<answer>":
answer_start = index + 1
break
source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens
if not source_tokens:
return prior
for distance, token in enumerate(reversed(source_tokens[-8:])):
if token in self.tokenizer.special_tokens:
continue
if not self._eligible_copy_token(token):
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
prior[token_id] += decay**distance
return _normalize_vector(prior)
def _copy_prior_array(self, context_tokens: list[str]) -> object:
assert np is not None
assert self.embedding_model is not None
assert self.tokenizer is not None
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
decay = 0.82
answer_start = None
for index in range(len(context_tokens) - 1, -1, -1):
if context_tokens[index] == "<answer>":
answer_start = index + 1
break
source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens
for distance, token in enumerate(reversed(source_tokens[-8:])):
if token in self.tokenizer.special_tokens:
continue
if not self._eligible_copy_token(token):
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
prior[token_id] += decay**distance
total = float(prior.sum())
if total > 0.0:
prior /= total
return prior
def _preference_prior(self) -> Vector:
assert self.embedding_model is not None
if not self.preference_bias or not any(value != 0.0 for value in self.preference_bias):
return [0.0 for _ in self.embedding_model.id_to_token]
eligible_indices = [
index
for index, token in enumerate(self.embedding_model.id_to_token)
if self.preference_bias[index] > 0.0 and self._eligible_preference_token(token)
]
if not eligible_indices:
return [0.0 for _ in self.embedding_model.id_to_token]
eligible_probabilities = self._calibrated_softmax(
[self.preference_bias[index] for index in eligible_indices]
)
prior = [0.0 for _ in self.embedding_model.id_to_token]
for index, probability in zip(eligible_indices, eligible_probabilities):
prior[index] = probability
return prior
def _preference_prior_array(self) -> object:
assert np is not None
assert self.embedding_model is not None
if self.preference_bias_array is None or not np.any(self.preference_bias_array != 0.0):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
if self.preference_valid_mask_array is None or not np.any(self.preference_valid_mask_array):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
positive_mask = self.preference_bias_array > 0.0
active_mask = self.preference_valid_mask_array & positive_mask
if not np.any(active_mask):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
prior[active_mask] = self._calibrated_softmax_array(
self.preference_bias_array[active_mask]
)
return prior
def _eligible_preference_token(self, token: str) -> bool:
assert self.tokenizer is not None
if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens:
return False
if not self._starts_new_word(token):
return False
rendered = self._render_token(token)
if not rendered.strip() or self._is_punctuation_piece(rendered):
return False
alphanumeric = "".join(character for character in rendered if character.isalnum())
return len(alphanumeric) >= 1
def _build_transition_tables(
self,
tokens: list[str],
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
counts: dict[int, dict[tuple[str, ...], dict[str, int]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order in sorted(TRANSITION_ORDERS):
for index in range(order - 1, len(tokens) - 1):
key = tuple(tokens[index - order + 1 : index + 1])
nxt = tokens[index + 1]
bucket = counts[order].setdefault(key, {})
bucket[nxt] = bucket.get(nxt, 0) + 1
probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order, mapping in counts.items():
items = list(mapping.items())
items.sort(key=lambda item: (-sum(item[1].values()), item[0]))
if (
self.config.max_transition_contexts_per_order is not None
and self.config.max_transition_contexts_per_order >= 0
):
items = items[: self.config.max_transition_contexts_per_order]
for key, bucket in items:
next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0]))
if self.config.max_transition_next_tokens > 0:
next_items = next_items[: self.config.max_transition_next_tokens]
total = sum(value for _, value in next_items)
if total <= 0:
continue
probabilities[order][key] = {
token: value / total
for token, value in next_items
}
return probabilities
def _serialize_transition_tables(self) -> dict[str, dict[str, dict[str, float]]]:
assert self.transition_tables is not None
return {
str(order): {
_encode_ngram_key(key): value
for key, value in mapping.items()
}
for order, mapping in self.transition_tables.items()
}
def _deserialize_transition_tables(
self,
payload: dict[str, dict[str, dict[str, float]]],
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
tables: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order_text, mapping in payload.items():
order = int(order_text)
tables[order] = {
_decode_ngram_key(key): {
str(token): float(probability)
for token, probability in value.items()
}
for key, value in mapping.items()
}
return tables
def _eligible_copy_token(self, token: str) -> bool:
rendered = self._render_token(token)
if not rendered.strip():
return False
if self._is_punctuation_piece(rendered):
return False
if not self._starts_new_word(token):
return False
alphanumeric = "".join(character for character in rendered if character.isalnum())
return len(alphanumeric) >= 2
def _allowed_generation_token(self, token: str, generated_tokens: list[str]) -> bool:
assert self.embedding_model is not None
if len(self.embedding_model.id_to_token) < 1024:
return True
if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens:
return False
rendered = self._render_token(token)
if rendered == "\n":
return bool(generated_tokens)
if not rendered.strip():
return False
if self._is_word_joiner_token(token):
return (
self._can_attach_word_joiner(generated_tokens)
or self._can_start_line_with_word_joiner(token, generated_tokens)
)
if self._is_structural_punctuation_token(token):
return bool(generated_tokens) or self._can_start_answer_with_structural_punctuation(token)
if self._is_structural_symbol_token(token):
return bool(generated_tokens) or self._starts_new_word(token)
if not self._starts_new_word(token):
return False
alphanumeric = "".join(character for character in rendered if character.isalnum())
return len(alphanumeric) >= 1 or not self._is_punctuation_piece(rendered)
def _would_repeat_recent_pattern(
self,
candidate: str,
generated_tokens: list[str],
recent_rendered_words: list[str] | None = None,
) -> bool:
if len(generated_tokens) >= 2 and generated_tokens[-1] == candidate and generated_tokens[-2] == candidate:
return True
if len(generated_tokens) >= 2:
trigram = tuple(generated_tokens[-2:] + [candidate])
recent_tokens = generated_tokens[-12:]
for index in range(max(0, len(recent_tokens) - 4)):
if tuple(recent_tokens[index : index + 3]) == trigram:
return True
rendered_words = recent_rendered_words
if rendered_words is None:
rendered_words = self._recent_rendered_words(generated_tokens)
candidate_word = self._render_token(candidate).casefold()
if (
rendered_words
and self._starts_new_word(candidate)
and any(character.isalnum() for character in candidate_word)
):
candidate_bigram = (rendered_words[-1], candidate_word)
recent_window = rendered_words[-10:]
recent_bigrams = {
(recent_window[index], recent_window[index + 1])
for index in range(len(recent_window) - 1)
}
if candidate_bigram in recent_bigrams:
return True
if (
len(candidate_word) > 2
and rendered_words[-10:].count(candidate_word) >= 2
and not self._is_common_connector_token(candidate)
):
return True
return False
def _recent_rendered_words(self, generated_tokens: list[str]) -> list[str]:
rendered_words: list[str] = []
for token in generated_tokens:
if not self._starts_new_word(token):
continue
rendered = self._render_token(token).casefold()
if any(character.isalnum() for character in rendered):
rendered_words.append(rendered)
return rendered_words
def _select_generation_token(
self,
distribution: dict[str, float],
*,
context_tokens: list[str] | None = None,
generated_tokens: list[str] | None = None,
temperature: float = DEFAULT_GENERATION_TEMPERATURE,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
preserve_dominant_candidates: bool = False,
) -> str:
assert self.tokenizer is not None
generated_tokens = generated_tokens or []
candidates = self._prepare_generation_candidates(
distribution,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=preserve_dominant_candidates,
)
if candidates:
return self._sample_generation_candidate(
candidates,
context_tokens=context_tokens or [],
generated_tokens=generated_tokens,
stochastic=temperature > 0.0,
)
for token, _ in sorted(distribution.items(), key=lambda item: item[1], reverse=True):
if token in self.tokenizer.special_tokens:
continue
if token == self.tokenizer.unk_token:
continue
if not self._allowed_generation_token(token, generated_tokens):
continue
return token
return ""
def _select_generation_token_from_array(
self,
probabilities: object,
*,
context_tokens: list[str],
generated_tokens: list[str],
temperature: float = DEFAULT_GENERATION_TEMPERATURE,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
preserve_dominant_candidates: bool = False,
) -> str:
assert np is not None
assert self.tokenizer is not None
assert self.embedding_model is not None
values = np.asarray(probabilities, dtype=np.float64)
if values.size == 0:
return ""
pool_size = min(values.size, max(top_k * 4, 64))
if pool_size <= 0:
pool_size = min(values.size, 64)
if pool_size < values.size:
candidate_indices = np.argpartition(values, -pool_size)[-pool_size:]
candidate_indices = candidate_indices[np.argsort(values[candidate_indices])[::-1]]
else:
candidate_indices = np.argsort(values)[::-1]
distribution: dict[str, float] = {}
for raw_index in candidate_indices:
index = int(raw_index)
score = float(values[index])
if score <= 0.0:
continue
token = self.embedding_model.id_to_token[index]
if token in self.tokenizer.special_tokens or token == self.tokenizer.unk_token:
continue
distribution[token] = score
return self._select_generation_token(
distribution,
context_tokens=context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=preserve_dominant_candidates,
)
def _prepare_generation_candidates(
self,
distribution: dict[str, float],
*,
generated_tokens: list[str],
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
preserve_dominant_candidates: bool = False,
) -> list[tuple[str, float]]:
assert self.tokenizer is not None
assert self.embedding_model is not None
generated_word_count = self._generated_word_count(generated_tokens)
clause_words = self._words_since_clause_break(generated_tokens)
recent_rendered_words = self._recent_rendered_words(generated_tokens)
best_probability = max(distribution.values(), default=0.0)
adjusted: list[tuple[str, float]] = []
for token, probability in sorted(distribution.items(), key=lambda item: item[1], reverse=True):
if token in self.tokenizer.special_tokens:
continue
if token == self.tokenizer.unk_token or probability <= 0.0:
continue
if not self._allowed_generation_token(token, generated_tokens):
continue
repeats_recent_pattern = self._would_repeat_recent_pattern(
token,
generated_tokens,
recent_rendered_words=recent_rendered_words,
)
if (
repeats_recent_pattern
and not (
preserve_dominant_candidates
and best_probability > 0.0
and probability >= best_probability * 0.80
)
):
continue
score = probability
rendered = self._render_token(token)
punctuation_token = self._is_structural_punctuation_token(token)
starts_new_word = self._starts_new_word(token)
alphanumeric = "".join(character for character in rendered if character.isalnum())
if generated_tokens and starts_new_word and alphanumeric:
previous_rendered = self._render_token(generated_tokens[-1])
previous_alphanumeric = "".join(
character for character in previous_rendered if character.isalnum()
)
if previous_alphanumeric.casefold() == alphanumeric.casefold():
continue
common_connector = self._is_common_connector_token(token)
if (
starts_new_word
and len(alphanumeric) == 1
and not common_connector
):
score *= 0.08
recent_count = generated_tokens[-12:].count(token)
if recent_count > 0 and not common_connector:
score /= repetition_penalty ** (2 * recent_count)
if generated_tokens and token == generated_tokens[-1]:
score /= repetition_penalty**3
if generated_tokens and token in generated_tokens[-4:] and not common_connector:
score *= 0.35
if generated_tokens and not starts_new_word and self._starts_new_word(generated_tokens[-1]):
score *= 0.08
if not generated_tokens and punctuation_token:
if best_probability <= 0.0 or probability < best_probability * 0.80:
score *= 0.01
elif not generated_tokens and not starts_new_word:
score *= 0.02
if punctuation_token:
if generated_tokens and self._is_structural_punctuation_token(generated_tokens[-1]):
score *= 0.05
if clause_words >= 6:
score *= 1.0 + min(1.4, 0.18 * (clause_words - 5))
elif generated_word_count >= 12:
score *= 1.1
if score > 0.0:
adjusted.append((token, score))
if not adjusted:
return []
adjusted.sort(key=lambda item: item[1], reverse=True)
if top_k > 0:
adjusted = adjusted[:top_k]
if 0.0 < top_p < 1.0:
kept: list[tuple[str, float]] = []
cumulative = 0.0
total = sum(score for _, score in adjusted)
for token, score in adjusted:
normalized = score / total if total else 0.0
kept.append((token, score))
cumulative += normalized
if cumulative >= top_p:
break
adjusted = kept
if temperature <= 0.0:
return [(adjusted[0][0], 1.0)]
exponent = 1.0 / temperature
tempered = [
(token, score**exponent)
for token, score in adjusted
if score > 0.0
]
total = sum(score for _, score in tempered)
if total <= 0.0:
return []
return [(token, score / total) for token, score in tempered]
def _sample_generation_candidate(
self,
candidates: list[tuple[str, float]],
*,
context_tokens: list[str],
generated_tokens: list[str],
stochastic: bool = False,
) -> str:
if not candidates:
return ""
if len(candidates) == 1:
return candidates[0][0]
top_probability = candidates[0][1]
second_probability = candidates[1][1]
top_has_clear_half_majority = top_probability >= 0.5 and (
second_probability <= 0.0
or top_probability - second_probability >= 0.02
)
if top_has_clear_half_majority or (
second_probability > 0.0 and top_probability >= second_probability * 2.5
) or (
top_probability >= 0.08
and second_probability > 0.0
and top_probability >= second_probability * 1.35
):
return candidates[0][0]
if stochastic:
threshold = random.random()
else:
seed_payload = "\u0002".join([*context_tokens, "<generated>", *generated_tokens, str(len(candidates))])
seed = int.from_bytes(hashlib.sha256(seed_payload.encode("utf-8")).digest()[:8], "big")
threshold = random.Random(seed).random()
cumulative = 0.0
for token, probability in candidates:
cumulative += probability
if threshold <= cumulative:
return token
return candidates[-1][0]
def _top_entries_from_vector(
self,
values: Vector,
limit: int,
) -> list[dict[str, object]]:
if limit <= 0:
return []
ranked = sorted(
enumerate(values),
key=lambda item: item[1],
reverse=True,
)
return [
self._token_entry(index, probability)
for index, probability in ranked[:limit]
if probability > 0.0
]
def _token_entry(
self,
index: int,
probability: float,
) -> dict[str, object]:
assert self.embedding_model is not None
token = self.embedding_model.id_to_token[index]
return {
"token": token,
"text": self._render_token(token),
"probability": probability,
}
def _build_reasoning_summary(
self,
transition_order: int | None,
blend_weights: dict[str, float],
) -> str:
dominant_source = max(blend_weights.items(), key=lambda item: item[1])[0] if blend_weights else "base"
if transition_order is not None:
transition_message = f" Transition prior is using order-{transition_order} context."
else:
transition_message = " Transition prior found no matching n-gram."
return (
"Generation is running on analytical state, recurrent traces, and corpus-derived token transitions."
f"{transition_message}"
f" Dominant blend source: {dominant_source}."
)
def _generated_word_count(self, tokens: list[str]) -> int:
return len(self._decode_tokens(tokens).split())
def _is_structural_punctuation_text(self, text: str) -> bool:
if len(text) != 1:
return False
if self._is_word_joiner_text(text):
return False
category = unicodedata.category(text)
return category.startswith("P")
def _is_structural_punctuation_token(self, token: str) -> bool:
return self._is_structural_punctuation_text(self._render_token(token))
def _is_structural_symbol_token(self, token: str) -> bool:
rendered = self._render_token(token)
return len(rendered) == 1 and unicodedata.category(rendered).startswith("S")
def _is_word_joiner_token(self, token: str) -> bool:
return self._is_word_joiner_text(self._render_token(token))
def _is_word_joiner_text(self, text: str) -> bool:
if len(text) != 1:
return False
category = unicodedata.category(text)
if category in ("Pc", "Pd", "Lm"):
return True
name = unicodedata.name(text, "")
return "APOSTROPHE" in name or (
"SINGLE" in name and "QUOTATION MARK" in name
)
def _can_start_line_with_word_joiner(self, token: str, generated_tokens: list[str]) -> bool:
rendered = self._render_token(token)
if len(rendered) != 1 or unicodedata.category(rendered) != "Pd":
return False
if not self._starts_new_word(token):
return False
return not generated_tokens or self._render_token(generated_tokens[-1]) == "\n"
def _can_start_answer_with_structural_punctuation(self, token: str) -> bool:
rendered = self._render_token(token)
if len(rendered) != 1 or not self._starts_new_word(token):
return False
return unicodedata.category(rendered) in ("Ps", "Pi")
def _is_common_connector_token(self, token: str) -> bool:
rendered = self._render_token(token)
return rendered.isalpha() and len(rendered) <= 3
def _can_attach_word_joiner(self, generated_tokens: list[str]) -> bool:
if not generated_tokens:
return False
rendered = self._render_token(generated_tokens[-1])
if not rendered:
return False
if any(character.isalnum() for character in rendered):
return True
if len(rendered) != 1:
return False
return unicodedata.category(rendered) in ("Ps", "Pi")
def _words_since_clause_break(self, tokens: list[str]) -> int:
assert self.tokenizer is not None
words = 0
for token in reversed(tokens):
if token in self.tokenizer.special_tokens:
continue
rendered = self._render_token(token)
if self._is_structural_punctuation_text(rendered):
break
if self._starts_new_word(token) and not self._is_punctuation_piece(rendered):
words += 1
return words
def _should_stop_generation(self, generated_tokens: list[str]) -> bool:
if not generated_tokens:
return False
if not self._is_terminal_punctuation_text(self._render_token(generated_tokens[-1])):
return False
return self._generated_word_count(generated_tokens) >= 14
def _is_terminal_punctuation_text(self, text: str) -> bool:
if not self._is_structural_punctuation_text(text):
return False
name = unicodedata.name(text, "")
return (
"FULL STOP" in name
or "QUESTION MARK" in name
or "EXCLAMATION MARK" in name
)
def _starts_new_word(self, token: str) -> bool:
assert self.tokenizer is not None
if token in self.tokenizer.special_tokens:
return True
if token.startswith(self.tokenizer.word_prefix):
return True
return len(token) == 1 and not token.isalnum() and not self._is_word_joiner_token(token)
def _decode_tokens(self, tokens: list[str]) -> str:
assert self.tokenizer is not None
return self.tokenizer.decode(tokens)
def _render_token(self, token: str) -> str:
assert self.tokenizer is not None
if token.startswith(self.tokenizer.word_prefix):
return token[len(self.tokenizer.word_prefix) :]
return token
def _require_fit(self) -> None:
if (
self.tokenizer is None
or self.embedding_model is None
or self.memory_units is None
or self.readout_weights is None
or self.ternary_mask is None
or self.associative_keys is None
or self.associative_key_norms is None
or self.associative_values is None
or self.transition_tables is None
):
raise RuntimeError("Call fit() before using the REFRAMR model.")
def _ensure_numeric_caches(self) -> None:
if np is None:
return
if self.readout_weights_array is None:
self._refresh_numeric_caches()