diff --git "a/reframr/model.py" "b/reframr/model.py" new file mode 100644--- /dev/null +++ "b/reframr/model.py" @@ -0,0 +1,4026 @@ +import json +import hashlib +import random +import site +import string +import sys +import unicodedata +from dataclasses import dataclass +from pathlib import Path + +_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor" +for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"): + if _vendor_path.exists(): + vendor_text = str(_vendor_path) + if vendor_text not in sys.path: + sys.path.insert(0, vendor_text) + +try: + import numpy as np +except ModuleNotFoundError: + user_site = site.getusersitepackages() + if user_site and user_site not in sys.path: + sys.path.append(user_site) + try: + import numpy as np + except ModuleNotFoundError: + np = None + +if np is not None and not hasattr(np, "asarray"): + np = None + +from .checkpoint import read_safetensor_file, write_safetensor_file +from .config import ReframrConfig +from .embeddings import EmbeddingModel, fit_ppmi_embedding_from_tokens +from .hippo import AnalyticalMemoryUnit, analytical_embedding_drive, analytical_embedding_drive_fast +from .linalg import Vector, dot, mean, norm, softmax, zeros_vector +from .reservoir import apply_readout, ridge_regression_readout +from .reasoning import reasoning_prefix +from .ternary import apply_ternary_mask, derive_ternary_mask_from_states +from .tokenizer import NativeTokenizer + +ASSOCIATIVE_BLEND = 0.42 +TRANSITION_BLEND = 0.08 +COPY_BLEND = 0.04 +BASE_BLEND = 0.34 +FAST_ASSOCIATIVE_BLEND = 0.06 +FAST_TRANSITION_BLEND = 0.14 +FAST_COPY_BLEND = 0.04 +FAST_BASE_BLEND = 0.58 +FAST_PREFERENCE_BLEND = 0.15 +FAST_ANSWER_BLEND = 0.30 +PROMPT_READOUT_LOGIT_ZSCORE_SCALE = 0.48 +ASSOCIATIVE_TOP_K = 12 +ANSWER_TOP_K = 48 +ANSWER_START_TOP_K = 32 +ANSWER_SEQUENCE_MATCH_FLOOR = 0.30 +ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR = 0.45 +ANSWER_SEQUENCE_LOCK_FLOOR = 0.55 +ANSWER_SEQUENCE_SPIKE_CONFIDENCE = 0.80 +READOUT_LOGIT_ZSCORE_SCALE = 0.22 +TRACE_IDENTITY_SCALE = 0.78 +TRACE_IDENTITY_HASHES = ( + (1103515245, 12345, 214013, 2531011), + (1664525, 1013904223, 22695477, 1), + (69069, 362437, 134775813, 17), + (134775813, 97, 1103515245, 31), + (22695477, 911, 1664525, 73), + (214013, 2531011, 69069, 19), + (48271, 0, 69621, 11), + (16807, 37, 40692, 101), + (279470273, 173, 1299709, 53), + (39916801, 29, 2147483629, 7), +) +NGRAM_KEY_SEPARATOR = "\u0001" +TRANSITION_ORDERS = (10, 8, 6, 5, 4, 3, 2, 1) +DEFAULT_GENERATION_TEMPERATURE = 0.82 +DEFAULT_GENERATION_TOP_K = 24 +DEFAULT_GENERATION_TOP_P = 0.92 +DEFAULT_REPETITION_PENALTY = 1.18 +ANSWER_SEQUENCE_MAX_TOKENS = 192 +RUNTIME_ARRAY_DTYPE = np.float32 if np is not None else None + + +@dataclass(frozen=True, slots=True) +class CharacterCountFact: + character: str + word: str + count: int + surface_seed: int + + +def _normalize_vector(values: Vector) -> Vector: + total = sum(values) + if total <= 0.0: + return [0.0 for _ in values] + return [value / total for value in values] + + +def _encode_ngram_key(tokens: tuple[str, ...]) -> str: + return NGRAM_KEY_SEPARATOR.join(tokens) + + +def _decode_ngram_key(key: str) -> tuple[str, ...]: + return tuple(part for part in key.split(NGRAM_KEY_SEPARATOR) if part) + + +def _last_index(values: list[str], target: str) -> int | None: + for index in range(len(values) - 1, -1, -1): + if values[index] == target: + return index + return None + + +@dataclass(slots=True) +class DecodeState: + hidden_states: list[Vector] + context_traces: list[Vector] + combined_state: Vector + context_tokens: list[str] + answer_anchor_state: Vector | None = None + answer_matches: list[tuple[float, int, int]] | None = None + answer_start_matches: list[tuple[float, int, int]] | None = None + answer_sequence_matches: list[tuple[float, int, int]] | None = None + prompt_answer_prior: object | None = None + prompt_answer_start_prior: object | None = None + + +@dataclass(slots=True) +class ReframrModel: + config: ReframrConfig + tokenizer: NativeTokenizer | None = None + embedding_model: EmbeddingModel | None = None + memory_units: list[AnalyticalMemoryUnit] | None = None + ternary_scale: float = 1.0 + ternary_mask: list[int] | None = None + ternary_mask_array: object | None = None + readout_weights: list[list[float]] | None = None + readout_weights_array: object | None = None + readout_bias: Vector | None = None + readout_bias_array: object | None = None + prompt_answer_weights: list[list[float]] | None = None + prompt_answer_weights_array: object | None = None + prompt_answer_bias: Vector | None = None + prompt_answer_bias_array: object | None = None + prompt_answer_start_weights: list[list[float]] | None = None + prompt_answer_start_weights_array: object | None = None + prompt_answer_start_bias: Vector | None = None + prompt_answer_start_bias_array: object | None = None + trace_token_weights: Vector | None = None + trace_token_weights_array: object | None = None + trace_embedding_table_array: object | None = None + preference_bias: Vector | None = None + preference_bias_array: object | None = None + preference_valid_mask_array: object | None = None + state_offset: Vector | None = None + state_offset_array: object | None = None + associative_keys: list[Vector] | None = None + associative_keys_array: object | None = None + associative_key_norms: list[float] | None = None + associative_key_norms_array: object | None = None + associative_values: list[int] | None = None + associative_values_array: object | None = None + associative_valid_mask_array: object | None = None + answer_keys: list[Vector] | None = None + answer_keys_array: object | None = None + answer_key_norms: list[float] | None = None + answer_key_norms_array: object | None = None + answer_similarity_keys_array: object | None = None + answer_similarity_key_norms_array: object | None = None + answer_similarity_mask_array: object | None = None + answer_values: list[int] | None = None + answer_values_array: object | None = None + answer_valid_mask_array: object | None = None + answer_start_keys: list[Vector] | None = None + answer_start_keys_array: object | None = None + answer_start_key_norms: list[float] | None = None + answer_start_key_norms_array: object | None = None + answer_start_similarity_keys_array: object | None = None + answer_start_similarity_key_norms_array: object | None = None + answer_start_values: list[int] | None = None + answer_start_values_array: object | None = None + answer_start_valid_mask_array: object | None = None + answer_sequence_keys: list[Vector] | None = None + answer_sequence_keys_array: object | None = None + answer_sequence_key_norms: list[float] | None = None + answer_sequence_key_norms_array: object | None = None + answer_sequence_similarity_keys_array: object | None = None + answer_sequence_similarity_key_norms_array: object | None = None + answer_sequence_prompt_tokens: list[list[int]] | None = None + answer_sequence_prompt_tokens_array: object | None = None + answer_sequence_tokens: list[list[int]] | None = None + answer_sequence_tokens_array: object | None = None + answer_sequence_prompt_weight_maps: list[dict[int, float]] | None = None + answer_sequence_prompt_weight_norms: list[float] | None = None + answer_sequence_prompt_bigram_sets: list[set[tuple[int, int]]] | None = None + answer_sequence_prompt_trigram_sets: list[set[tuple[int, int, int]]] | None = None + answer_sequence_prompt_number_sets: list[set[str]] | None = None + answer_sequence_prompt_inverted_index: dict[int, list[int]] | None = None + answer_sequence_prompt_specificity: dict[int, float] | None = None + transition_tables: dict[int, dict[tuple[str, ...], dict[str, float]]] | None = None + + def fit(self, text: str) -> "ReframrModel": + self.tokenizer = NativeTokenizer.train( + text, + vocab_size=self.config.tokenizer_vocab_size, + min_pair_frequency=self.config.tokenizer_min_pair_frequency, + lowercase=self.config.lowercase, + ) + tokens = self.tokenizer.encode(text) + if len(tokens) < 2: + raise ValueError("REFRAMR needs at least two tokens to derive a next-token readout.") + + self.embedding_model = fit_ppmi_embedding_from_tokens( + tokens, + embedding_dim=self.config.embedding_dim, + window_size=self.config.window_size, + min_frequency=self.config.min_frequency, + max_vocab=self.config.max_vocab, + ) + self.memory_units = [ + AnalyticalMemoryUnit(self.config.state_dim, timescale) + for timescale in self.config.timescales + ] + token_counts: dict[str, float] = {} + for token in tokens: + token_counts[token] = token_counts.get(token, 0.0) + 1.0 + self.trace_token_weights = self._derive_trace_token_weights_from_counts(token_counts) + + raw_states, targets, target_ids = self._collect_training_examples(tokens) + self.ternary_scale, self.ternary_mask = derive_ternary_mask_from_states(raw_states) + analytical_states = [ + apply_ternary_mask(state, self.ternary_mask, self.ternary_scale) + for state in raw_states + ] + self.associative_keys = [state[:] for state in analytical_states] + self.associative_key_norms = [norm(state) for state in analytical_states] + self.associative_values = target_ids[:] + self.answer_keys = [] + self.answer_key_norms = [] + self.answer_values = [] + self.answer_start_keys = [] + self.answer_start_key_norms = [] + self.answer_start_values = [] + self.answer_sequence_keys = [] + self.answer_sequence_key_norms = [] + self.answer_sequence_prompt_tokens = [] + self.answer_sequence_tokens = [] + self.prompt_answer_weights = [] + self.prompt_answer_bias = [0.0 for _ in self.embedding_model.id_to_token] + self.prompt_answer_start_weights = [] + self.prompt_answer_start_bias = [0.0 for _ in self.embedding_model.id_to_token] + self.transition_tables = self._build_transition_tables(tokens) + self._fit_answer_memory_from_text(text) + self.readout_weights = ridge_regression_readout( + analytical_states, + targets, + regularization=self.config.regularization, + ) + self.readout_bias = [0.0 for _ in self.embedding_model.id_to_token] + self.preference_bias = [0.0 for _ in self.embedding_model.id_to_token] + self.state_offset = [0.0 for _ in analytical_states[0]] if analytical_states else [] + self._refresh_numeric_caches() + return self + + def _fit_answer_memory_from_text(self, text: str) -> None: + assert self.tokenizer is not None + assert self.embedding_model is not None + if ( + self.answer_keys is None + or self.answer_key_norms is None + or self.answer_values is None + or self.answer_start_keys is None + or self.answer_start_key_norms is None + or self.answer_start_values is None + or self.answer_sequence_keys is None + or self.answer_sequence_key_norms is None + or self.answer_sequence_prompt_tokens is None + or self.answer_sequence_tokens is None + ): + return + + for line in text.splitlines(): + if "" not in line: + continue + prompt_text, answer_text = line.split("", 1) + prompt_text = prompt_text.strip() + answer_text = answer_text.strip() + if not prompt_text or not answer_text: + continue + + prompt_tokens = self.tokenizer.encode(prompt_text) + [""] + answer_tokens = [ + token + for token in self.tokenizer.encode(answer_text) + if token in self.embedding_model.token_to_id + and token not in self.tokenizer.special_tokens + ] + if not prompt_tokens or not answer_tokens: + continue + + key = self._encode_context(prompt_tokens) + key_norm = norm(key) + if key_norm <= 0.0: + continue + + answer_ids = [ + self.embedding_model.token_to_id[token] + for token in answer_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] + ] + prompt_ids = [ + self.embedding_model.token_to_id[token] + for token in prompt_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] + if token in self.embedding_model.token_to_id + and token not in self.tokenizer.special_tokens + ] + if not answer_ids: + continue + + self.answer_keys.append(key[:]) + self.answer_key_norms.append(key_norm) + self.answer_values.append(answer_ids[0]) + self.answer_start_keys.append(key[:]) + self.answer_start_key_norms.append(key_norm) + self.answer_start_values.append(answer_ids[0]) + self.answer_sequence_keys.append(key[:]) + self.answer_sequence_key_norms.append(key_norm) + self.answer_sequence_prompt_tokens.append( + prompt_ids + + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(prompt_ids))] + ) + self.answer_sequence_tokens.append( + answer_ids + + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(answer_ids))] + ) + + def predict_next_distribution( + self, + context: str, + *, + reasoning_mode: str | None = None, + ) -> dict[str, float]: + self._require_fit() + assert self.tokenizer is not None + assert self.embedding_model is not None + probabilities = self.predict_next_token_distribution( + context, + reasoning_mode=reasoning_mode, + ) + distribution: dict[str, float] = {} + for token, probability in probabilities.items(): + rendered = self._render_token(token) + distribution[rendered] = distribution.get(rendered, 0.0) + probability + return distribution + + def predict_next_token_distribution( + self, + context: str, + *, + reasoning_mode: str | None = None, + ) -> dict[str, float]: + self._require_fit() + assert self.tokenizer is not None + assert self.embedding_model is not None + assert self.readout_weights is not None + + active_mode = reasoning_mode or self.config.default_reasoning_profile + context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) + return self._predict_next_token_distribution_from_tokens(context_tokens) + + def generate_text( + self, + context: str, + *, + max_tokens: int = 64, + reasoning_mode: str | None = None, + temperature: float = 0.0, + top_k: int = DEFAULT_GENERATION_TOP_K, + top_p: float = DEFAULT_GENERATION_TOP_P, + repetition_penalty: float = DEFAULT_REPETITION_PENALTY, + ) -> str: + character_count_response = self._character_count_response( + context, + temperature=temperature, + ) + if character_count_response is not None: + return character_count_response + self._require_fit() + self._ensure_numeric_caches() + assert self.tokenizer is not None + if ( + np is not None + and self.readout_weights_array is not None + and self.embedding_model is not None + and len(self.embedding_model.id_to_token) >= 1024 + ): + return self._generate_text_fast( + context, + max_tokens=max_tokens, + reasoning_mode=reasoning_mode, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + + active_mode = reasoning_mode or self.config.default_reasoning_profile + _, context_tokens = self._generation_prompt_tokens(context, active_mode) + decode_state = self._build_decode_state(context_tokens) + generated_tokens: list[str] = [] + for _ in range(max_tokens): + distribution, _ = self._score_next_token_from_state( + decode_state, + include_trace=False, + generated_tokens=generated_tokens, + ) + next_token = self._select_generation_token( + distribution, + context_tokens=decode_state.context_tokens, + generated_tokens=generated_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + preserve_dominant_candidates=self._answer_decode_has_continuation( + decode_state, + generated_tokens, + ), + ) + if not next_token: + break + generated_tokens.append(next_token) + self._advance_decode_state(decode_state, next_token) + if self._should_stop_answer_sequence(decode_state, generated_tokens): + break + if self._should_stop_generation( + generated_tokens + ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): + break + overflow_budget = 6 + while ( + generated_tokens + and not self._starts_new_word(generated_tokens[-1]) + and overflow_budget > 0 + ): + distribution, _ = self._score_next_token_from_state( + decode_state, + include_trace=False, + generated_tokens=generated_tokens, + ) + next_token = self._select_generation_token( + distribution, + context_tokens=decode_state.context_tokens, + generated_tokens=generated_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + preserve_dominant_candidates=self._answer_decode_has_continuation( + decode_state, + generated_tokens, + ), + ) + if not next_token or self._starts_new_word(next_token): + break + generated_tokens.append(next_token) + self._advance_decode_state(decode_state, next_token) + overflow_budget -= 1 + return self._decode_tokens(generated_tokens) + + @staticmethod + def _character_count_fact(context: str) -> CharacterCountFact | None: + normalized = unicodedata.normalize("NFKC", context).strip() + tokens = ReframrModel._character_count_word_tokens(normalized) + if not tokens: + return None + lowered = [token.casefold() for token in tokens] + count_terms = {"count", "counts", "counting", "many"} + unit_terms = {"character", "characters", "letter", "letters"} + if not any(token in count_terms for token in lowered): + return None + if not any(token in unit_terms for token in lowered) and "count" not in lowered: + return None + + filler_terms = {"a", "an", "the", "single", "one", "please"} + word_markers = {"in", "inside"} + char_index = ReframrModel._character_count_target_index( + lowered, + unit_terms=unit_terms, + filler_terms=filler_terms, + ) + word_index = ReframrModel._character_count_word_index( + lowered, + char_index=char_index, + filler_terms=filler_terms, + word_markers=word_markers, + ) + if char_index is None or word_index is None: + return None + character = tokens[char_index] + word = tokens[word_index] + if len(character) != 1 or not word: + return None + order_offset = 0 if char_index < word_index else 1 + surface_seed = ((char_index + 1) * 7 + (word_index + 1) * 3 + len(tokens) + order_offset) % 4 + return CharacterCountFact( + character=character, + word=word, + count=word.casefold().count(character.casefold()), + surface_seed=surface_seed, + ) + + @staticmethod + def _character_count_word_tokens(text: str) -> list[str]: + tokens: list[str] = [] + current: list[str] = [] + for character in text: + if character != "_" and character.isalnum(): + current.append(character) + continue + if current: + tokens.append("".join(current)) + current = [] + if current: + tokens.append("".join(current)) + return tokens + + @staticmethod + def _character_count_target_index( + tokens: list[str], + *, + unit_terms: set[str], + filler_terms: set[str], + ) -> int | None: + for index, token in enumerate(tokens): + if token not in unit_terms: + continue + for adjacent in (index - 1, index + 1): + if 0 <= adjacent < len(tokens) and len(tokens[adjacent]) == 1: + return adjacent + before = ReframrModel._nearest_content_index(tokens, index - 1, -1, filler_terms) + after = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) + for candidate in (before, after): + if candidate is not None and len(tokens[candidate]) == 1: + return candidate + for index, token in enumerate(tokens): + if token not in {"count", "counts", "counting"}: + continue + candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) + if candidate is not None and tokens[candidate] in unit_terms: + candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) + if candidate is not None and len(tokens[candidate]) == 1: + return candidate + return None + + @staticmethod + def _character_count_word_index( + tokens: list[str], + *, + char_index: int | None, + filler_terms: set[str], + word_markers: set[str], + ) -> int | None: + for index, token in enumerate(tokens): + if token != "word": + continue + candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) + if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: + return candidate + for index, token in enumerate(tokens): + if token not in word_markers: + continue + candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) + if candidate is not None and tokens[candidate] == "word": + candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) + if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: + return candidate + skipped_terms = { + "how", + "many", + "count", + "counts", + "counting", + "letter", + "letters", + "character", + "characters", + "word", + "there", + "are", + "is", + "appear", + "appears", + "times", + } | filler_terms | word_markers + for index in range(len(tokens) - 1, -1, -1): + if index == char_index: + continue + if len(tokens[index]) <= 1 or tokens[index] in skipped_terms: + continue + return index + return None + + @staticmethod + def _nearest_content_index( + tokens: list[str], + start: int, + direction: int, + skipped_terms: set[str], + ) -> int | None: + index = start + while 0 <= index < len(tokens): + if tokens[index] not in skipped_terms: + return index + index += direction + return None + + @classmethod + def _character_count_response(cls, context: str, *, temperature: float = 0.0) -> str | None: + fact = cls._character_count_fact(context) + if fact is None: + return None + return cls._render_character_count_fact(fact, temperature=temperature) + + @staticmethod + def _render_character_count_fact(fact: CharacterCountFact, *, temperature: float = 0.0) -> str: + character_label = f"'{fact.character}'" + word_label = f"'{fact.word}'" + character_noun = "character" if fact.count == 1 else "characters" + plural_times = "" if fact.count == 1 else "s" + surfaces = ( + f"There {'is' if fact.count == 1 else 'are'} {fact.count} {character_label} {character_noun} in {word_label}.", + f"{word_label} contains {fact.count} {character_label} {character_noun}.", + f"In {word_label}, {character_label} appears {fact.count} time{plural_times}.", + f"The count is {fact.count} for {character_label} in {word_label}.", + ) + if temperature > 0.0: + return surfaces[(random.randrange(len(surfaces)) + fact.surface_seed) % len(surfaces)] + return surfaces[fact.surface_seed % len(surfaces)] + + def _generate_text_fast( + self, + context: str, + *, + max_tokens: int, + reasoning_mode: str | None, + temperature: float, + top_k: int, + top_p: float, + repetition_penalty: float, + ) -> str: + assert self.tokenizer is not None + + active_mode = reasoning_mode or self.config.default_reasoning_profile + _, context_tokens = self._generation_prompt_tokens(context, active_mode) + decode_state = self._build_decode_state(context_tokens) + generated_tokens: list[str] = [] + for _ in range(max_tokens): + probabilities, _ = self._score_next_token_array_from_state( + decode_state, + include_associative=True, + generated_tokens=generated_tokens, + ) + next_token = self._select_generation_token_from_array( + probabilities, + context_tokens=decode_state.context_tokens, + generated_tokens=generated_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + preserve_dominant_candidates=self._answer_decode_has_continuation( + decode_state, + generated_tokens, + ), + ) + if not next_token: + break + generated_tokens.append(next_token) + self._advance_decode_state(decode_state, next_token) + if self._should_stop_answer_sequence(decode_state, generated_tokens): + break + if self._should_stop_generation( + generated_tokens + ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): + break + + overflow_budget = 6 + while ( + generated_tokens + and not self._starts_new_word(generated_tokens[-1]) + and overflow_budget > 0 + ): + probabilities, _ = self._score_next_token_array_from_state( + decode_state, + include_associative=True, + generated_tokens=generated_tokens, + ) + next_token = self._select_generation_token_from_array( + probabilities, + context_tokens=decode_state.context_tokens, + generated_tokens=generated_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + preserve_dominant_candidates=self._answer_decode_has_continuation( + decode_state, + generated_tokens, + ), + ) + if not next_token or self._starts_new_word(next_token): + break + generated_tokens.append(next_token) + self._advance_decode_state(decode_state, next_token) + overflow_budget -= 1 + return self._decode_tokens(generated_tokens) + + def trace_next_token( + self, + context: str, + *, + reasoning_mode: str | None = None, + top_k: int = 5, + ) -> dict[str, object]: + self._require_fit() + assert self.tokenizer is not None + + active_mode = reasoning_mode or self.config.default_reasoning_profile + context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) + _, trace = self._score_next_token_from_tokens( + context_tokens, + top_k=top_k, + include_trace=True, + ) + trace.update( + { + "context": context, + "reasoning_mode": active_mode, + "reasoning_tokens": reasoning_prefix(active_mode), + "context_tokens": context_tokens, + } + ) + return trace + + def trace_generation( + self, + context: str, + *, + max_tokens: int = 16, + reasoning_mode: str | None = None, + top_k: int = 5, + temperature: float = 0.0, + top_p: float = DEFAULT_GENERATION_TOP_P, + repetition_penalty: float = DEFAULT_REPETITION_PENALTY, + ) -> dict[str, object]: + character_count_response = self._character_count_response( + context, + temperature=temperature, + ) + if character_count_response is not None: + active_mode = reasoning_mode or self.config.default_reasoning_profile + prompt = context if "" in context else f"{context} " + return { + "context": context, + "prompt": prompt, + "reasoning_mode": active_mode, + "reasoning_tokens": reasoning_prefix(active_mode), + "generation_policy": { + "temperature": temperature, + "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), + "top_p": top_p, + "repetition_penalty": repetition_penalty, + }, + "prompt_tokens": [], + "generated_tokens": [], + "generated_text": character_count_response, + "generated_token_count": len(character_count_response.split()), + "steps": [], + "reasoning_summary": ( + "The prompt matched the generic character-counting path, so Reframr " + "read the requested character and word from the prompt and counted " + "the characters directly." + ), + } + self._require_fit() + assert self.tokenizer is not None + + active_mode = reasoning_mode or self.config.default_reasoning_profile + prompt, context_tokens = self._generation_prompt_tokens(context, active_mode) + decode_state = self._build_decode_state(context_tokens) + prompt_tokens = decode_state.context_tokens[:] + generated_tokens: list[str] = [] + steps: list[dict[str, object]] = [] + + for step_index in range(1, max_tokens + 1): + distribution, trace = self._score_next_token_from_state( + decode_state, + top_k=top_k, + include_trace=True, + generated_tokens=generated_tokens, + ) + next_token = self._select_generation_token( + distribution, + context_tokens=decode_state.context_tokens, + generated_tokens=generated_tokens, + temperature=temperature, + top_k=max(DEFAULT_GENERATION_TOP_K, top_k), + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + if not next_token: + break + generated_tokens.append(next_token) + self._advance_decode_state(decode_state, next_token) + trace["step"] = step_index + trace["chosen_token"] = next_token + trace["chosen_text"] = self._render_token(next_token) + trace["chosen_probability"] = distribution[next_token] + steps.append(trace) + if self._should_stop_generation( + generated_tokens + ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): + break + + overflow_budget = 6 + while ( + generated_tokens + and not self._starts_new_word(generated_tokens[-1]) + and overflow_budget > 0 + ): + distribution, trace = self._score_next_token_from_state( + decode_state, + top_k=top_k, + include_trace=True, + generated_tokens=generated_tokens, + ) + next_token = self._select_generation_token( + distribution, + context_tokens=decode_state.context_tokens, + generated_tokens=generated_tokens, + temperature=temperature, + top_k=max(DEFAULT_GENERATION_TOP_K, top_k), + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + if not next_token or self._starts_new_word(next_token): + break + generated_tokens.append(next_token) + self._advance_decode_state(decode_state, next_token) + trace["step"] = len(steps) + 1 + trace["chosen_token"] = next_token + trace["chosen_text"] = self._render_token(next_token) + trace["chosen_probability"] = distribution[next_token] + steps.append(trace) + overflow_budget -= 1 + + return { + "context": context, + "prompt": prompt, + "reasoning_mode": active_mode, + "reasoning_tokens": reasoning_prefix(active_mode), + "generation_policy": { + "temperature": temperature, + "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), + "top_p": top_p, + "repetition_penalty": repetition_penalty, + }, + "prompt_tokens": prompt_tokens, + "generated_tokens": generated_tokens, + "generated_text": self._decode_tokens(generated_tokens), + "generated_token_count": len(generated_tokens), + "steps": steps, + } + + def _generation_prompt_tokens(self, context: str, active_mode: str) -> tuple[str, list[str]]: + assert self.tokenizer is not None + prompt = context if "" in context else f"{context} " + prefix = reasoning_prefix(active_mode) + prompt_tokens = self.tokenizer.encode(prompt) + if ( + "" in prompt_tokens + and "" not in prompt_tokens + and "" not in prefix + ): + prompt_tokens = [""] + prompt_tokens + return prompt, prefix + prompt_tokens + + def _predict_next_token_distribution_from_tokens( + self, + context_tokens: list[str], + ) -> dict[str, float]: + decode_state = self._build_decode_state(context_tokens) + return self._predict_next_token_distribution_from_state(decode_state) + + def _predict_next_token_distribution_from_state( + self, + decode_state: DecodeState, + ) -> dict[str, float]: + probabilities, _ = self._score_next_token_from_state( + decode_state, + include_trace=False, + ) + return probabilities + + @staticmethod + def _answer_sequence_should_lock( + *, + answer_sequence_confidence: float, + answer_sequence_match_confidence: float, + has_answer_sequence_prior: bool, + ) -> bool: + if not has_answer_sequence_prior or answer_sequence_confidence <= 0.0: + return False + if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: + return True + return ( + answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR + and answer_sequence_confidence <= ANSWER_SEQUENCE_SPIKE_CONFIDENCE + ) + + @staticmethod + def _answer_start_blend_weights( + *, + answer_sequence_match_confidence: float, + ) -> dict[str, float]: + if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: + return { + "prompt_answer_start": 0.35, + "prompt_answer": 0.10, + "answer_sequence": 0.45, + "answer_start": 0.10, + } + return { + "prompt_answer_start": 0.55, + "prompt_answer": 0.20, + "answer_sequence": 0.15, + "answer_start": 0.10, + } + + def _score_next_token_from_tokens( + self, + context_tokens: list[str], + *, + top_k: int = 5, + include_trace: bool = True, + ) -> tuple[dict[str, float], dict[str, object]]: + decode_state = self._build_decode_state(context_tokens) + return self._score_next_token_from_state( + decode_state, + top_k=top_k, + include_trace=include_trace, + ) + + def _score_next_token_from_state( + self, + decode_state: DecodeState, + *, + top_k: int = 5, + include_trace: bool = True, + generated_tokens: list[str] | None = None, + ) -> tuple[dict[str, float], dict[str, object]]: + assert self.embedding_model is not None + assert self.readout_weights is not None + generated_tokens = generated_tokens or [] + + state = self._masked_decode_state(decode_state) + logits = self._apply_readout_fast(state) + base_probabilities = self._calibrated_softmax(logits) + if decode_state.answer_matches is None: + decode_state.answer_matches = self._score_answer_matches( + decode_state.answer_anchor_state, + limit=max(ANSWER_TOP_K, top_k) if include_trace else ANSWER_TOP_K, + ) + answer_matches = decode_state.answer_matches + if decode_state.answer_start_matches is None: + decode_state.answer_start_matches = self._score_answer_start_matches( + decode_state.answer_anchor_state, + limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, + ) + answer_start_matches = decode_state.answer_start_matches + if decode_state.answer_sequence_matches is None: + decode_state.answer_sequence_matches = self._score_answer_sequence_matches( + decode_state.answer_anchor_state, + decode_state.context_tokens, + limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, + ) + answer_sequence_matches = decode_state.answer_sequence_matches + answer_prior = self._answer_prior_from_matches(answer_matches, generated_tokens) + answer_start_prior = self._answer_prior_from_matches(answer_start_matches, generated_tokens) + answer_sequence_prior = self._answer_sequence_prior_from_matches( + answer_sequence_matches, + generated_tokens, + ) + answer_sequence_confidence = max(answer_sequence_prior) if answer_sequence_prior else 0.0 + answer_sequence_match_confidence = ( + answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 + ) + has_answer_sequence_prior = any(value > 0.0 for value in answer_sequence_prior) + answer_locked = self._answer_sequence_should_lock( + answer_sequence_confidence=answer_sequence_confidence, + answer_sequence_match_confidence=answer_sequence_match_confidence, + has_answer_sequence_prior=has_answer_sequence_prior, + ) + if decode_state.prompt_answer_prior is None: + decode_state.prompt_answer_prior = self._prompt_answer_readout_prior( + decode_state.answer_anchor_state, + start=False, + ) + prompt_answer_prior = decode_state.prompt_answer_prior + prompt_answer_start_prior = ( + decode_state.prompt_answer_start_prior + if not generated_tokens + else [0.0 for _ in self.embedding_model.id_to_token] + ) + if not generated_tokens and prompt_answer_start_prior is None: + decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior( + decode_state.answer_anchor_state, + start=True, + ) + prompt_answer_start_prior = decode_state.prompt_answer_start_prior + use_answer_start = ( + not generated_tokens + and ( + any(value > 0.0 for value in answer_start_prior) + or any(value > 0.0 for value in prompt_answer_start_prior) + ) + ) + if answer_locked: + answer_prior = answer_sequence_prior + elif use_answer_start: + start_blend = self._answer_start_blend_weights( + answer_sequence_match_confidence=answer_sequence_match_confidence + ) + answer_prior = self._weighted_prior_sum( + [ + (start_blend["prompt_answer_start"], prompt_answer_start_prior), + (start_blend["prompt_answer"], prompt_answer_prior), + (start_blend["answer_sequence"], answer_sequence_prior), + (start_blend["answer_start"], answer_start_prior), + ], + ) + elif any(value > 0.0 for value in answer_sequence_prior): + answer_prior = self._weighted_prior_sum( + [ + (0.50, prompt_answer_prior), + (0.30, answer_sequence_prior), + (0.20, answer_prior), + ], + ) + elif any(value > 0.0 for value in prompt_answer_prior): + answer_prior = self._weighted_prior_sum( + [ + (0.65, prompt_answer_prior), + (0.35, answer_prior), + ], + ) + associative_matches = ( + [] + if use_answer_start + else self._score_associative_matches( + state, + limit=max(ASSOCIATIVE_TOP_K, top_k) if include_trace else ASSOCIATIVE_TOP_K, + ) + ) + associative_prior = ( + [0.0 for _ in self.embedding_model.id_to_token] + if use_answer_start + else self._associative_prior_from_matches(associative_matches) + ) + transition_prior, transition_order = self._transition_prior_with_order(decode_state.context_tokens) + copy_prior = self._copy_prior(decode_state.context_tokens) + preference_prior = self._preference_prior() + probabilities, blend_weights = self._blend_probabilities( + base_probabilities, + answer_prior, + associative_prior, + transition_prior, + copy_prior, + preference_prior, + transition_order=transition_order, + generated_count=len(generated_tokens), + answer_locked=answer_locked, + answer_guided_start=use_answer_start, + ) + distribution = { + token: probabilities[index] + for index, token in enumerate(self.embedding_model.id_to_token) + } + if not include_trace: + return distribution, {} + + trace = { + "state_norm": norm(state), + "blend_weights": blend_weights, + "transition_order": transition_order, + "base_top_predictions": self._top_entries_from_vector(base_probabilities, top_k), + "answer_top_predictions": self._top_entries_from_vector(answer_prior, top_k), + "prompt_answer_top_predictions": self._top_entries_from_vector(prompt_answer_prior, top_k), + "prompt_answer_start_top_predictions": self._top_entries_from_vector(prompt_answer_start_prior, top_k), + "answer_start_top_predictions": self._top_entries_from_vector(answer_start_prior, top_k), + "answer_sequence_top_predictions": self._top_entries_from_vector(answer_sequence_prior, top_k), + "associative_top_predictions": self._top_entries_from_vector(associative_prior, top_k), + "transition_top_predictions": self._top_entries_from_vector(transition_prior, top_k), + "copy_top_predictions": self._top_entries_from_vector(copy_prior, top_k), + "preference_top_predictions": self._top_entries_from_vector(preference_prior, top_k), + "final_top_predictions": self._top_entries_from_vector(probabilities, top_k), + "associative_matches": [ + { + "example_index": example_index, + "similarity": similarity, + **self._token_entry(token_id, similarity), + } + for similarity, token_id, example_index in associative_matches[:top_k] + ], + "answer_matches": [ + { + "example_index": example_index, + "similarity": similarity, + **self._token_entry(token_id, similarity), + } + for similarity, token_id, example_index in answer_matches[:top_k] + ], + "answer_start_matches": [ + { + "example_index": example_index, + "similarity": similarity, + **self._token_entry(token_id, similarity), + } + for similarity, token_id, example_index in answer_start_matches[:top_k] + ], + "answer_sequence_matches": [ + { + "example_index": example_index, + "similarity": similarity, + } + for similarity, _, example_index in answer_sequence_matches[:top_k] + ], + "reasoning_summary": self._build_reasoning_summary( + transition_order, + blend_weights, + ), + } + return distribution, trace + + def _score_next_token_array_from_state( + self, + decode_state: DecodeState, + *, + include_associative: bool, + generated_tokens: list[str] | None = None, + ) -> tuple[object, dict[str, float]]: + assert np is not None + assert self.embedding_model is not None + generated_tokens = generated_tokens or [] + + state = self._masked_decode_state_array(decode_state) + logits = self._apply_readout_array(state) + base_probabilities = self._calibrated_softmax_array(logits) + if decode_state.answer_matches is None: + decode_state.answer_matches = self._score_answer_matches(decode_state.answer_anchor_state) + answer_prior = np.asarray( + self._answer_prior_from_matches( + decode_state.answer_matches, + generated_tokens, + ), + dtype=np.float64, + ) + if decode_state.answer_sequence_matches is None: + decode_state.answer_sequence_matches = self._score_answer_sequence_matches( + decode_state.answer_anchor_state, + decode_state.context_tokens, + ) + answer_sequence_matches = decode_state.answer_sequence_matches + answer_sequence_prior = np.asarray( + self._answer_sequence_prior_from_matches( + answer_sequence_matches, + generated_tokens, + ), + dtype=np.float64, + ) + answer_sequence_confidence = ( + float(answer_sequence_prior.max()) if answer_sequence_prior.size else 0.0 + ) + answer_sequence_match_confidence = ( + answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 + ) + has_answer_sequence_prior = bool(np.any(answer_sequence_prior > 0.0)) + answer_locked = self._answer_sequence_should_lock( + answer_sequence_confidence=answer_sequence_confidence, + answer_sequence_match_confidence=answer_sequence_match_confidence, + has_answer_sequence_prior=has_answer_sequence_prior, + ) + if decode_state.prompt_answer_prior is None: + decode_state.prompt_answer_prior = self._prompt_answer_readout_prior_array( + decode_state.answer_anchor_state, + start=False, + ) + prompt_answer_prior = decode_state.prompt_answer_prior + prompt_answer_start_prior = np.zeros_like(base_probabilities) + use_answer_start = False + if answer_locked: + answer_prior = answer_sequence_prior + elif not generated_tokens: + if decode_state.prompt_answer_start_prior is None: + decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior_array( + decode_state.answer_anchor_state, + start=True, + ) + prompt_answer_start_prior = decode_state.prompt_answer_start_prior + if decode_state.answer_start_matches is None: + decode_state.answer_start_matches = self._score_answer_start_matches( + decode_state.answer_anchor_state + ) + answer_start_prior = np.asarray( + self._answer_prior_from_matches( + decode_state.answer_start_matches, + generated_tokens, + ), + dtype=np.float64, + ) + if np.any(answer_start_prior > 0.0) or np.any(prompt_answer_start_prior > 0.0): + start_blend = self._answer_start_blend_weights( + answer_sequence_match_confidence=answer_sequence_match_confidence + ) + answer_prior = self._weighted_prior_sum_array( + [ + (start_blend["prompt_answer_start"], prompt_answer_start_prior), + (start_blend["prompt_answer"], prompt_answer_prior), + (start_blend["answer_sequence"], answer_sequence_prior), + (start_blend["answer_start"], answer_start_prior), + ], + ) + use_answer_start = True + if answer_locked: + answer_prior = answer_sequence_prior + elif not use_answer_start and np.any(answer_sequence_prior > 0.0): + answer_prior = self._weighted_prior_sum_array( + [ + (0.50, prompt_answer_prior), + (0.30, answer_sequence_prior), + (0.20, answer_prior), + ], + ) + elif not use_answer_start and np.any(prompt_answer_prior > 0.0): + answer_prior = self._weighted_prior_sum_array( + [ + (0.65, prompt_answer_prior), + (0.35, answer_prior), + ], + ) + if include_associative and not use_answer_start: + associative_prior = np.asarray( + self._associative_prior_from_matches( + self._score_associative_matches(state) + ), + dtype=np.float64, + ) + else: + associative_prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + transition_prior, transition_order = self._transition_prior_array_with_order( + decode_state.context_tokens + ) + copy_prior = self._copy_prior_array(decode_state.context_tokens) + preference_prior = self._preference_prior_array() + return self._blend_probability_arrays( + base_probabilities, + answer_prior, + associative_prior, + transition_prior, + copy_prior, + preference_prior, + transition_order=transition_order, + generated_count=len(generated_tokens), + answer_locked=answer_locked, + answer_guided_start=use_answer_start, + ) + + def _calibrated_softmax( + self, + logits: Vector, + *, + scale: float = READOUT_LOGIT_ZSCORE_SCALE, + ) -> Vector: + if np is not None: + return self._calibrated_softmax_array( + np.asarray(logits, dtype=np.float64), + scale=scale, + ).tolist() + if not logits: + return [] + center = mean(logits) + variance = mean([(value - center) * (value - center) for value in logits]) + spread = variance**0.5 + if spread <= 1e-12: + return softmax(logits) + calibrated = [ + max(-20.0, min(20.0, ((value - center) / spread) * scale)) + for value in logits + ] + return softmax(calibrated) + + def _calibrated_softmax_array( + self, + logits: object, + *, + scale: float = READOUT_LOGIT_ZSCORE_SCALE, + ) -> object: + assert np is not None + values = np.asarray(logits, dtype=np.float64) + if values.size == 0: + return values + spread = float(values.std()) + if spread > 1e-12: + values = ((values - float(values.mean())) / spread) * scale + values = np.clip(values, -20.0, 20.0) + else: + values = values - float(values.max()) + values = values - float(values.max()) + exponentials = np.exp(values) + total = float(exponentials.sum()) + if total <= 0.0: + return np.full(values.shape, 1.0 / max(1, values.size), dtype=np.float64) + return exponentials / total + + def _weighted_prior_sum(self, sources: list[tuple[float, Vector]]) -> Vector: + assert self.embedding_model is not None + active_sources = [ + (weight, vector) + for weight, vector in sources + if weight > 0.0 and any(value > 0.0 for value in vector) + ] + if not active_sources: + return [0.0 for _ in self.embedding_model.id_to_token] + total_weight = sum(weight for weight, _ in active_sources) + merged = [0.0 for _ in self.embedding_model.id_to_token] + for weight, vector in active_sources: + normalized_weight = weight / total_weight + for index, value in enumerate(vector): + merged[index] += normalized_weight * value + return _normalize_vector(merged) + + def _weighted_prior_sum_array(self, sources: list[tuple[float, object]]) -> object: + assert np is not None + assert self.embedding_model is not None + active_sources = [ + (weight, np.asarray(vector, dtype=np.float64)) + for weight, vector in sources + if weight > 0.0 and np.any(np.asarray(vector, dtype=np.float64) > 0.0) + ] + if not active_sources: + return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + total_weight = sum(weight for weight, _ in active_sources) + merged = np.zeros_like(active_sources[0][1], dtype=np.float64) + for weight, vector in active_sources: + merged += (weight / total_weight) * vector + total = float(merged.sum()) + if total > 0.0: + merged /= total + return merged + + def _prompt_answer_readout_prior( + self, + answer_anchor_state: Vector | None, + *, + start: bool, + ) -> Vector: + assert self.embedding_model is not None + if answer_anchor_state is None: + return [0.0 for _ in self.embedding_model.id_to_token] + weights = self.prompt_answer_start_weights if start else self.prompt_answer_weights + bias = self.prompt_answer_start_bias if start else self.prompt_answer_bias + if np is not None: + return self._prompt_answer_readout_prior_array( + answer_anchor_state, + start=start, + ).tolist() + if not weights: + return [0.0 for _ in self.embedding_model.id_to_token] + state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) + logits = apply_readout(weights, state) + if bias: + logits = [value + bias[index] for index, value in enumerate(logits)] + return self._calibrated_softmax( + logits, + scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, + ) + + def _prompt_answer_readout_prior_array( + self, + answer_anchor_state: Vector | None, + *, + start: bool, + ) -> object: + assert np is not None + assert self.embedding_model is not None + if answer_anchor_state is None: + return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + weights = ( + self.prompt_answer_start_weights_array + if start + else self.prompt_answer_weights_array + ) + bias = self.prompt_answer_start_bias_array if start else self.prompt_answer_bias_array + if weights is None: + return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + state_array = self._center_state_array( + self._masked_combined_state_array(answer_anchor_state) + ) + logits = weights @ state_array + if bias is not None and bias.shape == logits.shape: + logits = logits + bias + return self._calibrated_softmax_array( + logits, + scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, + ) + + def save(self, path: str | Path) -> None: + self._require_fit() + assert self.tokenizer is not None + assert self.embedding_model is not None + assert self.ternary_mask is not None + assert self.readout_weights is not None + assert self.associative_keys is not None + assert self.associative_values is not None + assert self.transition_tables is not None + + metadata = { + "schema_version": "1", + "checkpoint_kind": "reframr-analytical", + "tokenizer_name": self.tokenizer.name, + "config": json.dumps(self.config.to_dict(), separators=(",", ":")), + "tokenizer": json.dumps(self.tokenizer.to_dict(), separators=(",", ":")), + "embedding_id_to_token": json.dumps(self.embedding_model.id_to_token, separators=(",", ":")), + "tokenizer_vocab_size": str(self.tokenizer.vocab_size), + "transition_tables": json.dumps(self._serialize_transition_tables(), separators=(",", ":")), + } + tensors = { + "embedding_table": self.embedding_model.embeddings, + "ternary_scale": [self.ternary_scale], + "ternary_mask": self.ternary_mask, + "readout_weights": self.readout_weights, + "readout_bias": self.readout_bias + or [0.0 for _ in self.embedding_model.id_to_token], + "prompt_answer_weights": self.prompt_answer_weights + if self.prompt_answer_weights is not None + else [], + "prompt_answer_bias": self.prompt_answer_bias + or [0.0 for _ in self.embedding_model.id_to_token], + "prompt_answer_start_weights": self.prompt_answer_start_weights + if self.prompt_answer_start_weights is not None + else [], + "prompt_answer_start_bias": self.prompt_answer_start_bias + or [0.0 for _ in self.embedding_model.id_to_token], + "trace_token_weights": self.trace_token_weights + or [1.0 for _ in self.embedding_model.id_to_token], + "preference_bias": self.preference_bias + or [0.0 for _ in self.embedding_model.id_to_token], + "state_offset": self.state_offset + or [0.0 for _ in range(self._combined_state_width())], + "associative_keys": self.associative_keys, + "associative_values": self.associative_values, + "answer_keys": self.answer_keys if self.answer_keys is not None else [], + "answer_values": self.answer_values if self.answer_values is not None else [], + "answer_start_keys": self.answer_start_keys if self.answer_start_keys is not None else [], + "answer_start_values": self.answer_start_values if self.answer_start_values is not None else [], + "answer_sequence_keys": self.answer_sequence_keys if self.answer_sequence_keys is not None else [], + "answer_sequence_prompt_tokens": self.answer_sequence_prompt_tokens if self.answer_sequence_prompt_tokens is not None else [], + "answer_sequence_tokens": self.answer_sequence_tokens if self.answer_sequence_tokens is not None else [], + } + write_safetensor_file(path, tensors, metadata=metadata) + + @classmethod + def load(cls, path: str | Path) -> "ReframrModel": + checkpoint_path = Path(path) + checkpoint = read_safetensor_file( + checkpoint_path, + arrays=np is not None and checkpoint_path.stat().st_size > 10_000_000, + ) + metadata = checkpoint.metadata + config = ReframrConfig.from_dict(json.loads(metadata["config"])) + model = cls(config) + model.tokenizer = NativeTokenizer.from_dict(json.loads(metadata["tokenizer"])) + id_to_token = [str(token) for token in json.loads(metadata["embedding_id_to_token"])] + embedding_table = checkpoint.tensors["embedding_table"] + if np is not None and hasattr(embedding_table, "shape"): + embeddings = embedding_table.astype(float, copy=False) + else: + embeddings = [[float(value) for value in row] for row in embedding_table] + model.embedding_model = EmbeddingModel( + token_to_id={token: index for index, token in enumerate(id_to_token)}, + id_to_token=id_to_token, + embeddings=embeddings, + ppmi_matrix=[], + ) + model.memory_units = [ + AnalyticalMemoryUnit(model.config.state_dim, timescale) + for timescale in model.config.timescales + ] + model.ternary_scale = float(checkpoint.tensors["ternary_scale"][0]) + model.ternary_mask = [int(value) for value in checkpoint.tensors["ternary_mask"]] + readout_tensor = checkpoint.tensors["readout_weights"] + model.readout_weights = ( + readout_tensor.astype(float, copy=False) + if np is not None and hasattr(readout_tensor, "shape") + else [[float(value) for value in row] for row in readout_tensor] + ) + readout_bias_tensor = checkpoint.tensors.get("readout_bias", []) + model.readout_bias = [ + float(value) for value in ( + readout_bias_tensor.tolist() + if hasattr(readout_bias_tensor, "tolist") + else readout_bias_tensor + ) + ] + if not model.readout_bias: + model.readout_bias = [0.0 for _ in id_to_token] + prompt_answer_tensor = checkpoint.tensors.get("prompt_answer_weights", []) + model.prompt_answer_weights = ( + prompt_answer_tensor.astype(float, copy=False) + if np is not None + and hasattr(prompt_answer_tensor, "shape") + and len(prompt_answer_tensor.shape) == 2 + else [[float(value) for value in row] for row in prompt_answer_tensor] + ) + prompt_answer_bias_tensor = checkpoint.tensors.get("prompt_answer_bias", []) + model.prompt_answer_bias = [ + float(value) for value in ( + prompt_answer_bias_tensor.tolist() + if hasattr(prompt_answer_bias_tensor, "tolist") + else prompt_answer_bias_tensor + ) + ] + if not model.prompt_answer_bias: + model.prompt_answer_bias = [0.0 for _ in id_to_token] + prompt_answer_start_tensor = checkpoint.tensors.get("prompt_answer_start_weights", []) + model.prompt_answer_start_weights = ( + prompt_answer_start_tensor.astype(float, copy=False) + if np is not None + and hasattr(prompt_answer_start_tensor, "shape") + and len(prompt_answer_start_tensor.shape) == 2 + else [[float(value) for value in row] for row in prompt_answer_start_tensor] + ) + prompt_answer_start_bias_tensor = checkpoint.tensors.get("prompt_answer_start_bias", []) + model.prompt_answer_start_bias = [ + float(value) for value in ( + prompt_answer_start_bias_tensor.tolist() + if hasattr(prompt_answer_start_bias_tensor, "tolist") + else prompt_answer_start_bias_tensor + ) + ] + if not model.prompt_answer_start_bias: + model.prompt_answer_start_bias = [0.0 for _ in id_to_token] + trace_weight_tensor = checkpoint.tensors.get("trace_token_weights", []) + model.trace_token_weights = [ + float(value) for value in ( + trace_weight_tensor.tolist() + if hasattr(trace_weight_tensor, "tolist") + else trace_weight_tensor + ) + ] + if not model.trace_token_weights: + model.trace_token_weights = [ + 0.0 if token in model.tokenizer.special_tokens else 1.0 + for token in id_to_token + ] + preference_bias_tensor = checkpoint.tensors.get("preference_bias", []) + model.preference_bias = [ + float(value) for value in ( + preference_bias_tensor.tolist() + if hasattr(preference_bias_tensor, "tolist") + else preference_bias_tensor + ) + ] + if not model.preference_bias: + model.preference_bias = [0.0 for _ in id_to_token] + state_offset_tensor = checkpoint.tensors.get("state_offset", []) + model.state_offset = [ + float(value) for value in ( + state_offset_tensor.tolist() + if hasattr(state_offset_tensor, "tolist") + else state_offset_tensor + ) + ] + if not model.state_offset: + model.state_offset = [0.0 for _ in range(model._combined_state_width())] + associative_tensor = checkpoint.tensors.get("associative_keys", []) + model.associative_keys = ( + associative_tensor.astype(float, copy=False) + if np is not None and hasattr(associative_tensor, "shape") + else [[float(value) for value in row] for row in associative_tensor] + ) + if np is not None and hasattr(model.associative_keys, "shape"): + model.associative_key_norms = np.linalg.norm(model.associative_keys, axis=1).tolist() + else: + model.associative_key_norms = [norm(key) for key in model.associative_keys] + raw_associative_values = checkpoint.tensors.get("associative_values", []) + model.associative_values = [ + int(value) for value in ( + raw_associative_values.tolist() + if hasattr(raw_associative_values, "tolist") + else raw_associative_values + ) + ] + answer_tensor = checkpoint.tensors.get("answer_keys", []) + if np is not None and hasattr(answer_tensor, "shape"): + model.answer_keys = ( + answer_tensor.astype(float, copy=False) + if len(answer_tensor.shape) == 2 + else [] + ) + else: + model.answer_keys = [[float(value) for value in row] for row in answer_tensor] + if ( + np is not None + and hasattr(model.answer_keys, "shape") + and len(model.answer_keys.shape) == 2 + ): + model.answer_key_norms = np.linalg.norm(model.answer_keys, axis=1).tolist() + else: + model.answer_key_norms = [norm(key) for key in model.answer_keys] + raw_answer_values = checkpoint.tensors.get("answer_values", []) + model.answer_values = [ + int(value) for value in ( + raw_answer_values.tolist() + if hasattr(raw_answer_values, "tolist") + else raw_answer_values + ) + ] + answer_start_tensor = checkpoint.tensors.get("answer_start_keys", []) + if np is not None and hasattr(answer_start_tensor, "shape"): + model.answer_start_keys = ( + answer_start_tensor.astype(float, copy=False) + if len(answer_start_tensor.shape) == 2 + else [] + ) + else: + model.answer_start_keys = [ + [float(value) for value in row] for row in answer_start_tensor + ] + if ( + np is not None + and hasattr(model.answer_start_keys, "shape") + and len(model.answer_start_keys.shape) == 2 + ): + model.answer_start_key_norms = np.linalg.norm(model.answer_start_keys, axis=1).tolist() + else: + model.answer_start_key_norms = [norm(key) for key in model.answer_start_keys] + raw_answer_start_values = checkpoint.tensors.get("answer_start_values", []) + model.answer_start_values = [ + int(value) for value in ( + raw_answer_start_values.tolist() + if hasattr(raw_answer_start_values, "tolist") + else raw_answer_start_values + ) + ] + answer_sequence_tensor = checkpoint.tensors.get("answer_sequence_keys", []) + if np is not None and hasattr(answer_sequence_tensor, "shape"): + model.answer_sequence_keys = ( + answer_sequence_tensor.astype(float, copy=False) + if len(answer_sequence_tensor.shape) == 2 + else [] + ) + else: + model.answer_sequence_keys = [ + [float(value) for value in row] for row in answer_sequence_tensor + ] + if ( + np is not None + and hasattr(model.answer_sequence_keys, "shape") + and len(model.answer_sequence_keys.shape) == 2 + ): + model.answer_sequence_key_norms = np.linalg.norm( + model.answer_sequence_keys, + axis=1, + ).tolist() + else: + model.answer_sequence_key_norms = [norm(key) for key in model.answer_sequence_keys] + raw_answer_sequence_prompt_tokens = checkpoint.tensors.get("answer_sequence_prompt_tokens", []) + if np is not None and hasattr(raw_answer_sequence_prompt_tokens, "shape"): + model.answer_sequence_prompt_tokens = raw_answer_sequence_prompt_tokens.astype(int, copy=False) + else: + model.answer_sequence_prompt_tokens = [ + [int(value) for value in row] for row in raw_answer_sequence_prompt_tokens + ] + raw_answer_sequence_tokens = checkpoint.tensors.get("answer_sequence_tokens", []) + if np is not None and hasattr(raw_answer_sequence_tokens, "shape"): + model.answer_sequence_tokens = raw_answer_sequence_tokens.astype(int, copy=False) + else: + model.answer_sequence_tokens = [ + [int(value) for value in row] for row in raw_answer_sequence_tokens + ] + model.transition_tables = model._deserialize_transition_tables( + json.loads(metadata.get("transition_tables", "{}")) + ) + model._refresh_numeric_caches() + return model + + def _collect_training_examples( + self, + tokens: list[str], + ) -> tuple[list[Vector], list[Vector], list[int]]: + assert self.embedding_model is not None + if np is not None: + hidden_states = [ + np.zeros(self.config.state_dim, dtype=np.float64) + for _ in self.config.timescales + ] + context_traces = [ + np.zeros(self.config.embedding_dim, dtype=np.float64) + for _ in self.config.timescales + ] + zero_embedding: Vector | object = np.zeros(self.config.embedding_dim, dtype=np.float64) + else: + hidden_states = [zeros_vector(self.config.state_dim) for _ in self.config.timescales] + context_traces = [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] + zero_embedding = zeros_vector(self.config.embedding_dim) + states: list[Vector] = [] + labels: list[Vector] = [] + label_ids: list[int] = [] + token_ids = [ + self.embedding_model.token_to_id.get(token, -1) + for token in tokens + ] + example_count = max(0, len(tokens) - 1) + stride = 1 + if self.config.max_training_examples and example_count > self.config.max_training_examples: + stride = max( + 1, + (example_count + self.config.max_training_examples - 1) // self.config.max_training_examples, + ) + + for index in range(len(tokens) - 1): + token = tokens[index] + token_id = token_ids[index] + embedding = ( + self.embedding_model.embeddings[token_id] + if token_id >= 0 + else zero_embedding + ) + trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) + hidden_states, context_traces, combined_state = self._step_hidden_states_from_embedding( + hidden_states, + context_traces, + embedding, + trace_embedding=trace_embedding, + ) + if stride > 1 and index % stride != 0 and index != len(tokens) - 2: + continue + states.append(combined_state) + next_token_id = token_ids[index + 1] + labels.append(self._one_hot_from_id(next_token_id)) + label_ids.append(next_token_id) + + if self.config.max_training_examples and len(states) > self.config.max_training_examples: + states = states[: self.config.max_training_examples] + labels = labels[: self.config.max_training_examples] + label_ids = label_ids[: self.config.max_training_examples] + return states, labels, label_ids + + def _is_punctuation_piece(self, piece: str) -> bool: + return bool(piece) and all(character in string.punctuation for character in piece) + + def _encode_context(self, tokens: list[str]) -> Vector: + return self._masked_decode_state(self._build_decode_state(tokens)) + + def _build_decode_state(self, tokens: list[str]) -> DecodeState: + assert self.memory_units is not None + + state = DecodeState( + hidden_states=( + [ + np.zeros(self.config.state_dim, dtype=np.float64) + for _ in self.config.timescales + ] + if np is not None + else [zeros_vector(self.config.state_dim) for _ in self.config.timescales] + ), + context_traces=( + [ + np.zeros(self.config.embedding_dim, dtype=np.float64) + for _ in self.config.timescales + ] + if np is not None + else [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] + ), + combined_state=self._zero_combined_state(), + context_tokens=[], + ) + for token in tokens: + self._advance_decode_state(state, token) + return state + + def _advance_decode_state(self, state: DecodeState, token: str) -> DecodeState: + next_hidden_states, next_context_traces, combined_state = self._step_hidden_states( + state.hidden_states, + state.context_traces, + token, + ) + state.hidden_states = next_hidden_states + state.context_traces = next_context_traces + state.combined_state = combined_state + state.context_tokens.append(token) + if token == "": + state.answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:] + state.answer_matches = None + state.answer_start_matches = None + state.answer_sequence_matches = None + state.prompt_answer_prior = None + state.prompt_answer_start_prior = None + return state + + def _masked_decode_state(self, state: DecodeState) -> Vector: + assert self.ternary_mask is not None + return apply_ternary_mask(state.combined_state, self.ternary_mask, self.ternary_scale) + + def _masked_combined_state(self, combined_state: Vector) -> Vector: + assert self.ternary_mask is not None + return apply_ternary_mask(combined_state, self.ternary_mask, self.ternary_scale) + + def _masked_decode_state_array(self, state: DecodeState) -> object: + assert np is not None + if self.ternary_mask_array is None: + return np.asarray(self._masked_decode_state(state), dtype=RUNTIME_ARRAY_DTYPE) + return ( + np.asarray(state.combined_state, dtype=RUNTIME_ARRAY_DTYPE) + * self.ternary_scale + * self.ternary_mask_array + ) + + def _masked_combined_state_array(self, combined_state: Vector) -> object: + assert np is not None + if self.ternary_mask_array is None: + return np.asarray(self._masked_combined_state(combined_state), dtype=RUNTIME_ARRAY_DTYPE) + return ( + np.asarray(combined_state, dtype=RUNTIME_ARRAY_DTYPE) + * self.ternary_scale + * self.ternary_mask_array + ) + + def _center_state_vector(self, state: Vector) -> Vector: + if not self.state_offset or len(self.state_offset) != len(state): + return state + return [value - self.state_offset[index] for index, value in enumerate(state)] + + def _center_state_array(self, state: object) -> object: + assert np is not None + state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) + if self.state_offset_array is None or self.state_offset_array.shape != state_array.shape: + return state_array + return state_array - self.state_offset_array + + def _zero_combined_state(self) -> Vector: + return [0.0 for _ in range(self._combined_state_width())] + + def _combined_state_width(self) -> int: + return (self.config.state_dim + self.config.embedding_dim) * len(self.config.timescales) + + def _derive_trace_token_weights_from_counts(self, token_counts: dict[str, float]) -> Vector: + assert self.embedding_model is not None + assert self.tokenizer is not None + counts = [ + float(token_counts.get(token, 0.0)) + for token in self.embedding_model.id_to_token + ] + positive_counts = sorted(value for value in counts if value > 0.0) + reference = ( + positive_counts[len(positive_counts) // 2] + if positive_counts + else 1.0 + ) + weights: Vector = [] + for token, count in zip(self.embedding_model.id_to_token, counts): + if token in self.tokenizer.special_tokens: + weights.append(0.0) + elif count <= 0.0: + weights.append(1.0) + else: + weight = (reference / count) ** 0.75 + weights.append(max(0.08, min(4.8, weight))) + return weights + + def _token_id_for_token(self, token: str) -> int: + assert self.embedding_model is not None + token_id = self.embedding_model.token_to_id.get(token) + if token_id is None and token.lower() != token: + token_id = self.embedding_model.token_to_id.get(token.lower()) + return int(token_id) if token_id is not None else -1 + + def _trace_embedding_from_token_id( + self, + embedding: Vector | object, + token_id: int, + ) -> Vector | object: + if token_id < 0: + return embedding + if self.trace_embedding_table_array is not None: + return self.trace_embedding_table_array[token_id] + weight = self.trace_token_weights[token_id] if self.trace_token_weights is not None else 1.0 + dimension = self.config.embedding_dim + if hasattr(embedding, "shape"): + trace_embedding = embedding * weight + for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: + bucket = (token_id * bucket_multiplier + bucket_offset) % dimension + sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 + trace_embedding[bucket] += weight * TRACE_IDENTITY_SCALE * sign + return trace_embedding + trace_values = [float(value) * weight for value in embedding] + for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: + bucket = (token_id * bucket_multiplier + bucket_offset) % dimension + sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 + trace_values[bucket] += weight * TRACE_IDENTITY_SCALE * sign + return trace_values + + def _build_trace_embedding_table_array(self, embedding_array: object) -> object | None: + if np is None or self.trace_token_weights is None: + return None + values = np.asarray(embedding_array, dtype=np.float64) + if values.size == 0 or len(values.shape) != 2: + return None + weights = np.asarray(self.trace_token_weights, dtype=np.float64) + if weights.shape[0] != values.shape[0]: + return None + trace_values = values * weights[:, None] + if values.shape[1] <= 0: + return trace_values + token_ids = np.arange(values.shape[0], dtype=np.int64) + for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: + buckets = ((token_ids * bucket_multiplier + bucket_offset) % values.shape[1]).astype( + np.int64, + copy=False, + ) + signs = np.where( + ((token_ids * sign_multiplier + sign_offset) & 1) == 0, + 1.0, + -1.0, + ) + np.add.at(trace_values, (token_ids, buckets), weights * TRACE_IDENTITY_SCALE * signs) + return trace_values + + def _refresh_numeric_caches(self) -> None: + if np is None: + self.ternary_mask_array = None + self.readout_weights_array = None + self.readout_bias_array = None + self.prompt_answer_weights_array = None + self.prompt_answer_bias_array = None + self.prompt_answer_start_weights_array = None + self.prompt_answer_start_bias_array = None + self.trace_token_weights_array = None + self.trace_embedding_table_array = None + self.preference_bias_array = None + self.preference_valid_mask_array = None + self.state_offset_array = None + self.associative_keys_array = None + self.associative_key_norms_array = None + self.associative_values_array = None + self.associative_valid_mask_array = None + self.answer_keys_array = None + self.answer_key_norms_array = None + self.answer_similarity_keys_array = None + self.answer_similarity_key_norms_array = None + self.answer_similarity_mask_array = None + self.answer_values_array = None + self.answer_valid_mask_array = None + self.answer_start_keys_array = None + self.answer_start_key_norms_array = None + self.answer_start_similarity_keys_array = None + self.answer_start_similarity_key_norms_array = None + self.answer_start_values_array = None + self.answer_start_valid_mask_array = None + self.answer_sequence_keys_array = None + self.answer_sequence_key_norms_array = None + self.answer_sequence_similarity_keys_array = None + self.answer_sequence_similarity_key_norms_array = None + self.answer_sequence_prompt_tokens_array = None + self.answer_sequence_tokens_array = None + self.answer_sequence_prompt_weight_maps = None + self.answer_sequence_prompt_weight_norms = None + self.answer_sequence_prompt_bigram_sets = None + self.answer_sequence_prompt_trigram_sets = None + self.answer_sequence_prompt_number_sets = None + self.answer_sequence_prompt_inverted_index = None + self._refresh_answer_sequence_prompt_overlap_cache() + return + self.ternary_mask_array = ( + np.asarray(self.ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) + if self.ternary_mask is not None + else None + ) + self.readout_weights_array = ( + np.asarray(self.readout_weights, dtype=RUNTIME_ARRAY_DTYPE) + if self.readout_weights is not None + else None + ) + self.readout_bias_array = ( + np.asarray(self.readout_bias, dtype=RUNTIME_ARRAY_DTYPE) + if self.readout_bias is not None + else None + ) + self.prompt_answer_weights_array = ( + np.asarray(self.prompt_answer_weights, dtype=RUNTIME_ARRAY_DTYPE) + if self.prompt_answer_weights is not None + and len(self.prompt_answer_weights) > 0 + else None + ) + self.prompt_answer_bias_array = ( + np.asarray(self.prompt_answer_bias, dtype=RUNTIME_ARRAY_DTYPE) + if self.prompt_answer_bias is not None + else None + ) + self.prompt_answer_start_weights_array = ( + np.asarray(self.prompt_answer_start_weights, dtype=RUNTIME_ARRAY_DTYPE) + if self.prompt_answer_start_weights is not None + and len(self.prompt_answer_start_weights) > 0 + else None + ) + self.prompt_answer_start_bias_array = ( + np.asarray(self.prompt_answer_start_bias, dtype=RUNTIME_ARRAY_DTYPE) + if self.prompt_answer_start_bias is not None + else None + ) + self.trace_token_weights_array = ( + np.asarray(self.trace_token_weights, dtype=RUNTIME_ARRAY_DTYPE) + if self.trace_token_weights is not None + else None + ) + trace_embedding_table = ( + self._build_trace_embedding_table_array(self.embedding_model.embeddings) + if self.embedding_model is not None and self.trace_token_weights is not None + else None + ) + self.trace_embedding_table_array = ( + trace_embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False) + if trace_embedding_table is not None + else None + ) + self.preference_bias_array = ( + np.asarray(self.preference_bias, dtype=RUNTIME_ARRAY_DTYPE) + if self.preference_bias is not None + else None + ) + self.preference_valid_mask_array = ( + np.asarray( + [ + self._eligible_preference_token(token) + for token in self.embedding_model.id_to_token + ], + dtype=bool, + ) + if self.embedding_model is not None and self.tokenizer is not None + else None + ) + self.state_offset_array = ( + np.asarray(self.state_offset, dtype=RUNTIME_ARRAY_DTYPE) + if self.state_offset is not None + else None + ) + self.associative_keys_array = ( + np.asarray(self.associative_keys, dtype=RUNTIME_ARRAY_DTYPE) + if self.associative_keys is not None and len(self.associative_keys) > 0 + else None + ) + self.associative_key_norms_array = ( + np.asarray(self.associative_key_norms, dtype=RUNTIME_ARRAY_DTYPE) + if self.associative_key_norms is not None and len(self.associative_key_norms) > 0 + else None + ) + self.associative_values_array = ( + np.asarray(self.associative_values, dtype=np.int64) + if self.associative_values is not None and len(self.associative_values) > 0 + else None + ) + self.associative_valid_mask_array = ( + self.associative_values_array >= 0 + if self.associative_values_array is not None + else None + ) + self.answer_keys_array = ( + np.asarray(self.answer_keys, dtype=RUNTIME_ARRAY_DTYPE) + if self.answer_keys is not None and len(self.answer_keys) > 0 + else None + ) + self.answer_key_norms_array = ( + np.asarray(self.answer_key_norms, dtype=RUNTIME_ARRAY_DTYPE) + if self.answer_key_norms is not None and len(self.answer_key_norms) > 0 + else None + ) + self.answer_similarity_keys_array = None + self.answer_similarity_key_norms_array = None + self.answer_similarity_mask_array = None + if self.answer_keys_array is not None and len(self.answer_keys_array.shape) == 2: + width = int(self.answer_keys_array.shape[1]) + block_width = self.config.state_dim + self.config.embedding_dim + expected_width = block_width * len(self.config.timescales) + if block_width > 0 and width == expected_width: + mask = np.zeros(width, dtype=RUNTIME_ARRAY_DTYPE) + for scale_index in range(len(self.config.timescales)): + start = scale_index * block_width + self.config.state_dim + end = start + self.config.embedding_dim + mask[start:end] = 1.0 + self.answer_similarity_mask_array = mask + self.answer_similarity_keys_array = self.answer_keys_array * mask[None, :] + self.answer_similarity_key_norms_array = np.linalg.norm( + self.answer_similarity_keys_array, + axis=1, + ).astype(RUNTIME_ARRAY_DTYPE, copy=False) + self.answer_values_array = ( + np.asarray(self.answer_values, dtype=np.int64) + if self.answer_values is not None and len(self.answer_values) > 0 + else None + ) + self.answer_valid_mask_array = ( + self.answer_values_array >= 0 + if self.answer_values_array is not None + else None + ) + self.answer_start_keys_array = ( + np.asarray(self.answer_start_keys, dtype=RUNTIME_ARRAY_DTYPE) + if self.answer_start_keys is not None and len(self.answer_start_keys) > 0 + else None + ) + self.answer_start_key_norms_array = ( + np.asarray(self.answer_start_key_norms, dtype=RUNTIME_ARRAY_DTYPE) + if self.answer_start_key_norms is not None and len(self.answer_start_key_norms) > 0 + else None + ) + self.answer_start_similarity_keys_array = None + self.answer_start_similarity_key_norms_array = None + if ( + self.answer_start_keys_array is not None + and len(self.answer_start_keys_array.shape) == 2 + and self.answer_similarity_mask_array is not None + and int(self.answer_start_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) + ): + self.answer_start_similarity_keys_array = ( + self.answer_start_keys_array * self.answer_similarity_mask_array[None, :] + ) + self.answer_start_similarity_key_norms_array = np.linalg.norm( + self.answer_start_similarity_keys_array, + axis=1, + ).astype(RUNTIME_ARRAY_DTYPE, copy=False) + self.answer_start_values_array = ( + np.asarray(self.answer_start_values, dtype=np.int64) + if self.answer_start_values is not None and len(self.answer_start_values) > 0 + else None + ) + self.answer_start_valid_mask_array = ( + self.answer_start_values_array >= 0 + if self.answer_start_values_array is not None + else None + ) + self.answer_sequence_keys_array = ( + np.asarray(self.answer_sequence_keys, dtype=RUNTIME_ARRAY_DTYPE) + if self.answer_sequence_keys is not None and len(self.answer_sequence_keys) > 0 + else None + ) + self.answer_sequence_key_norms_array = ( + np.asarray(self.answer_sequence_key_norms, dtype=RUNTIME_ARRAY_DTYPE) + if self.answer_sequence_key_norms is not None and len(self.answer_sequence_key_norms) > 0 + else None + ) + self.answer_sequence_similarity_keys_array = None + self.answer_sequence_similarity_key_norms_array = None + if ( + self.answer_sequence_keys_array is not None + and len(self.answer_sequence_keys_array.shape) == 2 + and self.answer_similarity_mask_array is not None + and int(self.answer_sequence_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) + ): + self.answer_sequence_similarity_keys_array = ( + self.answer_sequence_keys_array * self.answer_similarity_mask_array[None, :] + ) + self.answer_sequence_similarity_key_norms_array = np.linalg.norm( + self.answer_sequence_similarity_keys_array, + axis=1, + ).astype(RUNTIME_ARRAY_DTYPE, copy=False) + self.answer_sequence_tokens_array = ( + np.asarray(self.answer_sequence_tokens, dtype=np.int64) + if self.answer_sequence_tokens is not None and len(self.answer_sequence_tokens) > 0 + else None + ) + self.answer_sequence_prompt_tokens_array = ( + np.asarray(self.answer_sequence_prompt_tokens, dtype=np.int64) + if self.answer_sequence_prompt_tokens is not None + and len(self.answer_sequence_prompt_tokens) > 0 + else None + ) + self._refresh_answer_sequence_prompt_overlap_cache() + + def _refresh_answer_sequence_prompt_overlap_cache(self) -> None: + self.answer_sequence_prompt_weight_maps = None + self.answer_sequence_prompt_weight_norms = None + self.answer_sequence_prompt_bigram_sets = None + self.answer_sequence_prompt_trigram_sets = None + self.answer_sequence_prompt_number_sets = None + self.answer_sequence_prompt_inverted_index = None + self.answer_sequence_prompt_specificity = None + if self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None: + return + inverted: dict[int, list[int]] = {} + row_id_lists: list[list[int]] = [] + for row in self.answer_sequence_prompt_tokens: + row_values = row.tolist() if hasattr(row, "tolist") else row + row_ids: list[int] = [] + for raw_token_id in row_values: + token_id = int(raw_token_id) + if token_id < 0 or token_id >= len(self.trace_token_weights): + continue + row_ids.append(token_id) + sequence_index = len(row_id_lists) + for token_id in set(row_ids): + inverted.setdefault(token_id, []).append(sequence_index) + row_id_lists.append(row_ids) + + total_rows = len(row_id_lists) + specificity = { + token_id: self._prompt_overlap_token_specificity(len(indices), total_rows) + for token_id, indices in inverted.items() + } + self.answer_sequence_prompt_inverted_index = inverted + self.answer_sequence_prompt_specificity = specificity + + weight_maps: list[dict[int, float]] = [] + weight_norms: list[float] = [] + bigram_sets: list[set[tuple[int, int]]] = [] + trigram_sets: list[set[tuple[int, int, int]]] = [] + number_sets: list[set[str]] = [] + for row_ids in row_id_lists: + row_weights: dict[int, float] = {} + for token_id in row_ids: + row_weights[token_id] = max( + row_weights.get(token_id, 0.0), + float(self.trace_token_weights[token_id]) * specificity.get(token_id, 1.0), + ) + weight_maps.append(row_weights) + weight_norms.append(sum(value * value for value in row_weights.values()) ** 0.5) + bigram_sets.append( + { + (row_ids[index], row_ids[index + 1]) + for index in range(len(row_ids) - 1) + } + ) + trigram_sets.append( + { + (row_ids[index], row_ids[index + 1], row_ids[index + 2]) + for index in range(len(row_ids) - 2) + } + ) + number_sets.append(self._number_strings_from_token_ids(row_ids)) + self.answer_sequence_prompt_weight_maps = weight_maps + self.answer_sequence_prompt_weight_norms = weight_norms + self.answer_sequence_prompt_bigram_sets = bigram_sets + self.answer_sequence_prompt_trigram_sets = trigram_sets + self.answer_sequence_prompt_number_sets = number_sets + + @staticmethod + def _prompt_overlap_token_specificity(document_frequency: int, total_documents: int) -> float: + if document_frequency <= 0 or total_documents <= 0: + return 1.0 + coverage = min(1.0, document_frequency / total_documents) + return max(0.02, 1.0 - (coverage ** 0.5)) + + def _number_strings_from_token_ids(self, token_ids: list[int]) -> set[str]: + assert self.embedding_model is not None + tokens = [ + self.embedding_model.id_to_token[token_id] + for token_id in token_ids + if 0 <= token_id < len(self.embedding_model.id_to_token) + ] + return self._number_strings_from_tokens(tokens) + + def _number_strings_from_tokens(self, tokens: list[str]) -> set[str]: + numbers: set[str] = set() + current = "" + for token in tokens: + if self.tokenizer is not None and token in self.tokenizer.special_tokens: + if current: + numbers.add(current) + current = "" + continue + rendered = self._render_token(token) + digits = "".join(character for character in rendered if character.isdigit()) + starts_number = self._starts_new_word(token) if self.tokenizer is not None else True + if digits and starts_number: + if current: + numbers.add(current) + current = digits + elif digits and current: + current += digits + else: + if current: + numbers.add(current) + current = "" + if current: + numbers.add(current) + return numbers + + @staticmethod + def _numeric_prompt_can_match(query_numbers: set[str], row_numbers: set[str]) -> bool: + if not query_numbers: + return True + if not row_numbers: + return False + return query_numbers.issubset(row_numbers) + + def _apply_readout_fast(self, state: Vector) -> Vector: + if self.readout_weights_array is None or np is None: + assert self.readout_weights is not None + centered_state = self._center_state_vector(state) + logits = apply_readout(self.readout_weights, centered_state) + if self.readout_bias: + logits = [ + value + self.readout_bias[index] + for index, value in enumerate(logits) + ] + return logits + state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) + if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: + state_array = state_array - self.state_offset_array + logits = self.readout_weights_array @ state_array + if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: + logits = logits + self.readout_bias_array + return logits.tolist() + + def _apply_readout_array(self, state: object) -> object: + assert np is not None + assert self.readout_weights_array is not None + state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) + if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: + state_array = state_array - self.state_offset_array + logits = self.readout_weights_array @ state_array + if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: + logits = logits + self.readout_bias_array + return logits + + def _step_hidden_states( + self, + hidden_states: list[Vector], + context_traces: list[Vector], + token: str, + ) -> tuple[list[Vector], list[Vector], Vector]: + assert self.embedding_model is not None + assert self.tokenizer is not None + token_id = self._token_id_for_token(token) + embedding = self.embedding_model.vector(token) + trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) + return self._step_hidden_states_from_embedding( + hidden_states, + context_traces, + embedding, + trace_embedding=trace_embedding, + ) + + def _step_hidden_states_from_embedding( + self, + hidden_states: list[Vector], + context_traces: list[Vector], + embedding: Vector | object, + *, + trace_embedding: Vector | object | None = None, + ) -> tuple[list[Vector], list[Vector], Vector]: + assert self.memory_units is not None + if trace_embedding is None: + trace_embedding = embedding + + if np is not None and hidden_states and hasattr(hidden_states[0], "shape"): + embedding_array = ( + embedding + if hasattr(embedding, "shape") + else np.asarray(embedding, dtype=np.float64) + ) + trace_embedding_array = ( + trace_embedding + if hasattr(trace_embedding, "shape") + else np.asarray(trace_embedding, dtype=np.float64) + ) + drive = analytical_embedding_drive_fast(embedding_array, self.config.state_dim) + next_states: list[Vector] = [] + next_traces: list[Vector] = [] + combined_state: Vector = [] + for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): + next_state = unit.step_vector_fast(state, drive) + decay = 1.0 / (1.0 + unit.timescale) + next_trace = trace + ((1.0 - decay) * trace_embedding_array) + next_states.append(next_state) + next_traces.append(next_trace) + combined_state.extend(next_state.tolist()) + combined_state.extend(next_trace.tolist()) + return next_states, next_traces, combined_state + + embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else embedding + trace_embedding_vector = ( + trace_embedding.tolist() + if hasattr(trace_embedding, "tolist") + else trace_embedding + ) + drive = analytical_embedding_drive(embedding_vector, self.config.state_dim) + next_states: list[Vector] = [] + next_traces: list[Vector] = [] + combined_state: Vector = [] + for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): + next_state = unit.step_vector(state, drive) + decay = 1.0 / (1.0 + unit.timescale) + next_trace = [ + previous + ((1.0 - decay) * value) + for previous, value in zip(trace, trace_embedding_vector) + ] + next_states.append(next_state) + next_traces.append(next_trace) + combined_state.extend(next_state) + combined_state.extend(next_trace) + return next_states, next_traces, combined_state + + def _one_hot(self, token: str) -> Vector: + assert self.embedding_model is not None + return self._one_hot_from_id(self.embedding_model.token_to_id.get(token, -1)) + + def _one_hot_from_id(self, token_id: int) -> Vector: + assert self.embedding_model is not None + vector = [0.0 for _ in self.embedding_model.id_to_token] + if token_id >= 0: + vector[token_id] = 1.0 + return vector + + def _blend_probabilities( + self, + base: Vector, + answer: Vector, + associative: Vector, + transition: Vector, + copy: Vector, + preference: Vector, + *, + transition_order: int | None, + generated_count: int = 0, + answer_locked: bool = False, + answer_guided_start: bool = False, + ) -> tuple[Vector, dict[str, float]]: + base_weight = FAST_BASE_BLEND + answer_weight = FAST_ANSWER_BLEND + associative_weight = FAST_ASSOCIATIVE_BLEND + transition_weight = FAST_TRANSITION_BLEND + copy_weight = FAST_COPY_BLEND + preference_weight = FAST_PREFERENCE_BLEND + if answer_locked: + base_weight *= 0.18 + answer_weight *= 5.0 + associative_weight *= 0.2 + transition_weight *= 0.2 + copy_weight *= 0.2 + preference_weight *= 0.2 + elif answer_guided_start: + base_weight *= 0.35 + answer_weight *= 3.5 + associative_weight *= 0.2 + transition_weight *= 0.35 + copy_weight *= 0.2 + preference_weight *= 0.2 + elif generated_count > 0: + answer_weight *= 0.32 + transition_weight *= 2.0 + copy_weight *= 0.75 + + if transition_order is None: + answer_weight *= 1.1 + associative_weight *= 0.75 + copy_weight += 0.02 + elif transition_order <= 2: + answer_weight *= 1.15 + associative_weight *= 0.65 + transition_weight *= 0.55 + copy_weight += 0.01 + elif transition_order >= 5: + transition_weight *= 1.25 + + sources: list[tuple[str, float, Vector]] = [("base", base_weight, base)] + if any(value > 0.0 for value in answer): + sources.append(("answer", answer_weight, answer)) + if any(value > 0.0 for value in associative): + sources.append(("associative", associative_weight, associative)) + if any(value > 0.0 for value in transition): + sources.append(("transition", transition_weight, transition)) + if any(value > 0.0 for value in copy): + sources.append(("copy", copy_weight, copy)) + if any(value > 0.0 for value in preference): + sources.append(("preference", preference_weight, preference)) + + total_weight = sum(weight for _, weight, _ in sources) + blended = [0.0 for _ in base] + blend_weights: dict[str, float] = {} + for name, weight, source in sources: + normalized_weight = weight / total_weight if total_weight else 0.0 + blend_weights[name] = normalized_weight + for index, value in enumerate(source): + blended[index] += normalized_weight * value + return _normalize_vector(blended), blend_weights + + def _blend_probability_arrays( + self, + base: object, + answer: object, + associative: object, + transition: object, + copy: object, + preference: object, + *, + transition_order: int | None, + generated_count: int = 0, + answer_locked: bool = False, + answer_guided_start: bool = False, + ) -> tuple[object, dict[str, float]]: + assert np is not None + + base_weight = FAST_BASE_BLEND + answer_weight = FAST_ANSWER_BLEND + associative_weight = FAST_ASSOCIATIVE_BLEND + transition_weight = FAST_TRANSITION_BLEND + copy_weight = FAST_COPY_BLEND + preference_weight = FAST_PREFERENCE_BLEND + if answer_locked: + base_weight *= 0.18 + answer_weight *= 5.0 + associative_weight *= 0.2 + transition_weight *= 0.2 + copy_weight *= 0.2 + preference_weight *= 0.2 + elif answer_guided_start: + base_weight *= 0.35 + answer_weight *= 3.5 + associative_weight *= 0.2 + transition_weight *= 0.35 + copy_weight *= 0.2 + preference_weight *= 0.2 + elif generated_count > 0: + answer_weight *= 0.32 + transition_weight *= 2.0 + copy_weight *= 0.75 + if transition_order is None: + answer_weight *= 1.1 + associative_weight *= 0.75 + copy_weight += 0.02 + elif transition_order <= 2: + answer_weight *= 1.15 + associative_weight *= 0.65 + transition_weight *= 0.55 + copy_weight += 0.01 + elif transition_order >= 5: + transition_weight *= 1.25 + + sources: list[tuple[str, float, object]] = [("base", base_weight, base)] + if np.any(answer > 0.0): + sources.append(("answer", answer_weight, answer)) + if np.any(associative > 0.0): + sources.append(("associative", associative_weight, associative)) + if np.any(transition > 0.0): + sources.append(("transition", transition_weight, transition)) + if np.any(copy > 0.0): + sources.append(("copy", copy_weight, copy)) + if np.any(preference > 0.0): + sources.append(("preference", preference_weight, preference)) + + total_weight = sum(weight for _, weight, _ in sources) + blended = np.zeros_like(base, dtype=np.float64) + blend_weights: dict[str, float] = {} + for name, weight, source in sources: + normalized_weight = weight / total_weight if total_weight else 0.0 + blend_weights[name] = normalized_weight + blended += normalized_weight * source + total = float(blended.sum()) + if total <= 0.0: + return base, blend_weights + return blended / total, blend_weights + + def _score_associative_matches( + self, + state: Vector, + *, + limit: int = ASSOCIATIVE_TOP_K, + ) -> list[tuple[float, int, int]]: + if ( + self.associative_keys is None + or self.associative_values is None + or self.associative_key_norms is None + or len(self.associative_keys) == 0 + or len(self.associative_values) == 0 + or len(self.associative_key_norms) == 0 + ): + return [] + + if ( + np is not None + and + self.associative_keys_array is not None + and self.associative_key_norms_array is not None + and self.associative_values_array is not None + and self.associative_valid_mask_array is not None + and limit > 0 + ): + state_array = self._center_state_array(state).astype(self.associative_keys_array.dtype, copy=False) + state_norm = float(np.linalg.norm(state_array)) + if state_norm == 0.0: + return [] + numerators = self.associative_keys_array @ state_array + denominators = self.associative_key_norms_array * state_norm + valid_mask = self.associative_valid_mask_array & (denominators > 0.0) + if np.any(valid_mask): + scores = np.zeros_like(numerators, dtype=self.associative_keys_array.dtype) + np.divide(numerators, denominators, out=scores, where=valid_mask) + positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) + if positive_positions.size: + selected_positions = positive_positions + if positive_positions.size > limit: + partition = np.argpartition(scores[positive_positions], -limit)[-limit:] + selected_positions = positive_positions[partition] + ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] + return [ + ( + float(scores[position]), + int(self.associative_values_array[position]), + int(position), + ) + for position in ordered_positions + ] + + state = self._center_state_vector(state) + state_norm = norm(state) + if state_norm == 0.0: + return [] + + scored: list[tuple[float, int, int]] = [] + for example_index, (key, key_norm, token_id) in enumerate( + zip(self.associative_keys, self.associative_key_norms, self.associative_values) + ): + if token_id < 0: + continue + denominator = state_norm * key_norm + if denominator == 0.0: + continue + similarity = dot(state, key) / denominator + if similarity > 0.0: + scored.append((similarity, token_id, example_index)) + scored.sort(key=lambda item: item[0], reverse=True) + return scored[:limit] + + def _associative_prior_from_matches( + self, + matches: list[tuple[float, int, int]], + ) -> Vector: + assert self.embedding_model is not None + if not matches: + return [0.0 for _ in self.embedding_model.id_to_token] + + prior = [0.0 for _ in self.embedding_model.id_to_token] + for similarity, token_id, _ in matches[:ASSOCIATIVE_TOP_K]: + prior[token_id] += similarity + return _normalize_vector(prior) + + def _associative_prior(self, state: Vector) -> Vector: + return self._associative_prior_from_matches(self._score_associative_matches(state)) + + def _score_answer_matches( + self, + answer_anchor_state: Vector | None, + *, + limit: int = ANSWER_TOP_K, + ) -> list[tuple[float, int, int]]: + return self._score_prompt_anchor_matches( + answer_anchor_state, + self.answer_keys, + self.answer_key_norms, + self.answer_values, + self.answer_keys_array, + self.answer_key_norms_array, + self.answer_values_array, + self.answer_valid_mask_array, + self.answer_similarity_keys_array, + self.answer_similarity_key_norms_array, + self.answer_similarity_mask_array, + limit=limit, + ) + + def _score_answer_start_matches( + self, + answer_anchor_state: Vector | None, + *, + limit: int = ANSWER_START_TOP_K, + ) -> list[tuple[float, int, int]]: + return self._score_prompt_anchor_matches( + answer_anchor_state, + self.answer_start_keys, + self.answer_start_key_norms, + self.answer_start_values, + self.answer_start_keys_array, + self.answer_start_key_norms_array, + self.answer_start_values_array, + self.answer_start_valid_mask_array, + self.answer_start_similarity_keys_array, + self.answer_start_similarity_key_norms_array, + self.answer_similarity_mask_array, + limit=limit, + ) + + def _score_answer_sequence_matches( + self, + answer_anchor_state: Vector | None, + context_tokens: list[str], + *, + limit: int = ANSWER_START_TOP_K, + ) -> list[tuple[float, int, int]]: + if ( + answer_anchor_state is None + or self.answer_sequence_keys is None + or self.answer_sequence_key_norms is None + or self.answer_sequence_tokens is None + ): + return [] + values = list(range(len(self.answer_sequence_tokens))) + values_array = np.arange(len(values), dtype=np.int64) if np is not None else None + anchor_matches = self._score_prompt_anchor_matches( + answer_anchor_state, + self.answer_sequence_keys, + self.answer_sequence_key_norms, + values, + self.answer_sequence_keys_array, + self.answer_sequence_key_norms_array, + values_array, + values_array >= 0 if values_array is not None else None, + self.answer_sequence_similarity_keys_array, + self.answer_sequence_similarity_key_norms_array, + self.answer_similarity_mask_array, + limit=max(limit * 4, limit), + ) + overlap_scores = self._answer_sequence_prompt_overlap_scores(context_tokens) + if overlap_scores is None: + return anchor_matches[:limit] + if not overlap_scores: + return [] + best_overlap = max(overlap_scores.values()) if overlap_scores else 0.0 + overlap_floor = max(0.16, best_overlap * 0.90) + focused_overlap_scores = { + sequence_index: overlap + for sequence_index, overlap in overlap_scores.items() + if overlap >= overlap_floor + } + if not focused_overlap_scores: + focused_overlap_scores = overlap_scores + focused_indices = set(focused_overlap_scores) + merged: dict[int, float] = {} + for similarity, sequence_index, _ in anchor_matches: + if sequence_index not in focused_indices: + continue + merged[sequence_index] = max(merged.get(sequence_index, 0.0), 0.20 * similarity) + for sequence_index, overlap in focused_overlap_scores.items(): + merged[sequence_index] = merged.get(sequence_index, 0.0) + (0.80 * overlap) + ranked = [ + (score, sequence_index, sequence_index) + for sequence_index, score in merged.items() + if score > 0.0 + ] + ranked.sort(key=lambda item: item[0], reverse=True) + return ranked[:limit] + + def _answer_sequence_prompt_overlap_scores( + self, + context_tokens: list[str], + ) -> dict[int, float] | None: + if ( + self.embedding_model is None + or self.answer_sequence_prompt_tokens is None + or self.trace_token_weights is None + ): + return None + answer_boundary = _last_index(context_tokens, "") + prompt_tokens = ( + context_tokens[:answer_boundary] + if answer_boundary is not None + else context_tokens + ) + if self.answer_sequence_prompt_specificity is None: + self._refresh_answer_sequence_prompt_overlap_cache() + specificity_map = self.answer_sequence_prompt_specificity or {} + query_weights: dict[int, float] = {} + query_specificity: dict[int, float] = {} + query_content_weight = 0.0 + query_ids: list[int] = [] + for token in prompt_tokens: + if self.tokenizer is not None and token in self.tokenizer.special_tokens: + continue + token_id = self.embedding_model.token_to_id.get(token) + if token_id is None: + continue + query_ids.append(token_id) + specificity = specificity_map.get(token_id, 1.0) + weight = specificity + query_weights[token_id] = max( + query_weights.get(token_id, 0.0), + weight, + ) + query_specificity[token_id] = max( + query_specificity.get(token_id, 0.0), + specificity, + ) + if specificity >= 0.20: + query_content_weight += weight + if not query_weights: + return None + query_norm = sum(value * value for value in query_weights.values()) ** 0.5 + if query_norm <= 0.0: + return None + + query_bigrams = { + (query_ids[index], query_ids[index + 1]) + for index in range(len(query_ids) - 1) + } + query_trigrams = { + (query_ids[index], query_ids[index + 1], query_ids[index + 2]) + for index in range(len(query_ids) - 2) + } + query_numbers = self._number_strings_from_tokens(prompt_tokens) + + def ordered_ngram_score( + query_grams: set[tuple[int, ...]], + row_grams: set[tuple[int, ...]], + ) -> float: + if not query_grams or not row_grams: + return 0.0 + overlap = len(query_grams & row_grams) + if overlap <= 0: + return 0.0 + return overlap / ((len(query_grams) * len(row_grams)) ** 0.5) + + cached_maps = self.answer_sequence_prompt_weight_maps + cached_norms = self.answer_sequence_prompt_weight_norms + cached_bigrams = self.answer_sequence_prompt_bigram_sets + cached_trigrams = self.answer_sequence_prompt_trigram_sets + cached_numbers = self.answer_sequence_prompt_number_sets + cached_index = self.answer_sequence_prompt_inverted_index + if ( + cached_maps is not None + and cached_norms is not None + and cached_bigrams is not None + and cached_trigrams is not None + and cached_numbers is not None + and len(cached_maps) == len(self.answer_sequence_prompt_tokens) + ): + candidate_indices: set[int] | range + if cached_index is not None: + candidates: set[int] = set() + for token_id in query_weights: + candidates.update(cached_index.get(token_id, ())) + candidate_indices = candidates if candidates else range(len(cached_maps)) + else: + candidate_indices = range(len(cached_maps)) + candidate_indices = list(candidate_indices) + if cached_index is not None and candidate_indices: + candidate_set = set(candidate_indices) + local_query_weights: dict[int, float] = {} + local_query_specificity: dict[int, float] = {} + local_query_content_weight = 0.0 + for token_id in query_weights: + local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) + if local_frequency <= 0: + continue + specificity = self._prompt_overlap_token_specificity( + local_frequency, + len(candidate_indices), + ) + weight = specificity + local_query_weights[token_id] = weight + local_query_specificity[token_id] = specificity + if specificity >= 0.20: + local_query_content_weight += weight + local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 + if local_query_norm > 0.0: + query_weights = local_query_weights + query_specificity = local_query_specificity + if local_query_content_weight > 0.0: + query_content_weight = local_query_content_weight + query_norm = local_query_norm + scores: dict[int, float] = {} + for sequence_index in candidate_indices: + row_weights = cached_maps[sequence_index] + if not row_weights: + continue + if not self._numeric_prompt_can_match(query_numbers, cached_numbers[sequence_index]): + continue + matched_content_weight = sum( + query_weights[token_id] + for token_id in query_weights.keys() & row_weights.keys() + if query_specificity.get(token_id, 0.0) >= 0.20 + ) + row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( + 1, + len(row_weights), + ) + if ( + query_content_weight > 0.0 + and matched_content_weight / query_content_weight < 0.40 + and row_token_coverage < 0.75 + ): + continue + query_coverage = ( + matched_content_weight / query_content_weight + if query_content_weight > 0.0 + else row_token_coverage + ) + numerator = sum( + query_weights[token_id] * row_weights[token_id] + for token_id in query_weights.keys() & row_weights.keys() + ) + if numerator <= 0.0: + continue + row_norm = cached_norms[sequence_index] + if row_norm <= 0.0: + continue + token_score = numerator / (query_norm * row_norm) + bigram_score = ordered_ngram_score( + query_bigrams, + cached_bigrams[sequence_index], + ) + trigram_score = ordered_ngram_score( + query_trigrams, + cached_trigrams[sequence_index], + ) + scores[sequence_index] = ( + (0.35 * token_score) + + (0.35 * query_coverage) + + (0.15 * bigram_score) + + (0.15 * trigram_score) + ) + return scores + + if cached_index is not None: + candidate_set: set[int] = set() + for token_id in query_weights: + candidate_set.update(cached_index.get(token_id, ())) + if not candidate_set: + return {} + candidate_indices: list[int] | range = sorted(candidate_set) + local_query_weights: dict[int, float] = {} + local_query_specificity: dict[int, float] = {} + local_query_content_weight = 0.0 + candidate_count = len(candidate_indices) + for token_id in query_weights: + local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) + if local_frequency <= 0: + continue + specificity = self._prompt_overlap_token_specificity( + local_frequency, + candidate_count, + ) + local_query_weights[token_id] = specificity + local_query_specificity[token_id] = specificity + if specificity >= 0.20: + local_query_content_weight += specificity + local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 + if local_query_norm > 0.0: + query_weights = local_query_weights + query_specificity = local_query_specificity + if local_query_content_weight > 0.0: + query_content_weight = local_query_content_weight + query_norm = local_query_norm + else: + candidate_indices = range(len(self.answer_sequence_prompt_tokens)) + + scores: dict[int, float] = {} + for sequence_index in candidate_indices: + row = self.answer_sequence_prompt_tokens[sequence_index] + row_values = row.tolist() if hasattr(row, "tolist") else row + row_weights: dict[int, float] = {} + row_ids: list[int] = [] + for raw_token_id in row_values: + token_id = int(raw_token_id) + if token_id < 0 or token_id >= len(self.trace_token_weights): + continue + row_ids.append(token_id) + row_weights[token_id] = max( + row_weights.get(token_id, 0.0), + specificity_map.get(token_id, 1.0), + ) + if not row_weights: + continue + if not self._numeric_prompt_can_match( + query_numbers, + self._number_strings_from_token_ids(row_ids), + ): + continue + matched_content_weight = sum( + query_weights[token_id] + for token_id in query_weights.keys() & row_weights.keys() + if query_specificity.get(token_id, 0.0) >= 0.20 + ) + row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( + 1, + len(row_weights), + ) + if ( + query_content_weight > 0.0 + and matched_content_weight / query_content_weight < 0.40 + and row_token_coverage < 0.75 + ): + continue + query_coverage = ( + matched_content_weight / query_content_weight + if query_content_weight > 0.0 + else row_token_coverage + ) + numerator = sum( + query_weights[token_id] * row_weights[token_id] + for token_id in query_weights.keys() & row_weights.keys() + ) + if numerator <= 0.0: + continue + row_norm = sum(value * value for value in row_weights.values()) ** 0.5 + if row_norm > 0.0: + token_score = numerator / (query_norm * row_norm) + row_bigrams = { + (row_ids[index], row_ids[index + 1]) + for index in range(len(row_ids) - 1) + } + row_trigrams = { + (row_ids[index], row_ids[index + 1], row_ids[index + 2]) + for index in range(len(row_ids) - 2) + } + bigram_score = ordered_ngram_score(query_bigrams, row_bigrams) + trigram_score = ordered_ngram_score(query_trigrams, row_trigrams) + scores[sequence_index] = ( + (0.35 * token_score) + + (0.35 * query_coverage) + + (0.15 * bigram_score) + + (0.15 * trigram_score) + ) + return scores + + def _score_prompt_anchor_matches( + self, + answer_anchor_state: Vector | None, + keys: object | None, + key_norms_list: object | None, + values: object | None, + keys_array: object | None, + key_norms_array: object | None, + values_array: object | None, + valid_mask_array: object | None, + similarity_keys_array: object | None, + similarity_key_norms_array: object | None, + similarity_mask_array: object | None, + *, + limit: int, + ) -> list[tuple[float, int, int]]: + if ( + answer_anchor_state is None + or keys is None + or key_norms_list is None + or values is None + ): + return [] + + if ( + np is not None + and keys_array is not None + and key_norms_array is not None + and values_array is not None + and valid_mask_array is not None + and limit > 0 + ): + state_array = self._center_state_array( + self._masked_combined_state_array(answer_anchor_state) + ).astype(keys_array.dtype, copy=False) + key_array = keys_array + key_norms = key_norms_array + if ( + similarity_keys_array is not None + and similarity_key_norms_array is not None + and similarity_mask_array is not None + ): + state_array = state_array * similarity_mask_array + key_array = similarity_keys_array + key_norms = similarity_key_norms_array + state_norm = float(np.linalg.norm(state_array)) + if state_norm == 0.0: + return [] + numerators = key_array @ state_array + denominators = key_norms * state_norm + valid_mask = valid_mask_array & (denominators > 0.0) + if np.any(valid_mask): + scores = np.zeros_like(numerators, dtype=key_array.dtype) + np.divide(numerators, denominators, out=scores, where=valid_mask) + positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) + if positive_positions.size: + selected_positions = positive_positions + if positive_positions.size > limit: + partition = np.argpartition(scores[positive_positions], -limit)[-limit:] + selected_positions = positive_positions[partition] + ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] + return [ + ( + float(scores[position]), + int(values_array[position]), + int(position), + ) + for position in ordered_positions + ] + + state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) + state_norm = norm(state) + if state_norm == 0.0: + return [] + + scored: list[tuple[float, int, int]] = [] + for example_index, (key, key_norm, token_id) in enumerate( + zip(keys, key_norms_list, values) + ): + if token_id < 0: + continue + denominator = state_norm * key_norm + if denominator == 0.0: + continue + similarity = dot(state, key) / denominator + if similarity > 0.0: + scored.append((similarity, token_id, example_index)) + scored.sort(key=lambda item: item[0], reverse=True) + return scored[:limit] + + def _answer_prior_from_matches( + self, + matches: list[tuple[float, int, int]], + generated_tokens: list[str], + ) -> Vector: + assert self.embedding_model is not None + if not matches: + return [0.0 for _ in self.embedding_model.id_to_token] + + prior = [0.0 for _ in self.embedding_model.id_to_token] + generated_ids = { + self.embedding_model.token_to_id[token] + for token in generated_tokens + if token in self.embedding_model.token_to_id + } + for similarity, token_id, _ in matches[:ANSWER_TOP_K]: + token = self.embedding_model.id_to_token[token_id] + if not self._allowed_generation_token(token, generated_tokens): + continue + if token_id in generated_ids: + prior[token_id] += similarity * 0.35 + else: + prior[token_id] += similarity + return _normalize_vector(prior) + + def _answer_sequence_prior_from_matches( + self, + matches: list[tuple[float, int, int]], + generated_tokens: list[str], + ) -> Vector: + assert self.embedding_model is not None + if not matches or self.answer_sequence_tokens is None: + return [0.0 for _ in self.embedding_model.id_to_token] + + generated_ids = [ + self.embedding_model.token_to_id[token] + for token in generated_tokens + if token in self.embedding_model.token_to_id + ] + prior = [0.0 for _ in self.embedding_model.id_to_token] + best_similarity = matches[0][0] + match_floor = best_similarity - 0.02 if best_similarity >= 0.9 else 0.0 + for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: + if similarity < match_floor: + continue + row = self.answer_sequence_tokens[sequence_index] + token_ids = [ + int(value) + for value in (row.tolist() if hasattr(row, "tolist") else row) + if int(value) >= 0 + ] + if not token_ids: + continue + next_token_id = self._next_sequence_token_id(token_ids, generated_ids) + if next_token_id is None: + continue + token = self.embedding_model.id_to_token[next_token_id] + if self._allowed_generation_token(token, generated_tokens): + prior[next_token_id] += max(1e-9, similarity - match_floor) + return _normalize_vector(prior) + + def _should_stop_answer_sequence( + self, + decode_state: DecodeState, + generated_tokens: list[str], + ) -> bool: + matches = decode_state.answer_sequence_matches + if matches is None: + matches = self._score_answer_sequence_matches( + decode_state.answer_anchor_state, + decode_state.context_tokens, + ) + return self._answer_sequence_is_complete(generated_tokens, matches) + + def _answer_decode_has_continuation( + self, + decode_state: DecodeState, + generated_tokens: list[str], + ) -> bool: + matches = decode_state.answer_sequence_matches + if matches is None: + matches = self._score_answer_sequence_matches( + decode_state.answer_anchor_state, + decode_state.context_tokens, + ) + return self._answer_sequence_has_continuation(generated_tokens, matches) + + def _answer_sequence_is_complete( + self, + generated_tokens: list[str], + matches: list[tuple[float, int, int]], + ) -> bool: + if ( + self.embedding_model is None + or self.answer_sequence_tokens is None + or not generated_tokens + or not matches + ): + return False + generated_ids = [ + self.embedding_model.token_to_id[token] + for token in generated_tokens + if token in self.embedding_model.token_to_id + ] + if not generated_ids: + return False + for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: + if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): + continue + row = self.answer_sequence_tokens[sequence_index] + token_ids = [ + int(value) + for value in (row.tolist() if hasattr(row, "tolist") else row) + if int(value) >= 0 + ] + if not token_ids or len(generated_ids) < len(token_ids): + continue + if generated_ids[: len(token_ids)] == token_ids: + return True + return False + + def _answer_sequence_has_continuation( + self, + generated_tokens: list[str], + matches: list[tuple[float, int, int]], + ) -> bool: + if ( + self.embedding_model is None + or self.answer_sequence_tokens is None + or not generated_tokens + or not matches + ): + return False + generated_ids = [ + self.embedding_model.token_to_id[token] + for token in generated_tokens + if token in self.embedding_model.token_to_id + ] + if not generated_ids: + return False + for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: + if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): + continue + row = self.answer_sequence_tokens[sequence_index] + token_ids = [ + int(value) + for value in (row.tolist() if hasattr(row, "tolist") else row) + if int(value) >= 0 + ] + if not token_ids: + continue + next_token_id = self._next_sequence_token_id(token_ids, generated_ids) + if next_token_id is None: + continue + token = self.embedding_model.id_to_token[next_token_id] + if self._allowed_generation_token(token, generated_tokens): + return True + return False + + def _next_sequence_token_id( + self, + token_ids: list[int], + generated_ids: list[int], + ) -> int | None: + if not generated_ids: + return token_ids[0] + if len(generated_ids) >= len(token_ids): + return None + if token_ids[: len(generated_ids)] != generated_ids: + return None + return token_ids[len(generated_ids)] + + def _transition_prior(self, context_tokens: list[str]) -> Vector: + prior, _ = self._transition_prior_with_order(context_tokens) + return prior + + def _transition_prior_with_order( + self, + context_tokens: list[str], + ) -> tuple[Vector, int | None]: + assert self.embedding_model is not None + if not self.transition_tables: + return [0.0 for _ in self.embedding_model.id_to_token], None + + for order in TRANSITION_ORDERS: + if len(context_tokens) < order: + continue + key = tuple(context_tokens[-order:]) + transitions = self.transition_tables.get(order, {}).get(key) + if not transitions: + continue + prior = [0.0 for _ in self.embedding_model.id_to_token] + for token, probability in transitions.items(): + token_id = self.embedding_model.token_to_id.get(token) + if token_id is not None: + prior[token_id] = probability + return _normalize_vector(prior), order + return [0.0 for _ in self.embedding_model.id_to_token], None + + def _transition_prior_array_with_order( + self, + context_tokens: list[str], + ) -> tuple[object, int | None]: + assert np is not None + assert self.embedding_model is not None + prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + if not self.transition_tables: + return prior, None + + for order in TRANSITION_ORDERS: + if len(context_tokens) < order: + continue + key = tuple(context_tokens[-order:]) + transitions = self.transition_tables.get(order, {}).get(key) + if not transitions: + continue + for token, probability in transitions.items(): + token_id = self.embedding_model.token_to_id.get(token) + if token_id is not None: + prior[token_id] = probability + total = float(prior.sum()) + if total > 0.0: + prior /= total + return prior, order + return prior, None + + def _copy_prior(self, context_tokens: list[str]) -> Vector: + assert self.embedding_model is not None + assert self.tokenizer is not None + + prior = [0.0 for _ in self.embedding_model.id_to_token] + decay = 0.82 + answer_start = None + for index in range(len(context_tokens) - 1, -1, -1): + if context_tokens[index] == "": + answer_start = index + 1 + break + source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens + if not source_tokens: + return prior + for distance, token in enumerate(reversed(source_tokens[-8:])): + if token in self.tokenizer.special_tokens: + continue + if not self._eligible_copy_token(token): + continue + token_id = self.embedding_model.token_to_id.get(token) + if token_id is None: + continue + prior[token_id] += decay**distance + return _normalize_vector(prior) + + def _copy_prior_array(self, context_tokens: list[str]) -> object: + assert np is not None + assert self.embedding_model is not None + assert self.tokenizer is not None + + prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + decay = 0.82 + answer_start = None + for index in range(len(context_tokens) - 1, -1, -1): + if context_tokens[index] == "": + answer_start = index + 1 + break + source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens + for distance, token in enumerate(reversed(source_tokens[-8:])): + if token in self.tokenizer.special_tokens: + continue + if not self._eligible_copy_token(token): + continue + token_id = self.embedding_model.token_to_id.get(token) + if token_id is None: + continue + prior[token_id] += decay**distance + total = float(prior.sum()) + if total > 0.0: + prior /= total + return prior + + def _preference_prior(self) -> Vector: + assert self.embedding_model is not None + if not self.preference_bias or not any(value != 0.0 for value in self.preference_bias): + return [0.0 for _ in self.embedding_model.id_to_token] + eligible_indices = [ + index + for index, token in enumerate(self.embedding_model.id_to_token) + if self.preference_bias[index] > 0.0 and self._eligible_preference_token(token) + ] + if not eligible_indices: + return [0.0 for _ in self.embedding_model.id_to_token] + eligible_probabilities = self._calibrated_softmax( + [self.preference_bias[index] for index in eligible_indices] + ) + prior = [0.0 for _ in self.embedding_model.id_to_token] + for index, probability in zip(eligible_indices, eligible_probabilities): + prior[index] = probability + return prior + + def _preference_prior_array(self) -> object: + assert np is not None + assert self.embedding_model is not None + if self.preference_bias_array is None or not np.any(self.preference_bias_array != 0.0): + return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + if self.preference_valid_mask_array is None or not np.any(self.preference_valid_mask_array): + return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + positive_mask = self.preference_bias_array > 0.0 + active_mask = self.preference_valid_mask_array & positive_mask + if not np.any(active_mask): + return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) + prior[active_mask] = self._calibrated_softmax_array( + self.preference_bias_array[active_mask] + ) + return prior + + def _eligible_preference_token(self, token: str) -> bool: + assert self.tokenizer is not None + if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens: + return False + if not self._starts_new_word(token): + return False + rendered = self._render_token(token) + if not rendered.strip() or self._is_punctuation_piece(rendered): + return False + alphanumeric = "".join(character for character in rendered if character.isalnum()) + return len(alphanumeric) >= 1 + + def _build_transition_tables( + self, + tokens: list[str], + ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: + counts: dict[int, dict[tuple[str, ...], dict[str, int]]] = { + order: {} for order in sorted(TRANSITION_ORDERS) + } + for order in sorted(TRANSITION_ORDERS): + for index in range(order - 1, len(tokens) - 1): + key = tuple(tokens[index - order + 1 : index + 1]) + nxt = tokens[index + 1] + bucket = counts[order].setdefault(key, {}) + bucket[nxt] = bucket.get(nxt, 0) + 1 + + probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = { + order: {} for order in sorted(TRANSITION_ORDERS) + } + for order, mapping in counts.items(): + items = list(mapping.items()) + items.sort(key=lambda item: (-sum(item[1].values()), item[0])) + if ( + self.config.max_transition_contexts_per_order is not None + and self.config.max_transition_contexts_per_order >= 0 + ): + items = items[: self.config.max_transition_contexts_per_order] + for key, bucket in items: + next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0])) + if self.config.max_transition_next_tokens > 0: + next_items = next_items[: self.config.max_transition_next_tokens] + total = sum(value for _, value in next_items) + if total <= 0: + continue + probabilities[order][key] = { + token: value / total + for token, value in next_items + } + return probabilities + + def _serialize_transition_tables(self) -> dict[str, dict[str, dict[str, float]]]: + assert self.transition_tables is not None + return { + str(order): { + _encode_ngram_key(key): value + for key, value in mapping.items() + } + for order, mapping in self.transition_tables.items() + } + + def _deserialize_transition_tables( + self, + payload: dict[str, dict[str, dict[str, float]]], + ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: + tables: dict[int, dict[tuple[str, ...], dict[str, float]]] = { + order: {} for order in sorted(TRANSITION_ORDERS) + } + for order_text, mapping in payload.items(): + order = int(order_text) + tables[order] = { + _decode_ngram_key(key): { + str(token): float(probability) + for token, probability in value.items() + } + for key, value in mapping.items() + } + return tables + + def _eligible_copy_token(self, token: str) -> bool: + rendered = self._render_token(token) + if not rendered.strip(): + return False + if self._is_punctuation_piece(rendered): + return False + if not self._starts_new_word(token): + return False + alphanumeric = "".join(character for character in rendered if character.isalnum()) + return len(alphanumeric) >= 2 + + def _allowed_generation_token(self, token: str, generated_tokens: list[str]) -> bool: + assert self.embedding_model is not None + if len(self.embedding_model.id_to_token) < 1024: + return True + if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens: + return False + rendered = self._render_token(token) + if rendered == "\n": + return bool(generated_tokens) + if not rendered.strip(): + return False + if self._is_word_joiner_token(token): + return ( + self._can_attach_word_joiner(generated_tokens) + or self._can_start_line_with_word_joiner(token, generated_tokens) + ) + if self._is_structural_punctuation_token(token): + return bool(generated_tokens) or self._can_start_answer_with_structural_punctuation(token) + if self._is_structural_symbol_token(token): + return bool(generated_tokens) or self._starts_new_word(token) + if not self._starts_new_word(token): + return False + alphanumeric = "".join(character for character in rendered if character.isalnum()) + return len(alphanumeric) >= 1 or not self._is_punctuation_piece(rendered) + + def _would_repeat_recent_pattern( + self, + candidate: str, + generated_tokens: list[str], + recent_rendered_words: list[str] | None = None, + ) -> bool: + if len(generated_tokens) >= 2 and generated_tokens[-1] == candidate and generated_tokens[-2] == candidate: + return True + + if len(generated_tokens) >= 2: + trigram = tuple(generated_tokens[-2:] + [candidate]) + recent_tokens = generated_tokens[-12:] + for index in range(max(0, len(recent_tokens) - 4)): + if tuple(recent_tokens[index : index + 3]) == trigram: + return True + + rendered_words = recent_rendered_words + if rendered_words is None: + rendered_words = self._recent_rendered_words(generated_tokens) + candidate_word = self._render_token(candidate).casefold() + if ( + rendered_words + and self._starts_new_word(candidate) + and any(character.isalnum() for character in candidate_word) + ): + candidate_bigram = (rendered_words[-1], candidate_word) + recent_window = rendered_words[-10:] + recent_bigrams = { + (recent_window[index], recent_window[index + 1]) + for index in range(len(recent_window) - 1) + } + if candidate_bigram in recent_bigrams: + return True + if ( + len(candidate_word) > 2 + and rendered_words[-10:].count(candidate_word) >= 2 + and not self._is_common_connector_token(candidate) + ): + return True + + return False + + def _recent_rendered_words(self, generated_tokens: list[str]) -> list[str]: + rendered_words: list[str] = [] + for token in generated_tokens: + if not self._starts_new_word(token): + continue + rendered = self._render_token(token).casefold() + if any(character.isalnum() for character in rendered): + rendered_words.append(rendered) + return rendered_words + + def _select_generation_token( + self, + distribution: dict[str, float], + *, + context_tokens: list[str] | None = None, + generated_tokens: list[str] | None = None, + temperature: float = DEFAULT_GENERATION_TEMPERATURE, + top_k: int = DEFAULT_GENERATION_TOP_K, + top_p: float = DEFAULT_GENERATION_TOP_P, + repetition_penalty: float = DEFAULT_REPETITION_PENALTY, + preserve_dominant_candidates: bool = False, + ) -> str: + assert self.tokenizer is not None + generated_tokens = generated_tokens or [] + candidates = self._prepare_generation_candidates( + distribution, + generated_tokens=generated_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + preserve_dominant_candidates=preserve_dominant_candidates, + ) + if candidates: + return self._sample_generation_candidate( + candidates, + context_tokens=context_tokens or [], + generated_tokens=generated_tokens, + stochastic=temperature > 0.0, + ) + + for token, _ in sorted(distribution.items(), key=lambda item: item[1], reverse=True): + if token in self.tokenizer.special_tokens: + continue + if token == self.tokenizer.unk_token: + continue + if not self._allowed_generation_token(token, generated_tokens): + continue + return token + return "" + + def _select_generation_token_from_array( + self, + probabilities: object, + *, + context_tokens: list[str], + generated_tokens: list[str], + temperature: float = DEFAULT_GENERATION_TEMPERATURE, + top_k: int = DEFAULT_GENERATION_TOP_K, + top_p: float = DEFAULT_GENERATION_TOP_P, + repetition_penalty: float = DEFAULT_REPETITION_PENALTY, + preserve_dominant_candidates: bool = False, + ) -> str: + assert np is not None + assert self.tokenizer is not None + assert self.embedding_model is not None + + values = np.asarray(probabilities, dtype=np.float64) + if values.size == 0: + return "" + pool_size = min(values.size, max(top_k * 4, 64)) + if pool_size <= 0: + pool_size = min(values.size, 64) + if pool_size < values.size: + candidate_indices = np.argpartition(values, -pool_size)[-pool_size:] + candidate_indices = candidate_indices[np.argsort(values[candidate_indices])[::-1]] + else: + candidate_indices = np.argsort(values)[::-1] + + distribution: dict[str, float] = {} + for raw_index in candidate_indices: + index = int(raw_index) + score = float(values[index]) + if score <= 0.0: + continue + token = self.embedding_model.id_to_token[index] + if token in self.tokenizer.special_tokens or token == self.tokenizer.unk_token: + continue + distribution[token] = score + return self._select_generation_token( + distribution, + context_tokens=context_tokens, + generated_tokens=generated_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + preserve_dominant_candidates=preserve_dominant_candidates, + ) + + def _prepare_generation_candidates( + self, + distribution: dict[str, float], + *, + generated_tokens: list[str], + temperature: float, + top_k: int, + top_p: float, + repetition_penalty: float, + preserve_dominant_candidates: bool = False, + ) -> list[tuple[str, float]]: + assert self.tokenizer is not None + assert self.embedding_model is not None + + generated_word_count = self._generated_word_count(generated_tokens) + clause_words = self._words_since_clause_break(generated_tokens) + recent_rendered_words = self._recent_rendered_words(generated_tokens) + best_probability = max(distribution.values(), default=0.0) + adjusted: list[tuple[str, float]] = [] + for token, probability in sorted(distribution.items(), key=lambda item: item[1], reverse=True): + if token in self.tokenizer.special_tokens: + continue + if token == self.tokenizer.unk_token or probability <= 0.0: + continue + if not self._allowed_generation_token(token, generated_tokens): + continue + repeats_recent_pattern = self._would_repeat_recent_pattern( + token, + generated_tokens, + recent_rendered_words=recent_rendered_words, + ) + if ( + repeats_recent_pattern + and not ( + preserve_dominant_candidates + and best_probability > 0.0 + and probability >= best_probability * 0.80 + ) + ): + continue + + score = probability + rendered = self._render_token(token) + punctuation_token = self._is_structural_punctuation_token(token) + starts_new_word = self._starts_new_word(token) + alphanumeric = "".join(character for character in rendered if character.isalnum()) + if generated_tokens and starts_new_word and alphanumeric: + previous_rendered = self._render_token(generated_tokens[-1]) + previous_alphanumeric = "".join( + character for character in previous_rendered if character.isalnum() + ) + if previous_alphanumeric.casefold() == alphanumeric.casefold(): + continue + common_connector = self._is_common_connector_token(token) + if ( + starts_new_word + and len(alphanumeric) == 1 + and not common_connector + ): + score *= 0.08 + recent_count = generated_tokens[-12:].count(token) + if recent_count > 0 and not common_connector: + score /= repetition_penalty ** (2 * recent_count) + if generated_tokens and token == generated_tokens[-1]: + score /= repetition_penalty**3 + if generated_tokens and token in generated_tokens[-4:] and not common_connector: + score *= 0.35 + if generated_tokens and not starts_new_word and self._starts_new_word(generated_tokens[-1]): + score *= 0.08 + if not generated_tokens and punctuation_token: + if best_probability <= 0.0 or probability < best_probability * 0.80: + score *= 0.01 + elif not generated_tokens and not starts_new_word: + score *= 0.02 + if punctuation_token: + if generated_tokens and self._is_structural_punctuation_token(generated_tokens[-1]): + score *= 0.05 + if clause_words >= 6: + score *= 1.0 + min(1.4, 0.18 * (clause_words - 5)) + elif generated_word_count >= 12: + score *= 1.1 + if score > 0.0: + adjusted.append((token, score)) + + if not adjusted: + return [] + adjusted.sort(key=lambda item: item[1], reverse=True) + if top_k > 0: + adjusted = adjusted[:top_k] + if 0.0 < top_p < 1.0: + kept: list[tuple[str, float]] = [] + cumulative = 0.0 + total = sum(score for _, score in adjusted) + for token, score in adjusted: + normalized = score / total if total else 0.0 + kept.append((token, score)) + cumulative += normalized + if cumulative >= top_p: + break + adjusted = kept + + if temperature <= 0.0: + return [(adjusted[0][0], 1.0)] + + exponent = 1.0 / temperature + tempered = [ + (token, score**exponent) + for token, score in adjusted + if score > 0.0 + ] + total = sum(score for _, score in tempered) + if total <= 0.0: + return [] + return [(token, score / total) for token, score in tempered] + + def _sample_generation_candidate( + self, + candidates: list[tuple[str, float]], + *, + context_tokens: list[str], + generated_tokens: list[str], + stochastic: bool = False, + ) -> str: + if not candidates: + return "" + if len(candidates) == 1: + return candidates[0][0] + top_probability = candidates[0][1] + second_probability = candidates[1][1] + top_has_clear_half_majority = top_probability >= 0.5 and ( + second_probability <= 0.0 + or top_probability - second_probability >= 0.02 + ) + if top_has_clear_half_majority or ( + second_probability > 0.0 and top_probability >= second_probability * 2.5 + ) or ( + top_probability >= 0.08 + and second_probability > 0.0 + and top_probability >= second_probability * 1.35 + ): + return candidates[0][0] + if stochastic: + threshold = random.random() + else: + seed_payload = "\u0002".join([*context_tokens, "", *generated_tokens, str(len(candidates))]) + seed = int.from_bytes(hashlib.sha256(seed_payload.encode("utf-8")).digest()[:8], "big") + threshold = random.Random(seed).random() + cumulative = 0.0 + for token, probability in candidates: + cumulative += probability + if threshold <= cumulative: + return token + return candidates[-1][0] + + def _top_entries_from_vector( + self, + values: Vector, + limit: int, + ) -> list[dict[str, object]]: + if limit <= 0: + return [] + ranked = sorted( + enumerate(values), + key=lambda item: item[1], + reverse=True, + ) + return [ + self._token_entry(index, probability) + for index, probability in ranked[:limit] + if probability > 0.0 + ] + + def _token_entry( + self, + index: int, + probability: float, + ) -> dict[str, object]: + assert self.embedding_model is not None + token = self.embedding_model.id_to_token[index] + return { + "token": token, + "text": self._render_token(token), + "probability": probability, + } + + def _build_reasoning_summary( + self, + transition_order: int | None, + blend_weights: dict[str, float], + ) -> str: + dominant_source = max(blend_weights.items(), key=lambda item: item[1])[0] if blend_weights else "base" + if transition_order is not None: + transition_message = f" Transition prior is using order-{transition_order} context." + else: + transition_message = " Transition prior found no matching n-gram." + + return ( + "Generation is running on analytical state, recurrent traces, and corpus-derived token transitions." + f"{transition_message}" + f" Dominant blend source: {dominant_source}." + ) + + def _generated_word_count(self, tokens: list[str]) -> int: + return len(self._decode_tokens(tokens).split()) + + def _is_structural_punctuation_text(self, text: str) -> bool: + if len(text) != 1: + return False + if self._is_word_joiner_text(text): + return False + category = unicodedata.category(text) + return category.startswith("P") + + def _is_structural_punctuation_token(self, token: str) -> bool: + return self._is_structural_punctuation_text(self._render_token(token)) + + def _is_structural_symbol_token(self, token: str) -> bool: + rendered = self._render_token(token) + return len(rendered) == 1 and unicodedata.category(rendered).startswith("S") + + def _is_word_joiner_token(self, token: str) -> bool: + return self._is_word_joiner_text(self._render_token(token)) + + def _is_word_joiner_text(self, text: str) -> bool: + if len(text) != 1: + return False + category = unicodedata.category(text) + if category in ("Pc", "Pd", "Lm"): + return True + name = unicodedata.name(text, "") + return "APOSTROPHE" in name or ( + "SINGLE" in name and "QUOTATION MARK" in name + ) + + def _can_start_line_with_word_joiner(self, token: str, generated_tokens: list[str]) -> bool: + rendered = self._render_token(token) + if len(rendered) != 1 or unicodedata.category(rendered) != "Pd": + return False + if not self._starts_new_word(token): + return False + return not generated_tokens or self._render_token(generated_tokens[-1]) == "\n" + + def _can_start_answer_with_structural_punctuation(self, token: str) -> bool: + rendered = self._render_token(token) + if len(rendered) != 1 or not self._starts_new_word(token): + return False + return unicodedata.category(rendered) in ("Ps", "Pi") + + def _is_common_connector_token(self, token: str) -> bool: + rendered = self._render_token(token) + return rendered.isalpha() and len(rendered) <= 3 + + def _can_attach_word_joiner(self, generated_tokens: list[str]) -> bool: + if not generated_tokens: + return False + rendered = self._render_token(generated_tokens[-1]) + if not rendered: + return False + if any(character.isalnum() for character in rendered): + return True + if len(rendered) != 1: + return False + return unicodedata.category(rendered) in ("Ps", "Pi") + + def _words_since_clause_break(self, tokens: list[str]) -> int: + assert self.tokenizer is not None + + words = 0 + for token in reversed(tokens): + if token in self.tokenizer.special_tokens: + continue + rendered = self._render_token(token) + if self._is_structural_punctuation_text(rendered): + break + if self._starts_new_word(token) and not self._is_punctuation_piece(rendered): + words += 1 + return words + + def _should_stop_generation(self, generated_tokens: list[str]) -> bool: + if not generated_tokens: + return False + if not self._is_terminal_punctuation_text(self._render_token(generated_tokens[-1])): + return False + return self._generated_word_count(generated_tokens) >= 14 + + def _is_terminal_punctuation_text(self, text: str) -> bool: + if not self._is_structural_punctuation_text(text): + return False + name = unicodedata.name(text, "") + return ( + "FULL STOP" in name + or "QUESTION MARK" in name + or "EXCLAMATION MARK" in name + ) + + def _starts_new_word(self, token: str) -> bool: + assert self.tokenizer is not None + if token in self.tokenizer.special_tokens: + return True + if token.startswith(self.tokenizer.word_prefix): + return True + return len(token) == 1 and not token.isalnum() and not self._is_word_joiner_token(token) + + def _decode_tokens(self, tokens: list[str]) -> str: + assert self.tokenizer is not None + return self.tokenizer.decode(tokens) + + def _render_token(self, token: str) -> str: + assert self.tokenizer is not None + if token.startswith(self.tokenizer.word_prefix): + return token[len(self.tokenizer.word_prefix) :] + return token + + def _require_fit(self) -> None: + if ( + self.tokenizer is None + or self.embedding_model is None + or self.memory_units is None + or self.readout_weights is None + or self.ternary_mask is None + or self.associative_keys is None + or self.associative_key_norms is None + or self.associative_values is None + or self.transition_tables is None + ): + raise RuntimeError("Call fit() before using the REFRAMR model.") + + def _ensure_numeric_caches(self) -> None: + if np is None: + return + if self.readout_weights_array is None: + self._refresh_numeric_caches()