import json import hashlib import random import site import string import sys import unicodedata from dataclasses import dataclass, field from pathlib import Path from typing import Sequence _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor" for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"): if _vendor_path.exists(): vendor_text = str(_vendor_path) if vendor_text not in sys.path: sys.path.insert(0, vendor_text) try: import numpy as np except ModuleNotFoundError: user_site = site.getusersitepackages() if user_site and user_site not in sys.path: sys.path.append(user_site) try: import numpy as np except ModuleNotFoundError: np = None if np is not None and not hasattr(np, "asarray"): np = None from .checkpoint import read_safetensor_file, write_safetensor_file from .config import ReframrConfig from .embeddings import EmbeddingModel, fit_ppmi_embedding_from_tokens from .hippo import AnalyticalMemoryUnit, analytical_embedding_drive, analytical_embedding_drive_fast from .linalg import Vector, dot, mean, norm, softmax, zeros_vector from .reservoir import apply_readout, ridge_regression_readout from .reasoning import TOOL_PROTOCOL_TOKENS, reasoning_prefix from .sparse_context import HashedSparseAttention from .ternary import apply_ternary_mask, derive_ternary_mask_from_states from .tokenizer import NativeTokenizer ASSOCIATIVE_BLEND = 0.42 TRANSITION_BLEND = 0.08 COPY_BLEND = 0.04 BASE_BLEND = 0.34 FAST_ASSOCIATIVE_BLEND = 0.06 FAST_TRANSITION_BLEND = 0.14 FAST_COPY_BLEND = 0.12 FAST_BASE_BLEND = 0.72 FAST_PREFERENCE_BLEND = 0.15 FAST_ANSWER_BLEND = 0.16 FAST_SOURCE_EVIDENCE_BLEND = 0.52 PROMPT_READOUT_LOGIT_ZSCORE_SCALE = 0.48 PROMPT_START_READOUT_CONFIDENCE_FLOOR = 0.45 ASSOCIATIVE_TOP_K = 12 ANSWER_TOP_K = 48 ANSWER_START_TOP_K = 32 MIN_COMPLETE_ANSWER_WORDS = 6 MIN_COMPLETE_MULTI_SENTENCE_WORDS = 4 ANSWER_SEQUENCE_MATCH_FLOOR = 0.27 ANSWER_START_CONFIDENCE_FLOOR = 0.45 ANSWER_START_MATCH_SUPPORT_FLOOR = 0.18 ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR = 0.45 ANSWER_SEQUENCE_LOCK_FLOOR = 0.55 ANSWER_SEQUENCE_SPIKE_CONFIDENCE = 0.80 READOUT_LOGIT_ZSCORE_SCALE = 0.22 TRACE_IDENTITY_SCALE = 0.78 TRACE_IDENTITY_HASHES = ( (1103515245, 12345, 214013, 2531011), (1664525, 1013904223, 22695477, 1), (69069, 362437, 134775813, 17), (134775813, 97, 1103515245, 31), (22695477, 911, 1664525, 73), (214013, 2531011, 69069, 19), (48271, 0, 69621, 11), (16807, 37, 40692, 101), (279470273, 173, 1299709, 53), (39916801, 29, 2147483629, 7), ) PROMPT_ENVELOPE_TERMS = frozenset( {"system", "instruction", "user", "human", "assistant", "question", "answer"} ) NGRAM_KEY_SEPARATOR = "\u0001" TRANSITION_ORDERS = (10, 8, 6, 5, 4, 3, 2, 1) DEFAULT_GENERATION_TEMPERATURE = 0.82 DEFAULT_GENERATION_TOP_K = 24 DEFAULT_GENERATION_TOP_P = 0.92 DEFAULT_REPETITION_PENALTY = 1.18 ANSWER_SEQUENCE_MAX_TOKENS = 192 ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT = 8192 ANSWER_SEQUENCE_VARIATION_TEMPERATURE = 0.65 ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT = 4 ANSWER_SEQUENCE_CREATIVE_TEMPERATURE = 1.10 ANSWER_REPLAY_PREFIX_TEMPERATURE = 0.95 ANSWER_REPLAY_PREFIX_MIN_TOKENS = 64 ANSWER_REPLAY_PREFIX_PENALTY = 0.18 CREATIVE_EARLY_POOL_TEMPERATURE = 1.05 CREATIVE_EARLY_POOL_WORD_LIMIT = 6 CREATIVE_EARLY_POOL_MAX = 8 TOOL_CALL_CONTEXT_TERMS = frozenset( { "current", "latest", "today", "yesterday", "tonight", "now", "fresh", "recent", "web", "search", "real-time", "price", "weather", "election", "news", "official", "result", "live", } ) RUNTIME_GENERATION_HISTORY_LIMIT = 8 AVOID_SEQUENCE_MIN_TOKENS = 6 WORD_COMPLETION_OVERFLOW_TOKENS = 16 ANSWER_FINGERPRINT_WORDS = 4 SPARSE_CONTEXT_MIN_TOKENS = 16 SPARSE_CONTEXT_TOP_K = 64 SPARSE_CONTEXT_HASH_BITS = 12 SPARSE_CONTEXT_PROBE_RADIUS = 1 SPARSE_CONTEXT_CANDIDATE_MULTIPLIER = 16 SPARSE_CONTEXT_TRACE_BLEND = 0.35 RUNTIME_ARRAY_DTYPE = np.float32 if np is not None else None @dataclass(frozen=True, slots=True) class CharacterCountFact: character: str word: str count: int surface_seed: int focused: bool @dataclass(frozen=True, slots=True) class GenerationTokenMeta: rendered: str stripped: str starts_new_word: bool punctuation_piece: bool structural_punctuation: bool structural_symbol: bool word_joiner: bool alphanumeric: str common_connector: bool def _normalize_vector(values: Vector) -> Vector: total = sum(values) if total <= 0.0: return [0.0 for _ in values] return [value / total for value in values] def _encode_ngram_key(tokens: tuple[str, ...]) -> str: return NGRAM_KEY_SEPARATOR.join(tokens) def _decode_ngram_key(key: str) -> tuple[str, ...]: return tuple(part for part in key.split(NGRAM_KEY_SEPARATOR) if part) def _last_index(values: list[str], target: str) -> int | None: for index in range(len(values) - 1, -1, -1): if values[index] == target: return index return None def _first_index(values: list[str], target: str) -> int | None: for index, value in enumerate(values): if value == target: return index return None @dataclass(slots=True) class DecodeState: hidden_states: list[Vector] context_traces: list[Vector] combined_state: Vector context_tokens: list[str] answer_anchor_state: Vector | None = None answer_matches: list[tuple[float, int, int]] | None = None answer_start_matches: list[tuple[float, int, int]] | None = None answer_sequence_matches: list[tuple[float, int, int]] | None = None prompt_answer_prior: object | None = None prompt_answer_start_prior: object | None = None @dataclass(slots=True) class ReframrModel: config: ReframrConfig tokenizer: NativeTokenizer | None = None embedding_model: EmbeddingModel | None = None memory_units: list[AnalyticalMemoryUnit] | None = None ternary_scale: float = 1.0 ternary_mask: list[int] | None = None ternary_mask_array: object | None = None readout_weights: list[list[float]] | None = None readout_weights_array: object | None = None readout_bias: Vector | None = None readout_bias_array: object | None = None prompt_answer_weights: list[list[float]] | None = None prompt_answer_weights_array: object | None = None prompt_answer_bias: Vector | None = None prompt_answer_bias_array: object | None = None prompt_answer_start_weights: list[list[float]] | None = None prompt_answer_start_weights_array: object | None = None prompt_answer_start_bias: Vector | None = None prompt_answer_start_bias_array: object | None = None trace_token_weights: Vector | None = None trace_token_weights_array: object | None = None trace_embedding_table_array: object | None = None preference_bias: Vector | None = None preference_bias_array: object | None = None preference_valid_mask_array: object | None = None state_offset: Vector | None = None state_offset_array: object | None = None associative_keys: list[Vector] | None = None associative_keys_array: object | None = None associative_key_norms: list[float] | None = None associative_key_norms_array: object | None = None associative_values: list[int] | None = None associative_values_array: object | None = None associative_valid_mask_array: object | None = None answer_keys: list[Vector] | None = None answer_keys_array: object | None = None answer_key_norms: list[float] | None = None answer_key_norms_array: object | None = None answer_similarity_keys_array: object | None = None answer_similarity_key_norms_array: object | None = None answer_similarity_mask_array: object | None = None answer_values: list[int] | None = None answer_values_array: object | None = None answer_valid_mask_array: object | None = None answer_start_keys: list[Vector] | None = None answer_start_keys_array: object | None = None answer_start_key_norms: list[float] | None = None answer_start_key_norms_array: object | None = None answer_start_similarity_keys_array: object | None = None answer_start_similarity_key_norms_array: object | None = None answer_start_values: list[int] | None = None answer_start_values_array: object | None = None answer_start_valid_mask_array: object | None = None answer_sequence_keys: list[Vector] | None = None answer_sequence_keys_array: object | None = None answer_sequence_key_norms: list[float] | None = None answer_sequence_key_norms_array: object | None = None answer_sequence_similarity_keys_array: object | None = None answer_sequence_similarity_key_norms_array: object | None = None answer_sequence_prompt_tokens: list[list[int]] | None = None answer_sequence_prompt_tokens_array: object | None = None answer_sequence_tokens: list[list[int]] | None = None answer_sequence_tokens_array: object | None = None answer_sequence_token_id_rows: list[list[int]] | None = None answer_sequence_prompt_weight_maps: list[dict[int, float]] | None = None answer_sequence_prompt_weight_norms: list[float] | None = None answer_sequence_prompt_bigram_sets: list[set[tuple[int, int]]] | None = None answer_sequence_prompt_trigram_sets: list[set[tuple[int, int, int]]] | None = None answer_sequence_prompt_number_sets: list[set[str]] | None = None answer_sequence_prompt_inverted_index: dict[int, list[int]] | None = None answer_sequence_prompt_specificity: dict[int, float] | None = None prompt_overlap_valid_token_mask_array: object | None = None answer_fingerprint_hashes: set[tuple[int, ...]] | None = None answer_fingerprint_token_lengths: set[int] | None = None answer_fingerprint_token_sequences_by_length: dict[int, set[tuple[int, ...]]] | None = None answer_sequence_prefixes_by_length: dict[int, set[tuple[int, ...]]] | None = None transition_tables: dict[int, dict[tuple[str, ...], dict[str, float]]] | None = None transition_id_tables: dict[int, dict[tuple[int, ...], tuple[object, object]]] | None = None transition_tensor_cache: dict[str, object] | None = None transition_built_orders: set[int] | None = None generation_token_meta_cache: dict[str, GenerationTokenMeta] | None = None runtime_generation_history: dict[str, list[str]] = field(default_factory=dict, repr=False) def fit(self, text: str) -> "ReframrModel": self.generation_token_meta_cache = None self.answer_sequence_prefixes_by_length = None self.tokenizer = NativeTokenizer.train( text, vocab_size=self.config.tokenizer_vocab_size, min_pair_frequency=self.config.tokenizer_min_pair_frequency, lowercase=self.config.lowercase, ) tokens = self.tokenizer.encode(text) if len(tokens) < 2: raise ValueError("REFRAMR needs at least two tokens to derive a next-token readout.") self.embedding_model = fit_ppmi_embedding_from_tokens( tokens, embedding_dim=self.config.embedding_dim, window_size=self.config.window_size, min_frequency=self.config.min_frequency, max_vocab=self.config.max_vocab, required_tokens=self.tokenizer.vocab, ) self.memory_units = [ AnalyticalMemoryUnit(self.config.state_dim, timescale) for timescale in self.config.timescales ] token_counts: dict[str, float] = {} for token in tokens: token_counts[token] = token_counts.get(token, 0.0) + 1.0 self.trace_token_weights = self._derive_trace_token_weights_from_counts(token_counts) raw_states, targets, target_ids = self._collect_training_examples(tokens) self.ternary_scale, self.ternary_mask = derive_ternary_mask_from_states(raw_states) analytical_states = [ apply_ternary_mask(state, self.ternary_mask, self.ternary_scale) for state in raw_states ] self.associative_keys = [state[:] for state in analytical_states] self.associative_key_norms = [norm(state) for state in analytical_states] self.associative_values = target_ids[:] self.answer_keys = [] self.answer_key_norms = [] self.answer_values = [] self.answer_start_keys = [] self.answer_start_key_norms = [] self.answer_start_values = [] self.answer_sequence_keys = [] self.answer_sequence_key_norms = [] self.answer_sequence_prompt_tokens = [] self.answer_sequence_tokens = [] self.prompt_answer_weights = [] self.prompt_answer_bias = [0.0 for _ in self.embedding_model.id_to_token] self.prompt_answer_start_weights = [] self.prompt_answer_start_bias = [0.0 for _ in self.embedding_model.id_to_token] self.transition_tables = self._build_transition_tables(tokens) self._fit_answer_memory_from_text(text) self._refresh_answer_fingerprint_hashes() self.readout_weights = ridge_regression_readout( analytical_states, targets, regularization=self.config.regularization, ) self.readout_bias = [0.0 for _ in self.embedding_model.id_to_token] self.preference_bias = [0.0 for _ in self.embedding_model.id_to_token] self.state_offset = [0.0 for _ in analytical_states[0]] if analytical_states else [] self._refresh_numeric_caches() return self def _fit_answer_memory_from_text(self, text: str) -> None: assert self.tokenizer is not None assert self.embedding_model is not None if ( self.answer_keys is None or self.answer_key_norms is None or self.answer_values is None or self.answer_start_keys is None or self.answer_start_key_norms is None or self.answer_start_values is None or self.answer_sequence_keys is None or self.answer_sequence_key_norms is None or self.answer_sequence_prompt_tokens is None or self.answer_sequence_tokens is None ): return for line in text.splitlines(): if "" not in line: continue prompt_text, answer_text = line.split("", 1) prompt_text = prompt_text.strip() answer_text = answer_text.strip() if not prompt_text or not answer_text: continue prompt_tokens = self.tokenizer.encode(prompt_text) + [""] answer_tokens = [ token for token in self.tokenizer.encode(answer_text) if token in self.embedding_model.token_to_id and ( token not in self.tokenizer.special_tokens or token in TOOL_PROTOCOL_TOKENS ) ] if not prompt_tokens or not answer_tokens: continue key = self._encode_context(prompt_tokens) key_norm = norm(key) if key_norm <= 0.0: continue answer_ids = [ self.embedding_model.token_to_id[token] for token in answer_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] ] prompt_ids = [ self.embedding_model.token_to_id[token] for token in prompt_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] if token in self.embedding_model.token_to_id and ( token not in self.tokenizer.special_tokens or token in TOOL_PROTOCOL_TOKENS ) ] if not answer_ids: continue self.answer_keys.append(key[:]) self.answer_key_norms.append(key_norm) self.answer_values.append(answer_ids[0]) self.answer_start_keys.append(key[:]) self.answer_start_key_norms.append(key_norm) self.answer_start_values.append(answer_ids[0]) self.answer_sequence_keys.append(key[:]) self.answer_sequence_key_norms.append(key_norm) self.answer_sequence_prompt_tokens.append( prompt_ids + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(prompt_ids))] ) self.answer_sequence_tokens.append( answer_ids + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(answer_ids))] ) def predict_next_distribution( self, context: str, *, reasoning_mode: str | None = None, ) -> dict[str, float]: self._require_fit() assert self.tokenizer is not None assert self.embedding_model is not None probabilities = self.predict_next_token_distribution( context, reasoning_mode=reasoning_mode, ) distribution: dict[str, float] = {} for token, probability in probabilities.items(): rendered = self._render_token(token) distribution[rendered] = distribution.get(rendered, 0.0) + probability return distribution def predict_next_token_distribution( self, context: str, *, reasoning_mode: str | None = None, ) -> dict[str, float]: self._require_fit() assert self.tokenizer is not None assert self.embedding_model is not None assert self.readout_weights is not None active_mode = reasoning_mode or self.config.default_reasoning_profile context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) return self._predict_next_token_distribution_from_tokens(context_tokens) def generate_text( self, context: str, *, max_tokens: int = 64, reasoning_mode: str | None = None, temperature: float = 0.0, top_k: int = DEFAULT_GENERATION_TOP_K, top_p: float = DEFAULT_GENERATION_TOP_P, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, avoid_texts: Sequence[str] | None = None, ) -> str: character_count_response = self._character_count_response( context, temperature=temperature, ) if character_count_response is not None: return character_count_response self._require_fit() self._ensure_numeric_caches() assert self.tokenizer is not None runtime_avoid_texts = self._runtime_avoid_texts( context, avoid_texts, temperature=temperature, ) avoid_token_sequences = self._avoid_text_token_sequences(runtime_avoid_texts) if ( np is not None and self.readout_weights_array is not None and self.embedding_model is not None and len(self.embedding_model.id_to_token) >= 1024 ): generated_text = self._generate_text_fast( context, max_tokens=max_tokens, reasoning_mode=reasoning_mode, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, avoid_token_sequences=avoid_token_sequences, ) self._remember_runtime_generation( context, generated_text, temperature=temperature, ) return generated_text active_mode = reasoning_mode or self.config.default_reasoning_profile _, context_tokens = self._generation_prompt_tokens(context, active_mode) decode_state = self._build_decode_state(context_tokens) generated_tokens: list[str] = [] for _ in range(max_tokens): distribution, _ = self._score_next_token_from_state( decode_state, include_trace=False, generated_tokens=generated_tokens, temperature=temperature, avoid_token_sequences=avoid_token_sequences, ) forced_source_token = self._source_evidence_next_token( decode_state.context_tokens, generated_tokens, ) next_token = forced_source_token or self._select_generation_token( distribution, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, avoid_token_sequences=avoid_token_sequences, preserve_dominant_candidates=( self._answer_decode_has_continuation(decode_state, generated_tokens) or self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) ), ) if not next_token: break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) if self._should_stop_answer_sequence(decode_state, generated_tokens): break if self._should_stop_after_answer_path_drift(decode_state, generated_tokens): break if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens): break if ( self._should_stop_generation(generated_tokens) and not self._answer_decode_has_continuation(decode_state, generated_tokens) and not self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) ): break overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens) while generated_tokens and overflow_budget > 0: has_answer_continuation = self._answer_decode_has_continuation( decode_state, generated_tokens, ) has_source_continuation = self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) if ( self._starts_new_word(generated_tokens[-1]) and not has_answer_continuation and not has_source_continuation ): break distribution, _ = self._score_next_token_from_state( decode_state, include_trace=False, generated_tokens=generated_tokens, temperature=temperature, avoid_token_sequences=avoid_token_sequences, ) forced_source_token = self._source_evidence_next_token( decode_state.context_tokens, generated_tokens, ) next_token = forced_source_token or self._select_generation_token( distribution, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, avoid_token_sequences=avoid_token_sequences, preserve_dominant_candidates=has_answer_continuation or has_source_continuation, ) if not next_token: break if ( self._starts_new_word(next_token) and not has_answer_continuation and not has_source_continuation ): break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) overflow_budget -= 1 generated_text = self._finalize_generated_text( self._normalize_generated_tool_protocol_text( self._decode_tokens(generated_tokens), context=context, ) ) self._remember_runtime_generation( context, generated_text, temperature=temperature, ) return generated_text @staticmethod def _character_count_fact(context: str) -> CharacterCountFact | None: normalized = unicodedata.normalize("NFKC", context).strip() tokens = ReframrModel._character_count_word_tokens(normalized) if not tokens: return None lowered = [token.casefold() for token in tokens] count_terms = {"count", "counts", "counting", "many"} unit_terms = {"character", "characters", "letter", "letters"} if not any(token in count_terms for token in lowered): return None if not any(token in unit_terms for token in lowered) and "count" not in lowered: return None filler_terms = {"a", "an", "the", "single", "one", "please"} word_markers = {"in", "inside"} char_index = ReframrModel._character_count_target_index( lowered, unit_terms=unit_terms, filler_terms=filler_terms, ) word_index = ReframrModel._character_count_word_index( lowered, char_index=char_index, filler_terms=filler_terms, word_markers=word_markers, ) if char_index is None or word_index is None: return None character = tokens[char_index] word = tokens[word_index] if len(character) != 1 or not word: return None order_offset = 0 if char_index < word_index else 1 surface_seed = ((char_index + 1) * 7 + (word_index + 1) * 3 + len(tokens) + order_offset) % 4 structural_terms = ( count_terms | unit_terms | filler_terms | word_markers | { "for", "of", "to", "how", "do", "does", "there", "are", "is", "appear", "appears", "times", "word", } ) extra_content_tokens = [ token for index, token in enumerate(lowered) if index not in {char_index, word_index} and token not in structural_terms ] return CharacterCountFact( character=character, word=word, count=word.casefold().count(character.casefold()), surface_seed=surface_seed, focused=not extra_content_tokens, ) @staticmethod def _character_count_word_tokens(text: str) -> list[str]: tokens: list[str] = [] current: list[str] = [] for character in text: if character != "_" and character.isalnum(): current.append(character) continue if current: tokens.append("".join(current)) current = [] if current: tokens.append("".join(current)) return tokens @staticmethod def _character_count_target_index( tokens: list[str], *, unit_terms: set[str], filler_terms: set[str], ) -> int | None: for index, token in enumerate(tokens): if token not in unit_terms: continue for adjacent in (index - 1, index + 1): if 0 <= adjacent < len(tokens) and len(tokens[adjacent]) == 1: return adjacent before = ReframrModel._nearest_content_index(tokens, index - 1, -1, filler_terms) after = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) for candidate in (before, after): if candidate is not None and len(tokens[candidate]) == 1: return candidate for index, token in enumerate(tokens): if token not in {"count", "counts", "counting"}: continue candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) if candidate is not None and tokens[candidate] in unit_terms: candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) if candidate is not None and len(tokens[candidate]) == 1: return candidate return None @staticmethod def _character_count_word_index( tokens: list[str], *, char_index: int | None, filler_terms: set[str], word_markers: set[str], ) -> int | None: for index, token in enumerate(tokens): if token != "word": continue candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: return candidate for index, token in enumerate(tokens): if token not in word_markers: continue candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) if candidate is not None and tokens[candidate] == "word": candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: return candidate skipped_terms = { "how", "many", "do", "does", "count", "counts", "counting", "letter", "letters", "character", "characters", "word", "there", "are", "is", "appear", "appears", "times", } | filler_terms | word_markers for index in range(len(tokens) - 1, -1, -1): if index == char_index: continue if len(tokens[index]) <= 1 or tokens[index] in skipped_terms: continue return index return None @staticmethod def _nearest_content_index( tokens: list[str], start: int, direction: int, skipped_terms: set[str], ) -> int | None: index = start while 0 <= index < len(tokens): if tokens[index] not in skipped_terms: return index index += direction return None @classmethod def _character_count_response(cls, context: str, *, temperature: float = 0.0) -> str | None: fact = cls._character_count_fact(context) if fact is None: return None if not fact.focused: return None return cls._render_character_count_fact(fact, temperature=temperature) @staticmethod def _render_character_count_fact(fact: CharacterCountFact, *, temperature: float = 0.0) -> str: character_label = f"'{fact.character}'" word_label = f"'{fact.word}'" character_noun = "character" if fact.count == 1 else "characters" return f"{word_label} has {fact.count} {character_label} {character_noun}." @classmethod def _runtime_source_grounded_response(cls, context: str) -> str | None: return None @classmethod def _runtime_source_records(cls, context: str) -> list[tuple[str, str, str]]: records: list[tuple[str, str, str]] = [] marker = "" search_from = 0 while True: source_start = context.find(marker, search_from) if source_start < 0: break content_start = source_start + len(marker) content_end = cls._runtime_source_record_end(context, content_start) raw_record = context[content_start:content_end].strip() record = cls._parse_runtime_source_record(raw_record) if record is not None: records.append(record) search_from = max(content_end, content_start + 1) return records @staticmethod def _runtime_source_record_end(context: str, start: int) -> int: boundaries = [ position for marker in ( "\n", "", "", "", "", "", "", ) if (position := context.find(marker, start)) >= 0 ] return min(boundaries) if boundaries else len(context) @staticmethod def _parse_runtime_source_record(raw_record: str) -> tuple[str, str, str] | None: if not raw_record: return None pieces = [piece.strip() for piece in raw_record.split("|", 2)] if len(pieces) >= 3: title, url, snippet = pieces[0], pieces[1], pieces[2] else: title, url, snippet = "the provided source", "", pieces[-1] title = ReframrModel._clean_runtime_source_field(title) or "the provided source" url = ReframrModel._clean_runtime_source_field(url) snippet = ReframrModel._clean_runtime_source_field(snippet) if not snippet: return None return title, url, snippet @staticmethod def _clean_runtime_source_field(text: str) -> str: normalized = unicodedata.normalize("NFKC", text) cleaned = " ".join(normalized.split()) return cleaned.strip(" \t\r\n|") def _generate_text_fast( self, context: str, *, max_tokens: int, reasoning_mode: str | None, temperature: float, top_k: int, top_p: float, repetition_penalty: float, avoid_token_sequences: Sequence[Sequence[str]] | None = None, ) -> str: assert self.tokenizer is not None active_mode = reasoning_mode or self.config.default_reasoning_profile _, context_tokens = self._generation_prompt_tokens(context, active_mode) decode_state = self._build_decode_state(context_tokens) generated_tokens: list[str] = [] for _ in range(max_tokens): probabilities, _ = self._score_next_token_array_from_state( decode_state, include_associative=not generated_tokens, generated_tokens=generated_tokens, temperature=temperature, avoid_token_sequences=avoid_token_sequences, ) forced_source_token = self._source_evidence_next_token( decode_state.context_tokens, generated_tokens, ) next_token = forced_source_token or self._select_generation_token_from_array( probabilities, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, avoid_token_sequences=avoid_token_sequences, preserve_dominant_candidates=( self._answer_decode_has_continuation(decode_state, generated_tokens) or self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) ), ) if not next_token: break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) if self._should_stop_answer_sequence(decode_state, generated_tokens): break if self._should_stop_after_answer_path_drift(decode_state, generated_tokens): break if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens): break if ( self._should_stop_generation(generated_tokens) and not self._answer_decode_has_continuation(decode_state, generated_tokens) and not self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) ): break overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens) while generated_tokens and overflow_budget > 0: has_answer_continuation = self._answer_decode_has_continuation( decode_state, generated_tokens, ) has_source_continuation = self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) if ( self._starts_new_word(generated_tokens[-1]) and not has_answer_continuation and not has_source_continuation ): break probabilities, _ = self._score_next_token_array_from_state( decode_state, include_associative=False, generated_tokens=generated_tokens, temperature=temperature, avoid_token_sequences=avoid_token_sequences, ) forced_source_token = self._source_evidence_next_token( decode_state.context_tokens, generated_tokens, ) next_token = forced_source_token or self._select_generation_token_from_array( probabilities, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, avoid_token_sequences=avoid_token_sequences, preserve_dominant_candidates=has_answer_continuation or has_source_continuation, ) if not next_token: break if ( self._starts_new_word(next_token) and not has_answer_continuation and not has_source_continuation ): break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) overflow_budget -= 1 return self._finalize_generated_text( self._normalize_generated_tool_protocol_text( self._decode_tokens(generated_tokens), context=context, ) ) def trace_next_token( self, context: str, *, reasoning_mode: str | None = None, top_k: int = 5, ) -> dict[str, object]: self._require_fit() assert self.tokenizer is not None active_mode = reasoning_mode or self.config.default_reasoning_profile context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) _, trace = self._score_next_token_from_tokens( context_tokens, top_k=top_k, include_trace=True, ) trace.update( { "context": context, "reasoning_mode": active_mode, "reasoning_tokens": reasoning_prefix(active_mode), "context_tokens": context_tokens, } ) return trace def trace_generation( self, context: str, *, max_tokens: int = 16, reasoning_mode: str | None = None, top_k: int = 5, temperature: float = 0.0, top_p: float = DEFAULT_GENERATION_TOP_P, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, ) -> dict[str, object]: character_count_response = self._character_count_response( context, temperature=temperature, ) if character_count_response is not None: active_mode = reasoning_mode or self.config.default_reasoning_profile prompt = context if "" in context else f"{context} " return { "context": context, "prompt": prompt, "reasoning_mode": active_mode, "reasoning_tokens": reasoning_prefix(active_mode), "generation_policy": { "temperature": temperature, "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), "top_p": top_p, "repetition_penalty": repetition_penalty, }, "prompt_tokens": [], "generated_tokens": [], "generated_text": character_count_response, "generated_token_count": len(character_count_response.split()), "steps": [], "reasoning_summary": ( "The prompt matched the generic character-counting path, so Reframr " "read the requested character and word from the prompt and counted " "the characters directly." ), } self._require_fit() assert self.tokenizer is not None active_mode = reasoning_mode or self.config.default_reasoning_profile prompt, context_tokens = self._generation_prompt_tokens(context, active_mode) decode_state = self._build_decode_state(context_tokens) prompt_tokens = decode_state.context_tokens[:] generated_tokens: list[str] = [] steps: list[dict[str, object]] = [] for step_index in range(1, max_tokens + 1): distribution, trace = self._score_next_token_from_state( decode_state, top_k=top_k, include_trace=True, generated_tokens=generated_tokens, temperature=temperature, ) forced_source_token = self._source_evidence_next_token( decode_state.context_tokens, generated_tokens, ) next_token = forced_source_token or self._select_generation_token( distribution, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=max(DEFAULT_GENERATION_TOP_K, top_k), top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=( self._answer_decode_has_continuation(decode_state, generated_tokens) or self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) ), ) if not next_token: break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) trace["step"] = step_index trace["chosen_token"] = next_token trace["chosen_text"] = self._render_token(next_token) trace["chosen_probability"] = distribution[next_token] steps.append(trace) if self._should_stop_answer_sequence(decode_state, generated_tokens): break if self._should_stop_after_answer_path_drift(decode_state, generated_tokens): break if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens): break if ( self._should_stop_generation(generated_tokens) and not self._answer_decode_has_continuation(decode_state, generated_tokens) and not self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) ): break overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens) while generated_tokens and overflow_budget > 0: has_answer_continuation = self._answer_decode_has_continuation( decode_state, generated_tokens, ) has_source_continuation = self._source_evidence_has_continuation( decode_state.context_tokens, generated_tokens, ) if ( self._starts_new_word(generated_tokens[-1]) and not has_answer_continuation and not has_source_continuation ): break distribution, trace = self._score_next_token_from_state( decode_state, top_k=top_k, include_trace=True, generated_tokens=generated_tokens, temperature=temperature, ) forced_source_token = self._source_evidence_next_token( decode_state.context_tokens, generated_tokens, ) next_token = forced_source_token or self._select_generation_token( distribution, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=max(DEFAULT_GENERATION_TOP_K, top_k), top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=has_answer_continuation or has_source_continuation, ) if not next_token: break if ( self._starts_new_word(next_token) and not has_answer_continuation and not has_source_continuation ): break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) trace["step"] = len(steps) + 1 trace["chosen_token"] = next_token trace["chosen_text"] = self._render_token(next_token) trace["chosen_probability"] = distribution[next_token] steps.append(trace) if self._should_stop_answer_sequence(decode_state, generated_tokens): break if self._should_stop_after_answer_path_drift(decode_state, generated_tokens): break overflow_budget -= 1 return { "context": context, "prompt": prompt, "reasoning_mode": active_mode, "reasoning_tokens": reasoning_prefix(active_mode), "generation_policy": { "temperature": temperature, "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), "top_p": top_p, "repetition_penalty": repetition_penalty, }, "prompt_tokens": prompt_tokens, "generated_tokens": generated_tokens, "generated_text": self._finalize_generated_text( self._normalize_generated_tool_protocol_text( self._decode_tokens(generated_tokens), context=context, ) ), "generated_token_count": len(generated_tokens), "steps": steps, } def _generation_prompt_tokens(self, context: str, active_mode: str) -> tuple[str, list[str]]: assert self.tokenizer is not None prompt = context if "" in context else f"{context} " prefix = reasoning_prefix(active_mode) prompt_tokens = self.tokenizer.encode(prompt) if ( "" in prompt_tokens and "" not in prompt_tokens and "" not in prefix ): prompt_tokens = [""] + prompt_tokens return prompt, prefix + prompt_tokens def _predict_next_token_distribution_from_tokens( self, context_tokens: list[str], ) -> dict[str, float]: decode_state = self._build_decode_state(context_tokens) return self._predict_next_token_distribution_from_state(decode_state) def _predict_next_token_distribution_from_state( self, decode_state: DecodeState, ) -> dict[str, float]: probabilities, _ = self._score_next_token_from_state( decode_state, include_trace=False, ) return probabilities @staticmethod def _answer_memory_is_confident( *, answer_sequence_match_confidence: float, answer_start_confidence: float, generated_count: int, ) -> bool: if generated_count > 0: return answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR if answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR: return True if answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR: return True if answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR + ANSWER_SEQUENCE_MATCH_FLOOR: return True return ( answer_sequence_match_confidence >= ANSWER_START_MATCH_SUPPORT_FLOOR and answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR and answer_start_confidence <= answer_sequence_match_confidence + ANSWER_START_CONFIDENCE_FLOOR ) @staticmethod def _answer_sequence_should_lock( *, answer_sequence_confidence: float, answer_sequence_match_confidence: float, has_answer_sequence_prior: bool, ) -> bool: if not has_answer_sequence_prior or answer_sequence_confidence <= 0.0: return False if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: return True if ( answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR and answer_sequence_confidence >= 0.30 and answer_sequence_confidence <= 0.65 ): return True return ( answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR and answer_sequence_confidence <= ANSWER_SEQUENCE_SPIKE_CONFIDENCE ) def _prompt_start_readout_is_confident( self, prior: object, tokens: Sequence[str] | None = None, ) -> bool: if self.tokenizer is None: return False if tokens is None: if self.embedding_model is None: return False tokens = self.embedding_model.id_to_token values = prior.tolist() if hasattr(prior, "tolist") else list(prior) if not values or not tokens: return False limit = min(len(values), len(tokens)) if limit <= 0: return False best_index = max(range(limit), key=lambda index: float(values[index])) best_probability = float(values[best_index]) if best_probability < PROMPT_START_READOUT_CONFIDENCE_FLOOR: return False meta = self._generation_token_meta(tokens[best_index]) return ( meta.starts_new_word and bool(meta.alphanumeric) and not meta.structural_punctuation and not meta.structural_symbol ) def _locked_answer_sequence_matches( self, matches: list[tuple[float, int, int]], *, generated_tokens: list[str], temperature: float, answer_sequence_confidence: float, answer_sequence_match_confidence: float, ) -> list[tuple[float, int, int]]: if not matches: return [] if generated_tokens: aligned_matches = [ match for match in matches[:ANSWER_START_TOP_K] if self._answer_sequence_match_has_continuation( match, generated_tokens, ) ] return aligned_matches[:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT] or matches[:1] best_similarity = matches[0][0] near_match_floor = max(ANSWER_SEQUENCE_MATCH_FLOOR, best_similarity - 0.08) varied = [ match for match in matches[:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT] if match[0] >= near_match_floor ] if ( temperature < ANSWER_SEQUENCE_VARIATION_TEMPERATURE and answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR and len(varied) <= 1 ): return matches[:1] return varied or matches[:1] @staticmethod def _answer_sequence_matches_are_ambiguous( matches: Sequence[tuple[float, int, int]], ) -> bool: if len(matches) < 2: return False best_similarity = float(matches[0][0]) if best_similarity < ANSWER_SEQUENCE_MATCH_FLOOR: return False near_match_floor = max(ANSWER_SEQUENCE_MATCH_FLOOR, best_similarity - 0.08) return any( float(match[0]) >= near_match_floor for match in matches[1:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT] ) def _answer_sequence_match_has_continuation( self, match: tuple[float, int, int], generated_tokens: list[str], ) -> bool: if ( self.embedding_model is None or self.answer_sequence_tokens is None or not generated_tokens ): return False similarity, sequence_index, _ = match if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): return False generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] if not generated_ids: return False row = self.answer_sequence_tokens[sequence_index] token_ids = [ int(value) for value in (row.tolist() if hasattr(row, "tolist") else row) if int(value) >= 0 ] if not token_ids: return False next_token_id = self._next_sequence_token_id(token_ids, generated_ids) if next_token_id is None: return False token = self.embedding_model.id_to_token[next_token_id] return self._allowed_answer_sequence_token(token, generated_tokens) def _allowed_answer_sequence_token( self, token: str, generated_tokens: list[str], ) -> bool: assert self.tokenizer is not None if token == self.tokenizer.unk_token: return False if token in self.tokenizer.special_tokens: return self._allowed_generation_token(token, generated_tokens) return True def _should_relax_answer_sequence_memory( self, matches: list[tuple[float, int, int]], answer_sequence_prior: Sequence[float], *, generated_tokens: list[str], temperature: float, ) -> bool: if temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE or not matches: return False if self._is_inside_tool_protocol_continuation(generated_tokens): return False if self._answer_sequence_prior_prefers_tool_protocol(answer_sequence_prior): return False return True def _answer_sequence_prior_prefers_tool_protocol( self, answer_sequence_prior: Sequence[float], ) -> bool: if self.embedding_model is None or not answer_sequence_prior: return False best_index = -1 best_value = 0.0 for index, value in enumerate(answer_sequence_prior): if value > best_value: best_index = index best_value = float(value) return ( best_index >= 0 and best_index < len(self.embedding_model.id_to_token) and best_value > 0.0 and self.embedding_model.id_to_token[best_index] in TOOL_PROTOCOL_TOKENS ) @staticmethod def _answer_start_blend_weights( *, answer_sequence_match_confidence: float, temperature: float = 0.0, ) -> dict[str, float]: if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE: return { "prompt_answer_start": 0.46, "prompt_answer": 0.24, "answer_sequence": 0.10, "answer_start": 0.20, } if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: return { "prompt_answer_start": 0.35, "prompt_answer": 0.10, "answer_sequence": 0.45, "answer_start": 0.10, } if answer_sequence_match_confidence >= 0.40: return { "prompt_answer_start": 0.25, "prompt_answer": 0.12, "answer_sequence": 0.53, "answer_start": 0.10, } return { "prompt_answer_start": 0.08, "prompt_answer": 0.10, "answer_sequence": 0.02, "answer_start": 0.80, } def _score_next_token_from_tokens( self, context_tokens: list[str], *, top_k: int = 5, include_trace: bool = True, ) -> tuple[dict[str, float], dict[str, object]]: decode_state = self._build_decode_state(context_tokens) return self._score_next_token_from_state( decode_state, top_k=top_k, include_trace=include_trace, ) def _score_next_token_from_state( self, decode_state: DecodeState, *, top_k: int = 5, include_trace: bool = True, generated_tokens: list[str] | None = None, temperature: float = 0.0, avoid_token_sequences: Sequence[Sequence[str]] | None = None, ) -> tuple[dict[str, float], dict[str, object]]: assert self.embedding_model is not None assert self.readout_weights is not None generated_tokens = generated_tokens or [] state = self._masked_decode_state(decode_state) logits = self._apply_readout_fast(state) base_probabilities = self._calibrated_softmax(logits) if decode_state.answer_matches is None: decode_state.answer_matches = self._score_answer_matches( decode_state.answer_anchor_state, limit=max(ANSWER_TOP_K, top_k) if include_trace else ANSWER_TOP_K, ) answer_matches = decode_state.answer_matches if decode_state.answer_start_matches is None: decode_state.answer_start_matches = self._score_answer_start_matches( decode_state.answer_anchor_state, limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, ) answer_start_matches = decode_state.answer_start_matches if decode_state.answer_sequence_matches is None: decode_state.answer_sequence_matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, ) answer_sequence_matches = self._filter_avoided_answer_sequence_matches( decode_state.answer_sequence_matches, avoid_token_sequences, ) if not answer_start_matches and answer_sequence_matches: answer_start_matches = self._answer_start_matches_from_sequences( answer_sequence_matches ) decode_state.answer_start_matches = answer_start_matches answer_prior = self._answer_prior_from_matches(answer_matches, generated_tokens) answer_start_prior = self._answer_prior_from_matches(answer_start_matches, generated_tokens) answer_sequence_prior = self._answer_sequence_prior_from_matches( answer_sequence_matches, generated_tokens, temperature=temperature, ) answer_sequence_confidence = max(answer_sequence_prior) if answer_sequence_prior else 0.0 answer_sequence_match_confidence = ( answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 ) answer_start_confidence = answer_start_matches[0][0] if answer_start_matches else 0.0 prompt_copy_is_distinctive = ( not generated_tokens and self._prompt_copy_evidence_is_distinctive(decode_state.context_tokens) ) answer_memory_confident = self._answer_memory_is_confident( answer_sequence_match_confidence=answer_sequence_match_confidence, answer_start_confidence=answer_start_confidence, generated_count=len(generated_tokens), ) if prompt_copy_is_distinctive and not answer_sequence_matches: answer_memory_confident = False has_answer_sequence_prior = any(value > 0.0 for value in answer_sequence_prior) if not answer_memory_confident: zero_prior = [0.0 for _ in self.embedding_model.id_to_token] answer_prior = zero_prior answer_start_prior = zero_prior answer_sequence_prior = zero_prior answer_sequence_confidence = 0.0 has_answer_sequence_prior = False answer_locked = self._answer_sequence_should_lock( answer_sequence_confidence=answer_sequence_confidence, answer_sequence_match_confidence=answer_sequence_match_confidence, has_answer_sequence_prior=has_answer_sequence_prior, ) or ( bool(generated_tokens) and temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE and self._answer_sequence_has_continuation( generated_tokens, answer_sequence_matches, ) ) if self._should_relax_answer_sequence_memory( answer_sequence_matches, answer_sequence_prior, generated_tokens=generated_tokens, temperature=temperature, ): answer_locked = False if decode_state.prompt_answer_prior is None: decode_state.prompt_answer_prior = self._prompt_answer_readout_prior( decode_state.answer_anchor_state, start=False, ) prompt_answer_prior = decode_state.prompt_answer_prior prompt_answer_start_prior = ( decode_state.prompt_answer_start_prior if not generated_tokens else [0.0 for _ in self.embedding_model.id_to_token] ) if not generated_tokens and prompt_answer_start_prior is None: decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior( decode_state.answer_anchor_state, start=True, ) prompt_answer_start_prior = decode_state.prompt_answer_start_prior prompt_start_readout_confident = ( not generated_tokens and prompt_answer_start_prior is not None and self._prompt_start_readout_is_confident(prompt_answer_start_prior) ) prompt_readout_supported = answer_memory_confident and ( answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR or answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR ) if prompt_start_readout_confident: prompt_readout_supported = True if not prompt_readout_supported: prompt_answer_prior = [0.0 for _ in self.embedding_model.id_to_token] prompt_answer_start_prior = [0.0 for _ in self.embedding_model.id_to_token] use_answer_start = ( not generated_tokens and ( any(value > 0.0 for value in answer_start_prior) or any(value > 0.0 for value in prompt_answer_start_prior) ) ) if answer_locked: locked_matches = self._locked_answer_sequence_matches( answer_sequence_matches, generated_tokens=generated_tokens, temperature=temperature, answer_sequence_confidence=answer_sequence_confidence, answer_sequence_match_confidence=answer_sequence_match_confidence, ) answer_sequence_prior = self._answer_sequence_prior_from_matches( locked_matches, generated_tokens, temperature=temperature, ) answer_prior = answer_sequence_prior elif use_answer_start: start_blend = self._answer_start_blend_weights( answer_sequence_match_confidence=answer_sequence_match_confidence, temperature=temperature, ) answer_prior = self._weighted_prior_sum( [ (start_blend["prompt_answer_start"], prompt_answer_start_prior), (start_blend["prompt_answer"], prompt_answer_prior), (start_blend["answer_sequence"], answer_sequence_prior), (start_blend["answer_start"], answer_start_prior), ], ) elif any(value > 0.0 for value in answer_sequence_prior): sequence_weight = ( 0.10 if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE else 0.30 ) answer_prior = self._weighted_prior_sum( [ (0.55, prompt_answer_prior), (sequence_weight, answer_sequence_prior), (0.20, answer_prior), ], ) elif any(value > 0.0 for value in prompt_answer_prior): answer_prior = self._weighted_prior_sum( [ (0.65, prompt_answer_prior), (0.35, answer_prior), ], ) answer_guided = ( max(answer_prior) >= 0.08 if answer_prior else False ) associative_matches = ( [] if use_answer_start or answer_guided else self._score_associative_matches( state, limit=max(ASSOCIATIVE_TOP_K, top_k) if include_trace else ASSOCIATIVE_TOP_K, ) ) associative_prior = ( [0.0 for _ in self.embedding_model.id_to_token] if use_answer_start or answer_guided else self._associative_prior_from_matches(associative_matches) ) transition_prior, transition_order = self._transition_prior_with_order(decode_state.context_tokens) copy_prior = self._copy_prior(decode_state.context_tokens) source_evidence_prior = self._source_evidence_prior( decode_state.context_tokens, generated_tokens, ) preference_prior = self._preference_prior() probabilities, blend_weights = self._blend_probabilities( base_probabilities, answer_prior, associative_prior, transition_prior, copy_prior, source_evidence_prior, preference_prior, transition_order=transition_order, generated_count=len(generated_tokens), answer_locked=answer_locked, answer_guided_start=use_answer_start, copy_guided_start=prompt_copy_is_distinctive, ) probabilities = self._focus_answer_start_probabilities( probabilities, answer_sequence_prior, generated_tokens=generated_tokens, answer_memory_confident=answer_memory_confident, has_answer_sequence_prior=has_answer_sequence_prior, sequence_focus_allowed=answer_sequence_match_confidence >= 0.40 or answer_locked, temperature=temperature, ) distribution = { token: probabilities[index] for index, token in enumerate(self.embedding_model.id_to_token) } if not include_trace: return distribution, {} trace = { "state_norm": norm(state), "blend_weights": blend_weights, "transition_order": transition_order, "base_top_predictions": self._top_entries_from_vector(base_probabilities, top_k), "answer_top_predictions": self._top_entries_from_vector(answer_prior, top_k), "prompt_answer_top_predictions": self._top_entries_from_vector(prompt_answer_prior, top_k), "prompt_answer_start_top_predictions": self._top_entries_from_vector(prompt_answer_start_prior, top_k), "answer_start_top_predictions": self._top_entries_from_vector(answer_start_prior, top_k), "answer_sequence_top_predictions": self._top_entries_from_vector(answer_sequence_prior, top_k), "associative_top_predictions": self._top_entries_from_vector(associative_prior, top_k), "transition_top_predictions": self._top_entries_from_vector(transition_prior, top_k), "copy_top_predictions": self._top_entries_from_vector(copy_prior, top_k), "source_evidence_top_predictions": self._top_entries_from_vector(source_evidence_prior, top_k), "preference_top_predictions": self._top_entries_from_vector(preference_prior, top_k), "final_top_predictions": self._top_entries_from_vector(probabilities, top_k), "associative_matches": [ { "example_index": example_index, "similarity": similarity, **self._token_entry(token_id, similarity), } for similarity, token_id, example_index in associative_matches[:top_k] ], "answer_matches": [ { "example_index": example_index, "similarity": similarity, **self._token_entry(token_id, similarity), } for similarity, token_id, example_index in answer_matches[:top_k] ], "answer_start_matches": [ { "example_index": example_index, "similarity": similarity, **self._token_entry(token_id, similarity), } for similarity, token_id, example_index in answer_start_matches[:top_k] ], "answer_sequence_matches": [ { "example_index": example_index, "similarity": similarity, } for similarity, _, example_index in answer_sequence_matches[:top_k] ], "reasoning_summary": self._build_reasoning_summary( transition_order, blend_weights, ), } return distribution, trace def _score_next_token_array_from_state( self, decode_state: DecodeState, *, include_associative: bool, generated_tokens: list[str] | None = None, temperature: float = 0.0, avoid_token_sequences: Sequence[Sequence[str]] | None = None, ) -> tuple[object, dict[str, float]]: assert np is not None assert self.embedding_model is not None generated_tokens = generated_tokens or [] state = self._masked_decode_state_array(decode_state) logits = self._apply_readout_array(state) base_probabilities = self._calibrated_softmax_array(logits) if decode_state.answer_matches is None: decode_state.answer_matches = self._score_answer_matches(decode_state.answer_anchor_state) answer_prior = np.asarray( self._answer_prior_from_matches( decode_state.answer_matches, generated_tokens, ), dtype=np.float64, ) if decode_state.answer_sequence_matches is None: decode_state.answer_sequence_matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, ) answer_sequence_matches = self._filter_avoided_answer_sequence_matches( decode_state.answer_sequence_matches, avoid_token_sequences, ) if not decode_state.answer_start_matches and answer_sequence_matches: decode_state.answer_start_matches = self._answer_start_matches_from_sequences( answer_sequence_matches ) answer_sequence_prior = np.asarray( self._answer_sequence_prior_from_matches( answer_sequence_matches, generated_tokens, temperature=temperature, ), dtype=np.float64, ) answer_sequence_confidence = ( float(answer_sequence_prior.max()) if answer_sequence_prior.size else 0.0 ) answer_sequence_match_confidence = ( answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 ) if not generated_tokens and decode_state.answer_start_matches is None: decode_state.answer_start_matches = self._score_answer_start_matches( decode_state.answer_anchor_state ) answer_start_confidence = ( decode_state.answer_start_matches[0][0] if not generated_tokens and decode_state.answer_start_matches else 0.0 ) prompt_copy_is_distinctive = ( not generated_tokens and self._prompt_copy_evidence_is_distinctive(decode_state.context_tokens) ) answer_memory_confident = self._answer_memory_is_confident( answer_sequence_match_confidence=answer_sequence_match_confidence, answer_start_confidence=answer_start_confidence, generated_count=len(generated_tokens), ) if prompt_copy_is_distinctive and not answer_sequence_matches: answer_memory_confident = False has_answer_sequence_prior = bool(np.any(answer_sequence_prior > 0.0)) if not answer_memory_confident: answer_prior = np.zeros_like(base_probabilities) answer_sequence_prior = np.zeros_like(base_probabilities) answer_sequence_confidence = 0.0 has_answer_sequence_prior = False answer_locked = self._answer_sequence_should_lock( answer_sequence_confidence=answer_sequence_confidence, answer_sequence_match_confidence=answer_sequence_match_confidence, has_answer_sequence_prior=has_answer_sequence_prior, ) or ( bool(generated_tokens) and temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE and self._answer_sequence_has_continuation( generated_tokens, answer_sequence_matches, ) ) if self._should_relax_answer_sequence_memory( answer_sequence_matches, answer_sequence_prior.tolist(), generated_tokens=generated_tokens, temperature=temperature, ): answer_locked = False if decode_state.prompt_answer_prior is None: decode_state.prompt_answer_prior = self._prompt_answer_readout_prior_array( decode_state.answer_anchor_state, start=False, ) prompt_answer_prior = decode_state.prompt_answer_prior prompt_answer_start_prior = np.zeros_like(base_probabilities) use_answer_start = False if answer_locked: locked_matches = self._locked_answer_sequence_matches( answer_sequence_matches, generated_tokens=generated_tokens, temperature=temperature, answer_sequence_confidence=answer_sequence_confidence, answer_sequence_match_confidence=answer_sequence_match_confidence, ) answer_sequence_prior = np.asarray( self._answer_sequence_prior_from_matches( locked_matches, generated_tokens, temperature=temperature, ), dtype=np.float64, ) answer_prior = answer_sequence_prior elif not generated_tokens: if decode_state.prompt_answer_start_prior is None: decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior_array( decode_state.answer_anchor_state, start=True, ) prompt_answer_start_prior = ( decode_state.prompt_answer_start_prior if decode_state.prompt_answer_start_prior is not None else np.zeros_like(base_probabilities) ) prompt_start_readout_confident = self._prompt_start_readout_is_confident( prompt_answer_start_prior ) prompt_readout_supported = answer_memory_confident and ( answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR or answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR ) if prompt_start_readout_confident: prompt_readout_supported = True if not prompt_readout_supported: prompt_answer_prior = np.zeros_like(base_probabilities) prompt_answer_start_prior = np.zeros_like(base_probabilities) answer_start_prior = np.asarray( self._answer_prior_from_matches( decode_state.answer_start_matches, generated_tokens, ), dtype=np.float64, ) if not answer_memory_confident: answer_start_prior = np.zeros_like(base_probabilities) if np.any(answer_start_prior > 0.0) or np.any(prompt_answer_start_prior > 0.0): start_blend = self._answer_start_blend_weights( answer_sequence_match_confidence=answer_sequence_match_confidence, temperature=temperature, ) answer_prior = self._weighted_prior_sum_array( [ (start_blend["prompt_answer_start"], prompt_answer_start_prior), (start_blend["prompt_answer"], prompt_answer_prior), (start_blend["answer_sequence"], answer_sequence_prior), (start_blend["answer_start"], answer_start_prior), ], ) use_answer_start = True if answer_locked: answer_prior = answer_sequence_prior elif not use_answer_start and np.any(answer_sequence_prior > 0.0): sequence_weight = ( 0.10 if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE else 0.30 ) answer_prior = self._weighted_prior_sum_array( [ (0.55, prompt_answer_prior), (sequence_weight, answer_sequence_prior), (0.20, answer_prior), ], ) elif not use_answer_start and np.any(prompt_answer_prior > 0.0): answer_prior = self._weighted_prior_sum_array( [ (0.65, prompt_answer_prior), (0.35, answer_prior), ], ) answer_guided = bool(answer_prior.size and float(np.max(answer_prior)) >= 0.08) if include_associative and not use_answer_start and not answer_guided: associative_prior = np.asarray( self._associative_prior_from_matches( self._score_associative_matches(state) ), dtype=np.float64, ) else: associative_prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) transition_prior, transition_order = self._transition_prior_array_with_order( decode_state.context_tokens ) copy_prior = self._copy_prior_array(decode_state.context_tokens) source_evidence_prior = self._source_evidence_prior_array( decode_state.context_tokens, generated_tokens, ) preference_prior = self._preference_prior_array() probabilities, blend_weights = self._blend_probability_arrays( base_probabilities, answer_prior, associative_prior, transition_prior, copy_prior, source_evidence_prior, preference_prior, transition_order=transition_order, generated_count=len(generated_tokens), answer_locked=answer_locked, answer_guided_start=use_answer_start, ) probabilities = self._focus_answer_start_probability_array( probabilities, answer_sequence_prior, generated_tokens=generated_tokens, answer_memory_confident=answer_memory_confident, has_answer_sequence_prior=has_answer_sequence_prior, sequence_focus_allowed=answer_sequence_match_confidence >= 0.40 or answer_locked, temperature=temperature, ) return probabilities, blend_weights @staticmethod def _focus_answer_start_probabilities( probabilities: Vector, answer_sequence_prior: Vector, *, generated_tokens: list[str], answer_memory_confident: bool, has_answer_sequence_prior: bool, sequence_focus_allowed: bool | None = None, temperature: float = 0.0, ) -> Vector: if sequence_focus_allowed is None: sequence_focus_allowed = has_answer_sequence_prior if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE: return probabilities if ( generated_tokens or not answer_memory_confident or not has_answer_sequence_prior or not sequence_focus_allowed ): return probabilities if not probabilities or not answer_sequence_prior: return probabilities focused = [ probability if index < len(answer_sequence_prior) and answer_sequence_prior[index] > 0.0 else probability * 0.02 for index, probability in enumerate(probabilities) ] total = sum(focused) if total <= 0.0: return probabilities return [value / total for value in focused] @staticmethod def _focus_answer_start_probability_array( probabilities: object, answer_sequence_prior: object, *, generated_tokens: list[str], answer_memory_confident: bool, has_answer_sequence_prior: bool, sequence_focus_allowed: bool | None = None, temperature: float = 0.0, ) -> object: if sequence_focus_allowed is None: sequence_focus_allowed = has_answer_sequence_prior if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE: return probabilities if ( np is None or generated_tokens or not answer_memory_confident or not has_answer_sequence_prior or not sequence_focus_allowed ): return probabilities values = np.asarray(probabilities, dtype=np.float64) prior = np.asarray(answer_sequence_prior, dtype=np.float64) if values.size == 0 or prior.size != values.size or not np.any(prior > 0.0): return probabilities focused = values.copy() focused[prior <= 0.0] *= 0.02 total = float(focused.sum()) if total <= 0.0: return probabilities return focused / total def _calibrated_softmax( self, logits: Vector, *, scale: float = READOUT_LOGIT_ZSCORE_SCALE, ) -> Vector: if np is not None: return self._calibrated_softmax_array( np.asarray(logits, dtype=np.float64), scale=scale, ).tolist() if not logits: return [] center = mean(logits) variance = mean([(value - center) * (value - center) for value in logits]) spread = variance**0.5 if spread <= 1e-12: return softmax(logits) calibrated = [ max(-20.0, min(20.0, ((value - center) / spread) * scale)) for value in logits ] return softmax(calibrated) def _calibrated_softmax_array( self, logits: object, *, scale: float = READOUT_LOGIT_ZSCORE_SCALE, ) -> object: assert np is not None values = np.asarray(logits, dtype=np.float64) if values.size == 0: return values spread = float(values.std()) if spread > 1e-12: values = ((values - float(values.mean())) / spread) * scale values = np.clip(values, -20.0, 20.0) else: values = values - float(values.max()) values = values - float(values.max()) exponentials = np.exp(values) total = float(exponentials.sum()) if total <= 0.0: return np.full(values.shape, 1.0 / max(1, values.size), dtype=np.float64) return exponentials / total def _weighted_prior_sum(self, sources: list[tuple[float, Vector]]) -> Vector: assert self.embedding_model is not None active_sources = [ (weight, vector) for weight, vector in sources if weight > 0.0 and any(value > 0.0 for value in vector) ] if not active_sources: return [0.0 for _ in self.embedding_model.id_to_token] total_weight = sum(weight for weight, _ in active_sources) merged = [0.0 for _ in self.embedding_model.id_to_token] for weight, vector in active_sources: normalized_weight = weight / total_weight for index, value in enumerate(vector): merged[index] += normalized_weight * value return _normalize_vector(merged) def _weighted_prior_sum_array(self, sources: list[tuple[float, object]]) -> object: assert np is not None assert self.embedding_model is not None active_sources = [ (weight, np.asarray(vector, dtype=np.float64)) for weight, vector in sources if weight > 0.0 and np.any(np.asarray(vector, dtype=np.float64) > 0.0) ] if not active_sources: return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) total_weight = sum(weight for weight, _ in active_sources) merged = np.zeros_like(active_sources[0][1], dtype=np.float64) for weight, vector in active_sources: merged += (weight / total_weight) * vector total = float(merged.sum()) if total > 0.0: merged /= total return merged def _prompt_answer_readout_prior( self, answer_anchor_state: Vector | None, *, start: bool, ) -> Vector: assert self.embedding_model is not None if answer_anchor_state is None: return [0.0 for _ in self.embedding_model.id_to_token] weights = self.prompt_answer_start_weights if start else self.prompt_answer_weights bias = self.prompt_answer_start_bias if start else self.prompt_answer_bias if np is not None: return self._prompt_answer_readout_prior_array( answer_anchor_state, start=start, ).tolist() if not weights: return [0.0 for _ in self.embedding_model.id_to_token] state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) logits = apply_readout(weights, state) if bias: logits = [value + bias[index] for index, value in enumerate(logits)] return self._calibrated_softmax( logits, scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, ) def _prompt_answer_readout_prior_array( self, answer_anchor_state: Vector | None, *, start: bool, ) -> object: assert np is not None assert self.embedding_model is not None if answer_anchor_state is None: return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) weights = ( self.prompt_answer_start_weights_array if start else self.prompt_answer_weights_array ) bias = self.prompt_answer_start_bias_array if start else self.prompt_answer_bias_array if weights is None: return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) state_array = self._center_state_array( self._masked_combined_state_array(answer_anchor_state) ) logits = weights @ state_array if bias is not None and bias.shape == logits.shape: logits = logits + bias return self._calibrated_softmax_array( logits, scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, ) def save(self, path: str | Path) -> None: self._require_fit() assert self.tokenizer is not None assert self.embedding_model is not None assert self.ternary_mask is not None assert self.readout_weights is not None assert self.associative_keys is not None assert self.associative_values is not None assert self.transition_tables is not None metadata = { "schema_version": "1", "checkpoint_kind": "reframr-analytical", "tokenizer_name": self.tokenizer.name, "config": json.dumps(self.config.to_dict(), separators=(",", ":")), "tokenizer": json.dumps(self.tokenizer.to_dict(), separators=(",", ":")), "embedding_id_to_token": json.dumps(self.embedding_model.id_to_token, separators=(",", ":")), "tokenizer_vocab_size": str(self.tokenizer.vocab_size), "transition_table_format": "tensor-v1", } self._refresh_answer_fingerprint_hashes() if np is not None: self._refresh_numeric_caches() transition_tensors = self._transition_table_tensors() tensors = { "embedding_table": self.embedding_model.embeddings, "ternary_scale": [self.ternary_scale], "ternary_mask": self.ternary_mask, "readout_weights": self.readout_weights, "readout_bias": self.readout_bias or [0.0 for _ in self.embedding_model.id_to_token], "prompt_answer_weights": self.prompt_answer_weights if self.prompt_answer_weights is not None else [], "prompt_answer_bias": self.prompt_answer_bias or [0.0 for _ in self.embedding_model.id_to_token], "prompt_answer_start_weights": self.prompt_answer_start_weights if self.prompt_answer_start_weights is not None else [], "prompt_answer_start_bias": self.prompt_answer_start_bias or [0.0 for _ in self.embedding_model.id_to_token], "trace_token_weights": self.trace_token_weights or [1.0 for _ in self.embedding_model.id_to_token], "preference_bias": self.preference_bias or [0.0 for _ in self.embedding_model.id_to_token], "state_offset": self.state_offset or [0.0 for _ in range(self._combined_state_width())], "associative_keys": self.associative_keys, "associative_key_norms": self.associative_key_norms_array if self.associative_key_norms_array is not None else self.associative_key_norms or [], "associative_values": self.associative_values, "answer_keys": self.answer_keys if self.answer_keys is not None else [], "answer_key_norms": self.answer_key_norms_array if self.answer_key_norms_array is not None else self.answer_key_norms or [], "answer_similarity_keys": self.answer_similarity_keys_array if self.answer_similarity_keys_array is not None else [], "answer_similarity_key_norms": self.answer_similarity_key_norms_array if self.answer_similarity_key_norms_array is not None else [], "answer_values": self.answer_values if self.answer_values is not None else [], "answer_start_keys": self.answer_start_keys if self.answer_start_keys is not None else [], "answer_start_key_norms": self.answer_start_key_norms_array if self.answer_start_key_norms_array is not None else self.answer_start_key_norms or [], "answer_start_similarity_keys": self.answer_start_similarity_keys_array if self.answer_start_similarity_keys_array is not None else [], "answer_start_similarity_key_norms": self.answer_start_similarity_key_norms_array if self.answer_start_similarity_key_norms_array is not None else [], "answer_start_values": self.answer_start_values if self.answer_start_values is not None else [], "answer_sequence_keys": self.answer_sequence_keys if self.answer_sequence_keys is not None else [], "answer_sequence_key_norms": self.answer_sequence_key_norms_array if self.answer_sequence_key_norms_array is not None else self.answer_sequence_key_norms or [], "answer_sequence_similarity_keys": self.answer_sequence_similarity_keys_array if self.answer_sequence_similarity_keys_array is not None else [], "answer_sequence_similarity_key_norms": self.answer_sequence_similarity_key_norms_array if self.answer_sequence_similarity_key_norms_array is not None else [], "answer_sequence_prompt_tokens": self.answer_sequence_prompt_tokens if self.answer_sequence_prompt_tokens is not None else [], "answer_sequence_tokens": self.answer_sequence_tokens if self.answer_sequence_tokens is not None else [], "answer_fingerprint_hashes": self._answer_fingerprint_tensor(), **transition_tensors, } write_safetensor_file(path, tensors, metadata=metadata) @classmethod def load(cls, path: str | Path) -> "ReframrModel": checkpoint_path = Path(path) checkpoint = read_safetensor_file( checkpoint_path, arrays=np is not None and checkpoint_path.stat().st_size > 10_000_000, ) metadata = checkpoint.metadata config = ReframrConfig.from_dict(json.loads(metadata["config"])) model = cls(config) model.tokenizer = NativeTokenizer.from_dict(json.loads(metadata["tokenizer"])) id_to_token = [str(token) for token in json.loads(metadata["embedding_id_to_token"])] embedding_table = checkpoint.tensors["embedding_table"] if np is not None and hasattr(embedding_table, "shape"): embeddings = embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False) else: embeddings = [[float(value) for value in row] for row in embedding_table] model.embedding_model = EmbeddingModel( token_to_id={token: index for index, token in enumerate(id_to_token)}, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[], ) model.memory_units = [ AnalyticalMemoryUnit(model.config.state_dim, timescale) for timescale in model.config.timescales ] model.ternary_scale = float(checkpoint.tensors["ternary_scale"][0]) model.ternary_mask = [int(value) for value in checkpoint.tensors["ternary_mask"]] readout_tensor = checkpoint.tensors["readout_weights"] model.readout_weights = ( readout_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) if np is not None and hasattr(readout_tensor, "shape") else [[float(value) for value in row] for row in readout_tensor] ) readout_bias_tensor = checkpoint.tensors.get("readout_bias", []) model.readout_bias = [ float(value) for value in ( readout_bias_tensor.tolist() if hasattr(readout_bias_tensor, "tolist") else readout_bias_tensor ) ] if not model.readout_bias: model.readout_bias = [0.0 for _ in id_to_token] prompt_answer_tensor = checkpoint.tensors.get("prompt_answer_weights", []) model.prompt_answer_weights = ( prompt_answer_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) if np is not None and hasattr(prompt_answer_tensor, "shape") and len(prompt_answer_tensor.shape) == 2 else [[float(value) for value in row] for row in prompt_answer_tensor] ) prompt_answer_bias_tensor = checkpoint.tensors.get("prompt_answer_bias", []) model.prompt_answer_bias = [ float(value) for value in ( prompt_answer_bias_tensor.tolist() if hasattr(prompt_answer_bias_tensor, "tolist") else prompt_answer_bias_tensor ) ] if not model.prompt_answer_bias: model.prompt_answer_bias = [0.0 for _ in id_to_token] prompt_answer_start_tensor = checkpoint.tensors.get("prompt_answer_start_weights", []) model.prompt_answer_start_weights = ( prompt_answer_start_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) if np is not None and hasattr(prompt_answer_start_tensor, "shape") and len(prompt_answer_start_tensor.shape) == 2 else [[float(value) for value in row] for row in prompt_answer_start_tensor] ) prompt_answer_start_bias_tensor = checkpoint.tensors.get("prompt_answer_start_bias", []) model.prompt_answer_start_bias = [ float(value) for value in ( prompt_answer_start_bias_tensor.tolist() if hasattr(prompt_answer_start_bias_tensor, "tolist") else prompt_answer_start_bias_tensor ) ] if not model.prompt_answer_start_bias: model.prompt_answer_start_bias = [0.0 for _ in id_to_token] trace_weight_tensor = checkpoint.tensors.get("trace_token_weights", []) model.trace_token_weights = [ float(value) for value in ( trace_weight_tensor.tolist() if hasattr(trace_weight_tensor, "tolist") else trace_weight_tensor ) ] if not model.trace_token_weights: model.trace_token_weights = [ 1.0 if token in TOOL_PROTOCOL_TOKENS else 0.0 if token in model.tokenizer.special_tokens else 1.0 for token in id_to_token ] preference_bias_tensor = checkpoint.tensors.get("preference_bias", []) model.preference_bias = [ float(value) for value in ( preference_bias_tensor.tolist() if hasattr(preference_bias_tensor, "tolist") else preference_bias_tensor ) ] if not model.preference_bias: model.preference_bias = [0.0 for _ in id_to_token] state_offset_tensor = checkpoint.tensors.get("state_offset", []) model.state_offset = [ float(value) for value in ( state_offset_tensor.tolist() if hasattr(state_offset_tensor, "tolist") else state_offset_tensor ) ] if not model.state_offset: model.state_offset = [0.0 for _ in range(model._combined_state_width())] def _runtime_vector_tensor(name: str) -> object | None: tensor = checkpoint.tensors.get(name, []) if np is not None and hasattr(tensor, "shape"): if len(tensor.shape) == 1 and int(tensor.shape[0]) > 0: return tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) return None values = tensor.tolist() if hasattr(tensor, "tolist") else tensor return [float(value) for value in values] if values else None def _runtime_matrix_tensor(name: str) -> object | None: tensor = checkpoint.tensors.get(name, []) if ( np is not None and hasattr(tensor, "shape") and len(tensor.shape) == 2 and int(tensor.shape[0]) > 0 ): return tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) return None associative_tensor = checkpoint.tensors.get("associative_keys", []) model.associative_keys = ( associative_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) if np is not None and hasattr(associative_tensor, "shape") else [[float(value) for value in row] for row in associative_tensor] ) cached_associative_key_norms = _runtime_vector_tensor("associative_key_norms") if cached_associative_key_norms is not None: model.associative_key_norms = cached_associative_key_norms elif np is not None and hasattr(model.associative_keys, "shape"): model.associative_key_norms = None else: model.associative_key_norms = [norm(key) for key in model.associative_keys] raw_associative_values = checkpoint.tensors.get("associative_values", []) model.associative_values = [ int(value) for value in ( raw_associative_values.tolist() if hasattr(raw_associative_values, "tolist") else raw_associative_values ) ] answer_tensor = checkpoint.tensors.get("answer_keys", []) if np is not None and hasattr(answer_tensor, "shape"): model.answer_keys = ( answer_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) if len(answer_tensor.shape) == 2 else [] ) else: model.answer_keys = [[float(value) for value in row] for row in answer_tensor] if ( np is not None and hasattr(model.answer_keys, "shape") and len(model.answer_keys.shape) == 2 ): model.answer_key_norms = _runtime_vector_tensor("answer_key_norms") else: model.answer_key_norms = ( _runtime_vector_tensor("answer_key_norms") or [norm(key) for key in model.answer_keys] ) raw_answer_values = checkpoint.tensors.get("answer_values", []) model.answer_values = [ int(value) for value in ( raw_answer_values.tolist() if hasattr(raw_answer_values, "tolist") else raw_answer_values ) ] answer_start_tensor = checkpoint.tensors.get("answer_start_keys", []) if np is not None and hasattr(answer_start_tensor, "shape"): model.answer_start_keys = ( answer_start_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) if len(answer_start_tensor.shape) == 2 else [] ) else: model.answer_start_keys = [ [float(value) for value in row] for row in answer_start_tensor ] if ( np is not None and hasattr(model.answer_start_keys, "shape") and len(model.answer_start_keys.shape) == 2 ): model.answer_start_key_norms = _runtime_vector_tensor("answer_start_key_norms") else: model.answer_start_key_norms = ( _runtime_vector_tensor("answer_start_key_norms") or [norm(key) for key in model.answer_start_keys] ) raw_answer_start_values = checkpoint.tensors.get("answer_start_values", []) model.answer_start_values = [ int(value) for value in ( raw_answer_start_values.tolist() if hasattr(raw_answer_start_values, "tolist") else raw_answer_start_values ) ] answer_sequence_tensor = checkpoint.tensors.get("answer_sequence_keys", []) if np is not None and hasattr(answer_sequence_tensor, "shape"): model.answer_sequence_keys = ( answer_sequence_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) if len(answer_sequence_tensor.shape) == 2 else [] ) else: model.answer_sequence_keys = [ [float(value) for value in row] for row in answer_sequence_tensor ] if ( np is not None and hasattr(model.answer_sequence_keys, "shape") and len(model.answer_sequence_keys.shape) == 2 ): model.answer_sequence_key_norms = _runtime_vector_tensor("answer_sequence_key_norms") else: model.answer_sequence_key_norms = ( _runtime_vector_tensor("answer_sequence_key_norms") or [norm(key) for key in model.answer_sequence_keys] ) raw_answer_sequence_prompt_tokens = checkpoint.tensors.get("answer_sequence_prompt_tokens", []) if np is not None and hasattr(raw_answer_sequence_prompt_tokens, "shape"): model.answer_sequence_prompt_tokens = raw_answer_sequence_prompt_tokens.astype(int, copy=False) else: model.answer_sequence_prompt_tokens = [ [int(value) for value in row] for row in raw_answer_sequence_prompt_tokens ] raw_answer_sequence_tokens = checkpoint.tensors.get("answer_sequence_tokens", []) if np is not None and hasattr(raw_answer_sequence_tokens, "shape"): model.answer_sequence_tokens = raw_answer_sequence_tokens.astype(int, copy=False) else: model.answer_sequence_tokens = [ [int(value) for value in row] for row in raw_answer_sequence_tokens ] model.answer_sequence_token_id_rows = None raw_fingerprints = checkpoint.tensors.get("answer_fingerprint_hashes", []) model.answer_fingerprint_hashes = model._coerce_answer_fingerprint_hashes( raw_fingerprints ) model.answer_fingerprint_token_lengths = None model.answer_fingerprint_token_sequences_by_length = None if not model.answer_fingerprint_hashes: model._refresh_answer_fingerprint_hashes() model.answer_similarity_keys_array = _runtime_matrix_tensor("answer_similarity_keys") model.answer_similarity_key_norms_array = _runtime_vector_tensor("answer_similarity_key_norms") model.answer_start_similarity_keys_array = _runtime_matrix_tensor("answer_start_similarity_keys") model.answer_start_similarity_key_norms_array = _runtime_vector_tensor("answer_start_similarity_key_norms") model.answer_sequence_similarity_keys_array = _runtime_matrix_tensor("answer_sequence_similarity_keys") model.answer_sequence_similarity_key_norms_array = _runtime_vector_tensor("answer_sequence_similarity_key_norms") model.transition_id_tables = model._deserialize_transition_id_tables_from_tensors( checkpoint.tensors ) if model.transition_id_tables is not None: model.transition_tables = {order: {} for order in sorted(TRANSITION_ORDERS)} else: model.transition_tables = model._deserialize_transition_tables( json.loads(metadata.get("transition_tables", "{}")) ) model._refresh_numeric_caches() return model def _collect_training_examples( self, tokens: list[str], ) -> tuple[list[Vector], list[Vector], list[int]]: assert self.embedding_model is not None if np is not None: hidden_states = [ np.zeros(self.config.state_dim, dtype=np.float64) for _ in self.config.timescales ] context_traces = [ np.zeros(self.config.embedding_dim, dtype=np.float64) for _ in self.config.timescales ] zero_embedding: Vector | object = np.zeros(self.config.embedding_dim, dtype=np.float64) else: hidden_states = [zeros_vector(self.config.state_dim) for _ in self.config.timescales] context_traces = [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] zero_embedding = zeros_vector(self.config.embedding_dim) states: list[Vector] = [] labels: list[Vector] = [] label_ids: list[int] = [] token_ids = [ self.embedding_model.token_to_id.get(token, -1) for token in tokens ] example_count = max(0, len(tokens) - 1) stride = 1 if self.config.max_training_examples and example_count > self.config.max_training_examples: stride = max( 1, (example_count + self.config.max_training_examples - 1) // self.config.max_training_examples, ) for index in range(len(tokens) - 1): token = tokens[index] token_id = token_ids[index] embedding = ( self.embedding_model.embeddings[token_id] if token_id >= 0 else zero_embedding ) trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) hidden_states, context_traces, combined_state = self._step_hidden_states_from_embedding( hidden_states, context_traces, embedding, trace_embedding=trace_embedding, ) if stride > 1 and index % stride != 0 and index != len(tokens) - 2: continue states.append(combined_state) next_token_id = token_ids[index + 1] labels.append(self._one_hot_from_id(next_token_id)) label_ids.append(next_token_id) if self.config.max_training_examples and len(states) > self.config.max_training_examples: states = states[: self.config.max_training_examples] labels = labels[: self.config.max_training_examples] label_ids = label_ids[: self.config.max_training_examples] return states, labels, label_ids def _is_punctuation_piece(self, piece: str) -> bool: return bool(piece) and all(character in string.punctuation for character in piece) def _encode_context(self, tokens: list[str]) -> Vector: return self._masked_decode_state(self._build_decode_state(tokens)) def _build_decode_state(self, tokens: list[str]) -> DecodeState: assert self.memory_units is not None state = DecodeState( hidden_states=( [ np.zeros(self.config.state_dim, dtype=np.float64) for _ in self.config.timescales ] if np is not None else [zeros_vector(self.config.state_dim) for _ in self.config.timescales] ), context_traces=( [ np.zeros(self.config.embedding_dim, dtype=np.float64) for _ in self.config.timescales ] if np is not None else [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] ), combined_state=self._zero_combined_state(), context_tokens=[], ) for token in tokens: self._advance_decode_state(state, token) self._apply_sparse_context_anchor(state) return state def _advance_decode_state(self, state: DecodeState, token: str) -> DecodeState: next_hidden_states, next_context_traces, combined_state = self._step_hidden_states( state.hidden_states, state.context_traces, token, ) state.hidden_states = next_hidden_states state.context_traces = next_context_traces state.combined_state = combined_state state.context_tokens.append(token) if token == "": state.answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:] state.answer_matches = None state.answer_start_matches = None state.answer_sequence_matches = None state.prompt_answer_prior = None state.prompt_answer_start_prior = None return state def _apply_sparse_context_anchor(self, state: DecodeState) -> None: if ( np is None or self.embedding_model is None or state.answer_anchor_state is None or not state.context_tokens ): return answer_index = _last_index(state.context_tokens, "") if answer_index is None or answer_index <= 0: return context_ids = self._long_context_sparse_token_ids(state.context_tokens[:answer_index]) if len(context_ids) < SPARSE_CONTEXT_MIN_TOKENS: return query_id = context_ids[-1] embeddings = np.asarray(self.embedding_model.embeddings, dtype=np.float32) if embeddings.ndim != 2 or embeddings.shape[0] == 0: return selector = HashedSparseAttention( embeddings, k_neighbors=min(SPARSE_CONTEXT_TOP_K, len(context_ids)), hash_bits=SPARSE_CONTEXT_HASH_BITS, probe_radius=SPARSE_CONTEXT_PROBE_RADIUS, candidate_multiplier=SPARSE_CONTEXT_CANDIDATE_MULTIPLIER, ) token_ids = np.asarray(context_ids, dtype=np.int64) selector.build_context_index(token_ids) selection = selector.select_positions_cached(query_id) if not selection.positions: return selected_ids = token_ids[np.asarray(selection.positions, dtype=np.int64)] selected_embeddings = embeddings[selected_ids] scores = np.asarray(selection.scores, dtype=np.float32) scores -= float(scores.max()) weights = np.exp(scores) weights /= max(float(weights.sum()), 1e-8) sparse_embedding = weights @ selected_embeddings blended_anchor = self._blend_sparse_embedding_into_combined_state( state.answer_anchor_state, sparse_embedding, state_dim=self.config.state_dim, embedding_dim=self.config.embedding_dim, timescale_count=len(self.config.timescales), blend=SPARSE_CONTEXT_TRACE_BLEND, ) state.answer_anchor_state = blended_anchor if state.context_tokens and state.context_tokens[-1] == "": state.combined_state = blended_anchor.copy() state.answer_matches = None state.answer_start_matches = None state.answer_sequence_matches = None state.prompt_answer_prior = None state.prompt_answer_start_prior = None def _long_context_sparse_token_ids(self, tokens: Sequence[str]) -> list[int]: assert self.embedding_model is not None special_tokens = self.tokenizer.special_tokens if self.tokenizer is not None else set() ids: list[int] = [] for token in tokens: if token in special_tokens and token not in TOOL_PROTOCOL_TOKENS: continue token_id = self._token_id_for_token(token) if token_id >= 0: ids.append(token_id) return ids @staticmethod def _blend_sparse_embedding_into_combined_state( combined_state: Vector, sparse_embedding: object, *, state_dim: int, embedding_dim: int, timescale_count: int, blend: float, ) -> Vector: if np is None: return combined_state state_array = np.asarray(combined_state, dtype=np.float32).copy() sparse_array = np.asarray(sparse_embedding, dtype=np.float32) if sparse_array.shape[0] != embedding_dim: return combined_state block_width = state_dim + embedding_dim expected_width = block_width * timescale_count if state_array.shape[0] != expected_width: return combined_state alpha = min(1.0, max(0.0, float(blend))) for block_index in range(timescale_count): trace_start = block_index * block_width + state_dim trace_end = trace_start + embedding_dim state_array[trace_start:trace_end] = ( (1.0 - alpha) * state_array[trace_start:trace_end] + alpha * sparse_array ) return state_array.tolist() def _masked_decode_state(self, state: DecodeState) -> Vector: assert self.ternary_mask is not None return apply_ternary_mask(state.combined_state, self.ternary_mask, self.ternary_scale) def _masked_combined_state(self, combined_state: Vector) -> Vector: assert self.ternary_mask is not None return apply_ternary_mask(combined_state, self.ternary_mask, self.ternary_scale) def _masked_decode_state_array(self, state: DecodeState) -> object: assert np is not None if self.ternary_mask_array is None: return np.asarray(self._masked_decode_state(state), dtype=RUNTIME_ARRAY_DTYPE) return ( np.asarray(state.combined_state, dtype=RUNTIME_ARRAY_DTYPE) * self.ternary_scale * self.ternary_mask_array ) def _masked_combined_state_array(self, combined_state: Vector) -> object: assert np is not None if self.ternary_mask_array is None: return np.asarray(self._masked_combined_state(combined_state), dtype=RUNTIME_ARRAY_DTYPE) return ( np.asarray(combined_state, dtype=RUNTIME_ARRAY_DTYPE) * self.ternary_scale * self.ternary_mask_array ) def _center_state_vector(self, state: Vector) -> Vector: if not self.state_offset or len(self.state_offset) != len(state): return state return [value - self.state_offset[index] for index, value in enumerate(state)] def _center_state_array(self, state: object) -> object: assert np is not None state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) if self.state_offset_array is None or self.state_offset_array.shape != state_array.shape: return state_array return state_array - self.state_offset_array def _zero_combined_state(self) -> Vector: return [0.0 for _ in range(self._combined_state_width())] def _combined_state_width(self) -> int: return (self.config.state_dim + self.config.embedding_dim) * len(self.config.timescales) def _derive_trace_token_weights_from_counts(self, token_counts: dict[str, float]) -> Vector: assert self.embedding_model is not None assert self.tokenizer is not None counts = [ float(token_counts.get(token, 0.0)) for token in self.embedding_model.id_to_token ] positive_counts = sorted(value for value in counts if value > 0.0) reference = ( positive_counts[len(positive_counts) // 2] if positive_counts else 1.0 ) weights: Vector = [] for token, count in zip(self.embedding_model.id_to_token, counts): if token in TOOL_PROTOCOL_TOKENS: weights.append(1.0) elif token in self.tokenizer.special_tokens: weights.append(0.0) elif count <= 0.0: weights.append(1.0) else: weight = (reference / count) ** 0.75 weights.append(max(0.08, min(4.8, weight))) return weights def _token_id_for_token(self, token: str) -> int: assert self.embedding_model is not None token_id = self.embedding_model.token_to_id.get(token) if token_id is None and token.lower() != token: token_id = self.embedding_model.token_to_id.get(token.lower()) return int(token_id) if token_id is not None else -1 def _trace_embedding_from_token_id( self, embedding: Vector | object, token_id: int, ) -> Vector | object: if token_id < 0: return embedding if self.trace_embedding_table_array is not None: return self.trace_embedding_table_array[token_id] weight = self.trace_token_weights[token_id] if self.trace_token_weights is not None else 1.0 dimension = self.config.embedding_dim if hasattr(embedding, "shape"): trace_embedding = embedding * weight for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: bucket = (token_id * bucket_multiplier + bucket_offset) % dimension sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 trace_embedding[bucket] += weight * TRACE_IDENTITY_SCALE * sign return trace_embedding trace_values = [float(value) * weight for value in embedding] for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: bucket = (token_id * bucket_multiplier + bucket_offset) % dimension sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 trace_values[bucket] += weight * TRACE_IDENTITY_SCALE * sign return trace_values def _build_trace_embedding_table_array(self, embedding_array: object) -> object | None: if np is None or self.trace_token_weights is None: return None values = np.asarray(embedding_array, dtype=np.float64) if values.size == 0 or len(values.shape) != 2: return None weights = np.asarray(self.trace_token_weights, dtype=np.float64) if weights.shape[0] != values.shape[0]: return None trace_values = values * weights[:, None] if values.shape[1] <= 0: return trace_values token_ids = np.arange(values.shape[0], dtype=np.int64) for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: buckets = ((token_ids * bucket_multiplier + bucket_offset) % values.shape[1]).astype( np.int64, copy=False, ) signs = np.where( ((token_ids * sign_multiplier + sign_offset) & 1) == 0, 1.0, -1.0, ) np.add.at(trace_values, (token_ids, buckets), weights * TRACE_IDENTITY_SCALE * signs) return trace_values def _runtime_key_norms_array( self, key_array: object | None, key_norms: list[float] | None, ) -> object | None: assert np is not None if key_norms is not None and len(key_norms) > 0: return np.asarray(key_norms, dtype=RUNTIME_ARRAY_DTYPE) if key_array is None: return None keys = np.asarray(key_array, dtype=RUNTIME_ARRAY_DTYPE) if len(keys.shape) != 2 or keys.shape[0] == 0: return None return np.linalg.norm(keys, axis=1).astype(RUNTIME_ARRAY_DTYPE, copy=False) def _runtime_vector_cache(self, cached: object | None, length: int) -> object | None: assert np is not None if cached is None or not hasattr(cached, "shape"): return None array = np.asarray(cached, dtype=RUNTIME_ARRAY_DTYPE) if len(array.shape) != 1 or int(array.shape[0]) != int(length): return None return array def _runtime_matrix_cache( self, cached: object | None, rows: int, width: int, ) -> object | None: assert np is not None if cached is None or not hasattr(cached, "shape"): return None array = np.asarray(cached, dtype=RUNTIME_ARRAY_DTYPE) if ( len(array.shape) != 2 or int(array.shape[0]) != int(rows) or int(array.shape[1]) != int(width) ): return None return array def _refresh_numeric_caches(self) -> None: if np is None: self.ternary_mask_array = None self.readout_weights_array = None self.readout_bias_array = None self.prompt_answer_weights_array = None self.prompt_answer_bias_array = None self.prompt_answer_start_weights_array = None self.prompt_answer_start_bias_array = None self.trace_token_weights_array = None self.trace_embedding_table_array = None self.preference_bias_array = None self.preference_valid_mask_array = None self.state_offset_array = None self.associative_keys_array = None self.associative_key_norms_array = None self.associative_values_array = None self.associative_valid_mask_array = None self.answer_keys_array = None self.answer_key_norms_array = None self.answer_similarity_keys_array = None self.answer_similarity_key_norms_array = None self.answer_similarity_mask_array = None self.answer_values_array = None self.answer_valid_mask_array = None self.answer_start_keys_array = None self.answer_start_key_norms_array = None self.answer_start_similarity_keys_array = None self.answer_start_similarity_key_norms_array = None self.answer_start_values_array = None self.answer_start_valid_mask_array = None self.answer_sequence_keys_array = None self.answer_sequence_key_norms_array = None self.answer_sequence_similarity_keys_array = None self.answer_sequence_similarity_key_norms_array = None self.answer_sequence_prompt_tokens_array = None self.answer_sequence_tokens_array = None self.answer_sequence_prompt_weight_maps = None self.answer_sequence_prompt_weight_norms = None self.answer_sequence_prompt_bigram_sets = None self.answer_sequence_prompt_trigram_sets = None self.answer_sequence_prompt_number_sets = None self.answer_sequence_prompt_inverted_index = None self._refresh_answer_sequence_prompt_overlap_cache() self.prompt_overlap_valid_token_mask_array = None return cached_associative_key_norms_array = self.associative_key_norms_array cached_answer_key_norms_array = self.answer_key_norms_array cached_answer_similarity_keys_array = self.answer_similarity_keys_array cached_answer_similarity_key_norms_array = self.answer_similarity_key_norms_array cached_answer_start_key_norms_array = self.answer_start_key_norms_array cached_answer_start_similarity_keys_array = self.answer_start_similarity_keys_array cached_answer_start_similarity_key_norms_array = self.answer_start_similarity_key_norms_array cached_answer_sequence_key_norms_array = self.answer_sequence_key_norms_array cached_answer_sequence_similarity_keys_array = self.answer_sequence_similarity_keys_array cached_answer_sequence_similarity_key_norms_array = self.answer_sequence_similarity_key_norms_array self.ternary_mask_array = ( np.asarray(self.ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) if self.ternary_mask is not None else None ) self.readout_weights_array = ( np.asarray(self.readout_weights, dtype=RUNTIME_ARRAY_DTYPE) if self.readout_weights is not None else None ) self.readout_bias_array = ( np.asarray(self.readout_bias, dtype=RUNTIME_ARRAY_DTYPE) if self.readout_bias is not None else None ) self.prompt_answer_weights_array = ( np.asarray(self.prompt_answer_weights, dtype=RUNTIME_ARRAY_DTYPE) if self.prompt_answer_weights is not None and len(self.prompt_answer_weights) > 0 else None ) self.prompt_answer_bias_array = ( np.asarray(self.prompt_answer_bias, dtype=RUNTIME_ARRAY_DTYPE) if self.prompt_answer_bias is not None else None ) self.prompt_answer_start_weights_array = ( np.asarray(self.prompt_answer_start_weights, dtype=RUNTIME_ARRAY_DTYPE) if self.prompt_answer_start_weights is not None and len(self.prompt_answer_start_weights) > 0 else None ) self.prompt_answer_start_bias_array = ( np.asarray(self.prompt_answer_start_bias, dtype=RUNTIME_ARRAY_DTYPE) if self.prompt_answer_start_bias is not None else None ) self.trace_token_weights_array = ( np.asarray(self.trace_token_weights, dtype=RUNTIME_ARRAY_DTYPE) if self.trace_token_weights is not None else None ) trace_embedding_table = ( self._build_trace_embedding_table_array(self.embedding_model.embeddings) if self.embedding_model is not None and self.trace_token_weights is not None else None ) self.trace_embedding_table_array = ( trace_embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False) if trace_embedding_table is not None else None ) self.preference_bias_array = ( np.asarray(self.preference_bias, dtype=RUNTIME_ARRAY_DTYPE) if self.preference_bias is not None else None ) self.preference_valid_mask_array = ( np.asarray( [ self._eligible_preference_token(token) for token in self.embedding_model.id_to_token ], dtype=bool, ) if self.embedding_model is not None and self.tokenizer is not None else None ) self.state_offset_array = ( np.asarray(self.state_offset, dtype=RUNTIME_ARRAY_DTYPE) if self.state_offset is not None else None ) self.associative_keys_array = ( np.asarray(self.associative_keys, dtype=RUNTIME_ARRAY_DTYPE) if self.associative_keys is not None and len(self.associative_keys) > 0 else None ) associative_key_norms_cache = ( self._runtime_vector_cache( cached_associative_key_norms_array, int(self.associative_keys_array.shape[0]), ) if self.associative_keys_array is not None else None ) self.associative_key_norms_array = ( associative_key_norms_cache if associative_key_norms_cache is not None else self._runtime_key_norms_array( self.associative_keys_array, self.associative_key_norms, ) ) self.associative_values_array = ( np.asarray(self.associative_values, dtype=np.int64) if self.associative_values is not None and len(self.associative_values) > 0 else None ) self.associative_valid_mask_array = ( self.associative_values_array >= 0 if self.associative_values_array is not None else None ) self.answer_keys_array = ( np.asarray(self.answer_keys, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_keys is not None and len(self.answer_keys) > 0 else None ) answer_key_norms_cache = ( self._runtime_vector_cache( cached_answer_key_norms_array, int(self.answer_keys_array.shape[0]), ) if self.answer_keys_array is not None else None ) self.answer_key_norms_array = ( answer_key_norms_cache if answer_key_norms_cache is not None else self._runtime_key_norms_array( self.answer_keys_array, self.answer_key_norms, ) ) self.answer_similarity_keys_array = None self.answer_similarity_key_norms_array = None self.answer_similarity_mask_array = None if self.answer_keys_array is not None and len(self.answer_keys_array.shape) == 2: width = int(self.answer_keys_array.shape[1]) block_width = self.config.state_dim + self.config.embedding_dim expected_width = block_width * len(self.config.timescales) if block_width > 0 and width == expected_width: mask = np.zeros(width, dtype=RUNTIME_ARRAY_DTYPE) for scale_index in range(len(self.config.timescales)): start = scale_index * block_width + self.config.state_dim end = start + self.config.embedding_dim mask[start:end] = 1.0 self.answer_similarity_mask_array = mask answer_similarity_keys_cache = self._runtime_matrix_cache( cached_answer_similarity_keys_array, int(self.answer_keys_array.shape[0]), width, ) answer_similarity_key_norms_cache = self._runtime_vector_cache( cached_answer_similarity_key_norms_array, int(self.answer_keys_array.shape[0]), ) if ( answer_similarity_keys_cache is not None and answer_similarity_key_norms_cache is not None ): self.answer_similarity_keys_array = answer_similarity_keys_cache self.answer_similarity_key_norms_array = answer_similarity_key_norms_cache else: self.answer_similarity_keys_array = self.answer_keys_array * mask[None, :] self.answer_similarity_key_norms_array = np.linalg.norm( self.answer_similarity_keys_array, axis=1, ).astype(RUNTIME_ARRAY_DTYPE, copy=False) self.answer_values_array = ( np.asarray(self.answer_values, dtype=np.int64) if self.answer_values is not None and len(self.answer_values) > 0 else None ) self.answer_valid_mask_array = ( self.answer_values_array >= 0 if self.answer_values_array is not None else None ) self.answer_start_keys_array = ( np.asarray(self.answer_start_keys, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_start_keys is not None and len(self.answer_start_keys) > 0 else None ) answer_start_key_norms_cache = ( self._runtime_vector_cache( cached_answer_start_key_norms_array, int(self.answer_start_keys_array.shape[0]), ) if self.answer_start_keys_array is not None else None ) self.answer_start_key_norms_array = ( answer_start_key_norms_cache if answer_start_key_norms_cache is not None else self._runtime_key_norms_array( self.answer_start_keys_array, self.answer_start_key_norms, ) ) self.answer_start_similarity_keys_array = None self.answer_start_similarity_key_norms_array = None if ( self.answer_start_keys_array is not None and len(self.answer_start_keys_array.shape) == 2 and self.answer_similarity_mask_array is not None and int(self.answer_start_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) ): answer_start_similarity_keys_cache = self._runtime_matrix_cache( cached_answer_start_similarity_keys_array, int(self.answer_start_keys_array.shape[0]), int(self.answer_start_keys_array.shape[1]), ) answer_start_similarity_key_norms_cache = self._runtime_vector_cache( cached_answer_start_similarity_key_norms_array, int(self.answer_start_keys_array.shape[0]), ) if ( answer_start_similarity_keys_cache is not None and answer_start_similarity_key_norms_cache is not None ): self.answer_start_similarity_keys_array = answer_start_similarity_keys_cache self.answer_start_similarity_key_norms_array = answer_start_similarity_key_norms_cache else: self.answer_start_similarity_keys_array = ( self.answer_start_keys_array * self.answer_similarity_mask_array[None, :] ) self.answer_start_similarity_key_norms_array = np.linalg.norm( self.answer_start_similarity_keys_array, axis=1, ).astype(RUNTIME_ARRAY_DTYPE, copy=False) self.answer_start_values_array = ( np.asarray(self.answer_start_values, dtype=np.int64) if self.answer_start_values is not None and len(self.answer_start_values) > 0 else None ) self.answer_start_valid_mask_array = ( self.answer_start_values_array >= 0 if self.answer_start_values_array is not None else None ) self.answer_sequence_keys_array = ( np.asarray(self.answer_sequence_keys, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_sequence_keys is not None and len(self.answer_sequence_keys) > 0 else None ) answer_sequence_key_norms_cache = ( self._runtime_vector_cache( cached_answer_sequence_key_norms_array, int(self.answer_sequence_keys_array.shape[0]), ) if self.answer_sequence_keys_array is not None else None ) self.answer_sequence_key_norms_array = ( answer_sequence_key_norms_cache if answer_sequence_key_norms_cache is not None else self._runtime_key_norms_array( self.answer_sequence_keys_array, self.answer_sequence_key_norms, ) ) self.answer_sequence_similarity_keys_array = None self.answer_sequence_similarity_key_norms_array = None if ( self.answer_sequence_keys_array is not None and len(self.answer_sequence_keys_array.shape) == 2 and self.answer_similarity_mask_array is not None and int(self.answer_sequence_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) ): answer_sequence_similarity_keys_cache = self._runtime_matrix_cache( cached_answer_sequence_similarity_keys_array, int(self.answer_sequence_keys_array.shape[0]), int(self.answer_sequence_keys_array.shape[1]), ) answer_sequence_similarity_key_norms_cache = self._runtime_vector_cache( cached_answer_sequence_similarity_key_norms_array, int(self.answer_sequence_keys_array.shape[0]), ) if ( answer_sequence_similarity_keys_cache is not None and answer_sequence_similarity_key_norms_cache is not None ): self.answer_sequence_similarity_keys_array = answer_sequence_similarity_keys_cache self.answer_sequence_similarity_key_norms_array = answer_sequence_similarity_key_norms_cache else: self.answer_sequence_similarity_keys_array = ( self.answer_sequence_keys_array * self.answer_similarity_mask_array[None, :] ) self.answer_sequence_similarity_key_norms_array = np.linalg.norm( self.answer_sequence_similarity_keys_array, axis=1, ).astype(RUNTIME_ARRAY_DTYPE, copy=False) self.answer_sequence_tokens_array = ( np.asarray(self.answer_sequence_tokens, dtype=np.int64) if self.answer_sequence_tokens is not None and len(self.answer_sequence_tokens) > 0 else None ) self.answer_sequence_prompt_tokens_array = ( np.asarray(self.answer_sequence_prompt_tokens, dtype=np.int64) if self.answer_sequence_prompt_tokens is not None and len(self.answer_sequence_prompt_tokens) > 0 else None ) self.prompt_overlap_valid_token_mask_array = None if not self._defer_answer_sequence_prompt_overlap_cache(): self._refresh_answer_sequence_prompt_overlap_cache() else: self._refresh_answer_sequence_prompt_overlap_cache() def _defer_answer_sequence_prompt_overlap_cache(self) -> bool: if self.answer_sequence_prompt_tokens is None: return False try: row_count = len(self.answer_sequence_prompt_tokens) except TypeError: return False return ( row_count > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT and np is not None and self.answer_sequence_prompt_tokens_array is not None ) def _prompt_overlap_valid_token_mask(self) -> object | None: if np is None or self.embedding_model is None: return None if ( self.prompt_overlap_valid_token_mask_array is not None and int(self.prompt_overlap_valid_token_mask_array.shape[0]) == len(self.embedding_model.id_to_token) ): return self.prompt_overlap_valid_token_mask_array mask = np.fromiter( ( not self._should_skip_prompt_overlap_token(token) for token in self.embedding_model.id_to_token ), dtype=bool, count=len(self.embedding_model.id_to_token), ) self.prompt_overlap_valid_token_mask_array = mask return mask def _answer_prompt_row_ids_from_array(self) -> tuple[dict[int, list[int]], list[list[int]] | None] | None: if ( np is None or self.answer_sequence_prompt_tokens_array is None or self.trace_token_weights is None or self.embedding_model is None ): return None rows = np.asarray(self.answer_sequence_prompt_tokens_array, dtype=np.int64) if len(rows.shape) != 2 or rows.size == 0: return {}, [] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None vocab_size = len(self.trace_token_weights) if vocab_size <= 0: return {}, [] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None valid_token_mask = self._prompt_overlap_valid_token_mask() if valid_token_mask is None: return None bounded = (rows >= 0) & (rows < vocab_size) clipped = np.clip(rows, 0, max(0, vocab_size - 1)) bounded &= valid_token_mask[clipped] row_positions, column_positions = np.nonzero(bounded) if row_positions.size == 0: empty_rows = [[] for _ in range(int(rows.shape[0]))] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None return {}, empty_rows token_values = rows[row_positions, column_positions].astype(np.int64, copy=False) order = np.lexsort((row_positions, token_values)) token_values = token_values[order] row_positions = row_positions[order] unique = np.ones(token_values.shape[0], dtype=bool) unique[1:] = (token_values[1:] != token_values[:-1]) | (row_positions[1:] != row_positions[:-1]) token_values = token_values[unique] row_positions = row_positions[unique] boundaries = np.flatnonzero(token_values[1:] != token_values[:-1]) + 1 token_groups = np.split(token_values, boundaries) row_groups = np.split(row_positions, boundaries) inverted = { int(token_group[0]): row_group.astype(np.int64, copy=False).tolist() for token_group, row_group in zip(token_groups, row_groups) if token_group.size } if rows.shape[0] > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT: return inverted, None row_id_lists: list[list[int]] = [[] for _ in range(int(rows.shape[0]))] for token_id, row_index in zip(token_values.tolist(), row_positions.tolist()): row_id_lists[int(row_index)].append(int(token_id)) return inverted, row_id_lists def _refresh_answer_sequence_prompt_overlap_cache(self) -> None: self.answer_sequence_prompt_weight_maps = None self.answer_sequence_prompt_weight_norms = None self.answer_sequence_prompt_bigram_sets = None self.answer_sequence_prompt_trigram_sets = None self.answer_sequence_prompt_number_sets = None self.answer_sequence_prompt_inverted_index = None self.answer_sequence_prompt_specificity = None if self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None: return array_index = self._answer_prompt_row_ids_from_array() if array_index is not None: inverted, row_id_lists = array_index total_rows = ( int(self.answer_sequence_prompt_tokens_array.shape[0]) if self.answer_sequence_prompt_tokens_array is not None else len(row_id_lists or []) ) else: inverted = {} row_id_lists = [] for row in self.answer_sequence_prompt_tokens: row_values = row.tolist() if hasattr(row, "tolist") else row row_ids: list[int] = [] for raw_token_id in row_values: token_id = int(raw_token_id) if token_id < 0 or token_id >= len(self.trace_token_weights): continue if self.embedding_model is not None and self._should_skip_prompt_overlap_token( self.embedding_model.id_to_token[token_id] ): continue row_ids.append(token_id) sequence_index = len(row_id_lists) for token_id in set(row_ids): inverted.setdefault(token_id, []).append(sequence_index) row_id_lists.append(row_ids) total_rows = len(row_id_lists) specificity = { token_id: self._prompt_overlap_token_specificity(len(indices), total_rows) for token_id, indices in inverted.items() } self.answer_sequence_prompt_inverted_index = inverted self.answer_sequence_prompt_specificity = specificity if total_rows > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT: return if row_id_lists is None: return weight_maps: list[dict[int, float]] = [] weight_norms: list[float] = [] bigram_sets: list[set[tuple[int, int]]] = [] trigram_sets: list[set[tuple[int, int, int]]] = [] number_sets: list[set[str]] = [] for row_index, row_ids in enumerate(row_id_lists): row_weights: dict[int, float] = {} for token_id in row_ids: row_weights[token_id] = max( row_weights.get(token_id, 0.0), float(self.trace_token_weights[token_id]) * specificity.get(token_id, 1.0), ) weight_maps.append(row_weights) weight_norms.append(sum(value * value for value in row_weights.values()) ** 0.5) bigram_sets.append( { (row_ids[index], row_ids[index + 1]) for index in range(len(row_ids) - 1) } ) trigram_sets.append( { (row_ids[index], row_ids[index + 1], row_ids[index + 2]) for index in range(len(row_ids) - 2) } ) raw_row = self.answer_sequence_prompt_tokens[row_index] raw_values = raw_row.tolist() if hasattr(raw_row, "tolist") else raw_row raw_ids = [ int(value) for value in raw_values if 0 <= int(value) < len(self.embedding_model.id_to_token) ] number_sets.append(self._number_strings_from_token_ids(raw_ids)) self.answer_sequence_prompt_weight_maps = weight_maps self.answer_sequence_prompt_weight_norms = weight_norms self.answer_sequence_prompt_bigram_sets = bigram_sets self.answer_sequence_prompt_trigram_sets = trigram_sets self.answer_sequence_prompt_number_sets = number_sets @staticmethod def _prompt_overlap_token_specificity(document_frequency: int, total_documents: int) -> float: if document_frequency <= 0 or total_documents <= 0: return 1.0 coverage = min(1.0, document_frequency / total_documents) return max(0.02, 1.0 - (coverage ** 0.5)) def _number_strings_from_token_ids(self, token_ids: list[int]) -> set[str]: assert self.embedding_model is not None tokens = [ self.embedding_model.id_to_token[token_id] for token_id in token_ids if 0 <= token_id < len(self.embedding_model.id_to_token) ] return self._number_strings_from_tokens(tokens) def _number_strings_from_tokens(self, tokens: list[str]) -> set[str]: numbers: set[str] = set() current = "" for token in tokens: if self.tokenizer is not None and token in self.tokenizer.special_tokens: if current: numbers.add(current) current = "" continue rendered = self._render_token(token) digits = "".join(character for character in rendered if character.isdigit()) starts_number = self._starts_new_word(token) if self.tokenizer is not None else True if digits and starts_number: if current: numbers.add(current) current = digits elif digits and current: current += digits else: if current: numbers.add(current) current = "" if current: numbers.add(current) return numbers @staticmethod def _numeric_prompt_can_match(query_numbers: set[str], row_numbers: set[str]) -> bool: if not query_numbers: return True if not row_numbers: return False return query_numbers.issubset(row_numbers) def _vector_answer_sequence_candidate_indices( self, query_token_ids: object, ) -> list[int] | None: if ( np is None or self.answer_sequence_prompt_tokens_array is None or not hasattr(self.answer_sequence_prompt_tokens_array, "shape") ): return None query_ids = np.asarray(list(query_token_ids), dtype=np.int64) if query_ids.size == 0: return [] prompt_array = self.answer_sequence_prompt_tokens_array if len(prompt_array.shape) != 2 or prompt_array.shape[0] == 0: return None mask = np.isin(prompt_array, query_ids).any(axis=1) return [int(index) for index in np.flatnonzero(mask)] def _vector_answer_sequence_local_frequency( self, token_id: int, candidate_indices: list[int], ) -> int | None: if ( np is None or self.answer_sequence_prompt_tokens_array is None or not hasattr(self.answer_sequence_prompt_tokens_array, "shape") or not candidate_indices ): return None rows = self.answer_sequence_prompt_tokens_array[ np.asarray(candidate_indices, dtype=np.int64) ] return int(np.any(rows == int(token_id), axis=1).sum()) def _apply_readout_fast(self, state: Vector) -> Vector: if self.readout_weights_array is None or np is None: assert self.readout_weights is not None centered_state = self._center_state_vector(state) logits = apply_readout(self.readout_weights, centered_state) if self.readout_bias: logits = [ value + self.readout_bias[index] for index, value in enumerate(logits) ] return logits state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: state_array = state_array - self.state_offset_array logits = self.readout_weights_array @ state_array if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: logits = logits + self.readout_bias_array return logits.tolist() def _apply_readout_array(self, state: object) -> object: assert np is not None assert self.readout_weights_array is not None state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: state_array = state_array - self.state_offset_array logits = self.readout_weights_array @ state_array if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: logits = logits + self.readout_bias_array return logits def _step_hidden_states( self, hidden_states: list[Vector], context_traces: list[Vector], token: str, ) -> tuple[list[Vector], list[Vector], Vector]: assert self.embedding_model is not None assert self.tokenizer is not None token_id = self._token_id_for_token(token) embedding = self.embedding_model.vector(token) trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) return self._step_hidden_states_from_embedding( hidden_states, context_traces, embedding, trace_embedding=trace_embedding, ) def _step_hidden_states_from_embedding( self, hidden_states: list[Vector], context_traces: list[Vector], embedding: Vector | object, *, trace_embedding: Vector | object | None = None, ) -> tuple[list[Vector], list[Vector], Vector]: assert self.memory_units is not None if trace_embedding is None: trace_embedding = embedding if np is not None and hidden_states and hasattr(hidden_states[0], "shape"): embedding_array = ( embedding if hasattr(embedding, "shape") else np.asarray(embedding, dtype=np.float64) ) trace_embedding_array = ( trace_embedding if hasattr(trace_embedding, "shape") else np.asarray(trace_embedding, dtype=np.float64) ) drive = analytical_embedding_drive_fast(embedding_array, self.config.state_dim) next_states: list[Vector] = [] next_traces: list[Vector] = [] combined_state: Vector = [] for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): next_state = unit.step_vector_fast(state, drive) decay = 1.0 / (1.0 + unit.timescale) next_trace = trace + ((1.0 - decay) * trace_embedding_array) next_states.append(next_state) next_traces.append(next_trace) combined_state.extend(next_state.tolist()) combined_state.extend(next_trace.tolist()) return next_states, next_traces, combined_state embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else embedding trace_embedding_vector = ( trace_embedding.tolist() if hasattr(trace_embedding, "tolist") else trace_embedding ) drive = analytical_embedding_drive(embedding_vector, self.config.state_dim) next_states: list[Vector] = [] next_traces: list[Vector] = [] combined_state: Vector = [] for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): next_state = unit.step_vector(state, drive) decay = 1.0 / (1.0 + unit.timescale) next_trace = [ previous + ((1.0 - decay) * value) for previous, value in zip(trace, trace_embedding_vector) ] next_states.append(next_state) next_traces.append(next_trace) combined_state.extend(next_state) combined_state.extend(next_trace) return next_states, next_traces, combined_state def _one_hot(self, token: str) -> Vector: assert self.embedding_model is not None return self._one_hot_from_id(self.embedding_model.token_to_id.get(token, -1)) def _one_hot_from_id(self, token_id: int) -> Vector: assert self.embedding_model is not None vector = [0.0 for _ in self.embedding_model.id_to_token] if token_id >= 0: vector[token_id] = 1.0 return vector def _blend_probabilities( self, base: Vector, answer: Vector, associative: Vector, transition: Vector, copy: Vector, source_evidence: Vector, preference: Vector, *, transition_order: int | None, generated_count: int = 0, answer_locked: bool = False, answer_guided_start: bool = False, copy_guided_start: bool = False, ) -> tuple[Vector, dict[str, float]]: base_weight = FAST_BASE_BLEND answer_weight = FAST_ANSWER_BLEND associative_weight = FAST_ASSOCIATIVE_BLEND transition_weight = FAST_TRANSITION_BLEND copy_weight = FAST_COPY_BLEND source_evidence_weight = FAST_SOURCE_EVIDENCE_BLEND preference_weight = FAST_PREFERENCE_BLEND source_grounded = any(value > 0.0 for value in source_evidence) if answer_locked: base_weight *= 0.005 answer_weight *= 250.0 associative_weight *= 0.05 transition_weight *= 0.005 copy_weight *= 0.005 source_evidence_weight *= 0.05 preference_weight *= 0.05 elif answer_guided_start: base_weight *= 0.45 answer_weight *= 3.1 associative_weight *= 0.2 transition_weight *= 0.35 copy_weight *= 0.2 source_evidence_weight *= 1.1 preference_weight *= 0.2 elif copy_guided_start: base_weight *= 0.55 answer_weight *= 0.35 associative_weight *= 0.4 transition_weight *= 0.35 copy_weight *= 4.5 preference_weight *= 0.6 elif generated_count > 0: answer_weight *= 0.32 transition_weight *= 2.0 copy_weight *= 0.75 source_evidence_weight *= 0.85 if source_grounded: base_weight *= 0.45 answer_weight *= 0.35 associative_weight *= 0.50 transition_weight *= 0.25 copy_weight *= 0.50 source_evidence_weight *= 3.50 if source_grounded: base_weight *= 0.60 answer_weight *= 0.35 associative_weight *= 0.50 transition_weight *= 0.80 copy_weight *= 0.20 source_evidence_weight *= 1.80 else: source_evidence_weight = 0.0 if transition_order is None: answer_weight *= 1.1 associative_weight *= 0.75 copy_weight += 0.02 elif transition_order <= 2: answer_weight *= 1.15 associative_weight *= 0.65 transition_weight *= 0.55 copy_weight += 0.01 elif transition_order >= 5: transition_weight *= 1.25 sources: list[tuple[str, float, Vector]] = [("base", base_weight, base)] if any(value > 0.0 for value in answer): sources.append(("answer", answer_weight, answer)) if any(value > 0.0 for value in associative): sources.append(("associative", associative_weight, associative)) if any(value > 0.0 for value in transition): sources.append(("transition", transition_weight, transition)) if any(value > 0.0 for value in copy): sources.append(("copy", copy_weight, copy)) if any(value > 0.0 for value in source_evidence): sources.append(("source_evidence", source_evidence_weight, source_evidence)) if any(value > 0.0 for value in preference): sources.append(("preference", preference_weight, preference)) total_weight = sum(weight for _, weight, _ in sources) blended = [0.0 for _ in base] blend_weights: dict[str, float] = {} for name, weight, source in sources: normalized_weight = weight / total_weight if total_weight else 0.0 blend_weights[name] = normalized_weight for index, value in enumerate(source): blended[index] += normalized_weight * value return _normalize_vector(blended), blend_weights def _blend_probability_arrays( self, base: object, answer: object, associative: object, transition: object, copy: object, source_evidence: object, preference: object, *, transition_order: int | None, generated_count: int = 0, answer_locked: bool = False, answer_guided_start: bool = False, copy_guided_start: bool = False, ) -> tuple[object, dict[str, float]]: assert np is not None base_weight = FAST_BASE_BLEND answer_weight = FAST_ANSWER_BLEND associative_weight = FAST_ASSOCIATIVE_BLEND transition_weight = FAST_TRANSITION_BLEND copy_weight = FAST_COPY_BLEND source_evidence_weight = FAST_SOURCE_EVIDENCE_BLEND preference_weight = FAST_PREFERENCE_BLEND source_grounded = bool(np.any(source_evidence > 0.0)) if answer_locked: base_weight *= 0.005 answer_weight *= 250.0 associative_weight *= 0.05 transition_weight *= 0.005 copy_weight *= 0.005 source_evidence_weight *= 0.05 preference_weight *= 0.05 elif answer_guided_start: base_weight *= 0.45 answer_weight *= 3.1 associative_weight *= 0.2 transition_weight *= 0.35 copy_weight *= 0.2 source_evidence_weight *= 1.1 preference_weight *= 0.2 elif copy_guided_start: base_weight *= 0.55 answer_weight *= 0.35 associative_weight *= 0.4 transition_weight *= 0.35 copy_weight *= 4.5 preference_weight *= 0.6 elif generated_count > 0: answer_weight *= 0.32 transition_weight *= 2.0 copy_weight *= 0.75 source_evidence_weight *= 0.85 if source_grounded: base_weight *= 0.45 answer_weight *= 0.35 associative_weight *= 0.50 transition_weight *= 0.25 copy_weight *= 0.50 source_evidence_weight *= 3.50 if source_grounded: base_weight *= 0.60 answer_weight *= 0.35 associative_weight *= 0.50 transition_weight *= 0.80 copy_weight *= 0.20 source_evidence_weight *= 1.80 else: source_evidence_weight = 0.0 if transition_order is None: answer_weight *= 1.1 associative_weight *= 0.75 copy_weight += 0.02 elif transition_order <= 2: answer_weight *= 1.15 associative_weight *= 0.65 transition_weight *= 0.55 copy_weight += 0.01 elif transition_order >= 5: transition_weight *= 1.25 sources: list[tuple[str, float, object]] = [("base", base_weight, base)] if np.any(answer > 0.0): sources.append(("answer", answer_weight, answer)) if np.any(associative > 0.0): sources.append(("associative", associative_weight, associative)) if np.any(transition > 0.0): sources.append(("transition", transition_weight, transition)) if np.any(copy > 0.0): sources.append(("copy", copy_weight, copy)) if np.any(source_evidence > 0.0): sources.append(("source_evidence", source_evidence_weight, source_evidence)) if np.any(preference > 0.0): sources.append(("preference", preference_weight, preference)) total_weight = sum(weight for _, weight, _ in sources) blended = np.zeros_like(base, dtype=np.float64) blend_weights: dict[str, float] = {} for name, weight, source in sources: normalized_weight = weight / total_weight if total_weight else 0.0 blend_weights[name] = normalized_weight blended += normalized_weight * source total = float(blended.sum()) if total <= 0.0: return base, blend_weights return blended / total, blend_weights def _score_associative_matches( self, state: Vector, *, limit: int = ASSOCIATIVE_TOP_K, ) -> list[tuple[float, int, int]]: if ( self.associative_keys is None or self.associative_values is None or len(self.associative_keys) == 0 or len(self.associative_values) == 0 ): return [] if ( np is not None and self.associative_keys_array is not None and self.associative_key_norms_array is not None and self.associative_values_array is not None and self.associative_valid_mask_array is not None and limit > 0 ): state_array = self._center_state_array(state).astype(self.associative_keys_array.dtype, copy=False) state_norm = float(np.linalg.norm(state_array)) if state_norm == 0.0: return [] numerators = self.associative_keys_array @ state_array denominators = self.associative_key_norms_array * state_norm valid_mask = self.associative_valid_mask_array & (denominators > 0.0) if np.any(valid_mask): scores = np.zeros_like(numerators, dtype=self.associative_keys_array.dtype) np.divide(numerators, denominators, out=scores, where=valid_mask) positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) if positive_positions.size: selected_positions = positive_positions if positive_positions.size > limit: partition = np.argpartition(scores[positive_positions], -limit)[-limit:] selected_positions = positive_positions[partition] ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] return [ ( float(scores[position]), int(self.associative_values_array[position]), int(position), ) for position in ordered_positions ] if self.associative_key_norms is None or len(self.associative_key_norms) == 0: return [] state = self._center_state_vector(state) state_norm = norm(state) if state_norm == 0.0: return [] scored: list[tuple[float, int, int]] = [] for example_index, (key, key_norm, token_id) in enumerate( zip(self.associative_keys, self.associative_key_norms, self.associative_values) ): if token_id < 0: continue denominator = state_norm * key_norm if denominator == 0.0: continue similarity = dot(state, key) / denominator if similarity > 0.0: scored.append((similarity, token_id, example_index)) scored.sort(key=lambda item: item[0], reverse=True) return scored[:limit] def _associative_prior_from_matches( self, matches: list[tuple[float, int, int]], ) -> Vector: assert self.embedding_model is not None if not matches: return [0.0 for _ in self.embedding_model.id_to_token] prior = [0.0 for _ in self.embedding_model.id_to_token] for similarity, token_id, _ in matches[:ASSOCIATIVE_TOP_K]: prior[token_id] += similarity return _normalize_vector(prior) def _associative_prior(self, state: Vector) -> Vector: return self._associative_prior_from_matches(self._score_associative_matches(state)) def _score_answer_matches( self, answer_anchor_state: Vector | None, *, limit: int = ANSWER_TOP_K, ) -> list[tuple[float, int, int]]: return self._score_prompt_anchor_matches( answer_anchor_state, self.answer_keys, self.answer_key_norms, self.answer_values, self.answer_keys_array, self.answer_key_norms_array, self.answer_values_array, self.answer_valid_mask_array, self.answer_similarity_keys_array, self.answer_similarity_key_norms_array, self.answer_similarity_mask_array, limit=limit, ) def _score_answer_start_matches( self, answer_anchor_state: Vector | None, *, limit: int = ANSWER_START_TOP_K, ) -> list[tuple[float, int, int]]: matches = self._score_prompt_anchor_matches( answer_anchor_state, self.answer_start_keys, self.answer_start_key_norms, self.answer_start_values, self.answer_start_keys_array, self.answer_start_key_norms_array, self.answer_start_values_array, self.answer_start_valid_mask_array, self.answer_start_similarity_keys_array, self.answer_start_similarity_key_norms_array, self.answer_similarity_mask_array, limit=limit, ) if matches: return matches return self._score_prompt_anchor_matches( answer_anchor_state, self.answer_start_keys, self.answer_start_key_norms, self.answer_start_values, self.answer_start_keys_array, self.answer_start_key_norms_array, self.answer_start_values_array, self.answer_start_valid_mask_array, None, None, None, limit=limit, ) def _score_answer_sequence_matches( self, answer_anchor_state: Vector | None, context_tokens: list[str], *, limit: int = ANSWER_START_TOP_K, ) -> list[tuple[float, int, int]]: if ( answer_anchor_state is None or self.answer_sequence_keys is None or self.answer_sequence_key_norms is None or self.answer_sequence_tokens is None ): return [] values = list(range(len(self.answer_sequence_tokens))) values_array = np.arange(len(values), dtype=np.int64) if np is not None else None anchor_matches = self._score_prompt_anchor_matches( answer_anchor_state, self.answer_sequence_keys, self.answer_sequence_key_norms, values, self.answer_sequence_keys_array, self.answer_sequence_key_norms_array, values_array, values_array >= 0 if values_array is not None else None, self.answer_sequence_similarity_keys_array, self.answer_sequence_similarity_key_norms_array, self.answer_similarity_mask_array, limit=max(limit * 4, limit), ) overlap_scores = self._answer_sequence_prompt_overlap_scores(context_tokens) if overlap_scores is None: return anchor_matches[:limit] if not overlap_scores: return [] best_overlap = max(overlap_scores.values()) if overlap_scores else 0.0 overlap_floor = max(0.16, best_overlap * 0.90) focused_overlap_scores = { sequence_index: overlap for sequence_index, overlap in overlap_scores.items() if overlap >= overlap_floor } if not focused_overlap_scores: focused_overlap_scores = overlap_scores focused_indices = set(focused_overlap_scores) merged: dict[int, float] = {} for similarity, sequence_index, _ in anchor_matches: if sequence_index not in focused_indices: continue merged[sequence_index] = max(merged.get(sequence_index, 0.0), 0.20 * similarity) for sequence_index, overlap in focused_overlap_scores.items(): merged[sequence_index] = merged.get(sequence_index, 0.0) + (0.80 * overlap) ranked = [ (score, sequence_index, sequence_index) for sequence_index, score in merged.items() if score > 0.0 ] ranked.sort(key=lambda item: item[0], reverse=True) return ranked[:limit] def _answer_sequence_prompt_overlap_scores( self, context_tokens: list[str], ) -> dict[int, float] | None: if ( self.embedding_model is None or self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None ): return None answer_boundary = _last_index(context_tokens, "") prompt_tokens = ( context_tokens[:answer_boundary] if answer_boundary is not None else context_tokens ) if ( self.answer_sequence_prompt_specificity is None and not self._defer_answer_sequence_prompt_overlap_cache() ): self._refresh_answer_sequence_prompt_overlap_cache() specificity_map = self.answer_sequence_prompt_specificity or {} query_weights: dict[int, float] = {} query_specificity: dict[int, float] = {} query_segment_multipliers: dict[int, float] = {} query_content_weight = 0.0 query_ids: list[int] = [] primary_query_ids: list[int] = [] inside_tool_evidence = False prompt_segment_index = 0 for token in prompt_tokens: if token in {"", ""}: inside_tool_evidence = True continue if token == "": inside_tool_evidence = False continue if self.tokenizer is not None and token in self.tokenizer.special_tokens: continue if self._is_structural_punctuation_token(token): prompt_segment_index += 1 continue if self._should_skip_prompt_overlap_token(token): continue token_id = self.embedding_model.token_to_id.get(token) if token_id is None: continue query_ids.append(token_id) specificity = specificity_map.get(token_id, 1.0) evidence_multiplier = 0.35 if inside_tool_evidence else 1.0 segment_multiplier = evidence_multiplier / (1.0 + prompt_segment_index) weight = specificity * segment_multiplier query_weights[token_id] = max( query_weights.get(token_id, 0.0), weight, ) query_specificity[token_id] = max( query_specificity.get(token_id, 0.0), specificity, ) query_segment_multipliers[token_id] = max( query_segment_multipliers.get(token_id, 0.0), segment_multiplier, ) if not inside_tool_evidence: primary_query_ids.append(token_id) if specificity >= 0.20: query_content_weight += weight if not query_weights: return None full_query_token_ids = set(query_ids) primary_query_token_ids = set(primary_query_ids) has_tool_evidence = any(token in {"", ""} for token in prompt_tokens) query_norm = sum(value * value for value in query_weights.values()) ** 0.5 if query_norm <= 0.0: return None query_bigrams = { (query_ids[index], query_ids[index + 1]) for index in range(len(query_ids) - 1) } query_trigrams = { (query_ids[index], query_ids[index + 1], query_ids[index + 2]) for index in range(len(query_ids) - 2) } query_numbers = self._number_strings_from_tokens(prompt_tokens) def ordered_ngram_score( query_grams: set[tuple[int, ...]], row_grams: set[tuple[int, ...]], ) -> float: if not query_grams or not row_grams: return 0.0 overlap = len(query_grams & row_grams) if overlap <= 0: return 0.0 return overlap / ((len(query_grams) * len(row_grams)) ** 0.5) def prompt_length_fit(row_token_count: int) -> float: query_token_count = len(full_query_token_ids) if query_token_count <= 0 or row_token_count <= 0: return 1.0 if row_token_count <= query_token_count: return 1.0 extra_fraction = (row_token_count - query_token_count) / row_token_count return max(0.25, 1.0 - extra_fraction) cached_maps = self.answer_sequence_prompt_weight_maps cached_norms = self.answer_sequence_prompt_weight_norms cached_bigrams = self.answer_sequence_prompt_bigram_sets cached_trigrams = self.answer_sequence_prompt_trigram_sets cached_numbers = self.answer_sequence_prompt_number_sets cached_index = self.answer_sequence_prompt_inverted_index if ( cached_maps is not None and cached_norms is not None and cached_bigrams is not None and cached_trigrams is not None and cached_numbers is not None and len(cached_maps) == len(self.answer_sequence_prompt_tokens) ): candidate_indices: set[int] | range if cached_index is not None: candidates: set[int] = set() ranked_query_ids = sorted( query_weights, key=lambda token_id: specificity_map.get(token_id, 1.0), reverse=True, ) distinctive_query_ids = [ token_id for token_id in ranked_query_ids if specificity_map.get(token_id, 1.0) >= 0.75 ] or ranked_query_ids[:4] for token_id in distinctive_query_ids: candidates.update(cached_index.get(token_id, ())) candidate_indices = candidates if candidates else range(len(cached_maps)) else: candidate_indices = range(len(cached_maps)) candidate_indices = list(candidate_indices) if cached_index is not None and candidate_indices: candidate_set = set(candidate_indices) local_query_weights: dict[int, float] = {} local_query_specificity: dict[int, float] = {} local_query_content_weight = 0.0 for token_id in query_weights: local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) if local_frequency <= 0: continue specificity = self._prompt_overlap_token_specificity( local_frequency, len(candidate_indices), ) weight = specificity * query_segment_multipliers.get(token_id, 1.0) local_query_weights[token_id] = weight local_query_specificity[token_id] = specificity if specificity >= 0.20: local_query_content_weight += weight local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 if local_query_norm > 0.0: query_weights = local_query_weights query_specificity = local_query_specificity query_norm = local_query_norm scores: dict[int, float] = {} for sequence_index in candidate_indices: row_weights = cached_maps[sequence_index] if not row_weights: continue if query_numbers and not self._numeric_prompt_can_match( query_numbers, cached_numbers[sequence_index], ): continue matched_content_weight = sum( query_weights[token_id] for token_id in query_weights.keys() & row_weights.keys() if query_specificity.get(token_id, 0.0) >= 0.20 ) row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( 1, len(row_weights), ) full_query_coverage = len(full_query_token_ids & row_weights.keys()) / max( 1, len(full_query_token_ids), ) primary_query_coverage = len(primary_query_token_ids & row_weights.keys()) / max( 1, len(primary_query_token_ids), ) if ( has_tool_evidence and len(primary_query_token_ids) >= 3 and primary_query_coverage < 0.45 and row_token_coverage < 0.75 ): continue partial_query_floor = 0.60 if len(full_query_token_ids) < 8 else 0.50 if ( len(full_query_token_ids) >= 5 and full_query_coverage <= partial_query_floor and row_token_coverage < 0.75 ): continue if ( len(full_query_token_ids) >= 12 and full_query_coverage < 0.45 and row_token_coverage <= 0.75 ): continue if ( query_content_weight > 0.0 and matched_content_weight / query_content_weight < 0.40 and row_token_coverage < 0.75 and full_query_coverage < 0.60 ): continue query_coverage = ( matched_content_weight / query_content_weight if query_content_weight > 0.0 else row_token_coverage ) numerator = sum( query_weights[token_id] * row_weights[token_id] for token_id in query_weights.keys() & row_weights.keys() ) if numerator <= 0.0: continue row_norm = cached_norms[sequence_index] if row_norm <= 0.0: continue token_score = numerator / (query_norm * row_norm) bigram_score = ordered_ngram_score( query_bigrams, cached_bigrams[sequence_index], ) trigram_score = ordered_ngram_score( query_trigrams, cached_trigrams[sequence_index], ) scores[sequence_index] = ( (0.35 * token_score) + (0.35 * query_coverage) + (0.15 * bigram_score) + (0.15 * trigram_score) ) * prompt_length_fit(len(row_weights)) return scores vector_candidate_indices: list[int] | None = None if cached_index is not None: candidate_set: set[int] = set() ranked_query_ids = sorted( query_weights, key=lambda token_id: specificity_map.get(token_id, 1.0), reverse=True, ) distinctive_query_ids = [ token_id for token_id in ranked_query_ids if specificity_map.get(token_id, 1.0) >= 0.75 ] or ranked_query_ids[:4] for token_id in distinctive_query_ids: candidate_set.update(cached_index.get(token_id, ())) if not candidate_set: for token_id in ranked_query_ids: candidate_set.update(cached_index.get(token_id, ())) if candidate_set: break if not candidate_set: candidate_indices = range(len(self.answer_sequence_prompt_tokens)) else: candidate_indices = sorted(candidate_set) local_query_weights: dict[int, float] = {} local_query_specificity: dict[int, float] = {} local_query_content_weight = 0.0 candidate_count = len(candidate_indices) for token_id in query_weights: local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) if local_frequency <= 0: continue specificity = self._prompt_overlap_token_specificity( local_frequency, candidate_count, ) local_query_weights[token_id] = specificity * query_segment_multipliers.get(token_id, 1.0) local_query_specificity[token_id] = specificity if specificity >= 0.20: local_query_content_weight += local_query_weights[token_id] local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 if local_query_norm > 0.0: query_weights = local_query_weights query_specificity = local_query_specificity query_norm = local_query_norm elif self._defer_answer_sequence_prompt_overlap_cache(): vector_candidate_indices = self._vector_answer_sequence_candidate_indices( query_weights.keys() ) if vector_candidate_indices is not None: if not vector_candidate_indices: return {} candidate_indices = vector_candidate_indices local_query_weights = {} local_query_specificity = {} local_query_content_weight = 0.0 candidate_count = len(vector_candidate_indices) for token_id in query_weights: local_frequency = self._vector_answer_sequence_local_frequency( token_id, vector_candidate_indices, ) if local_frequency is None or local_frequency <= 0: continue specificity = self._prompt_overlap_token_specificity( local_frequency, candidate_count, ) local_query_weights[token_id] = specificity * query_segment_multipliers.get(token_id, 1.0) local_query_specificity[token_id] = specificity if specificity >= 0.20: local_query_content_weight += local_query_weights[token_id] local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 if local_query_norm > 0.0: query_weights = local_query_weights query_specificity = local_query_specificity query_norm = local_query_norm else: candidate_indices = range(len(self.answer_sequence_prompt_tokens)) valid_token_mask = self._prompt_overlap_valid_token_mask() scores: dict[int, float] = {} for sequence_index in candidate_indices: row = self.answer_sequence_prompt_tokens[sequence_index] row_values = row.tolist() if hasattr(row, "tolist") else row row_weights: dict[int, float] = {} row_ids: list[int] = [] raw_row_ids: list[int] = [] for raw_token_id in row_values: token_id = int(raw_token_id) if token_id < 0 or token_id >= len(self.trace_token_weights): continue raw_row_ids.append(token_id) if valid_token_mask is not None: if token_id >= len(valid_token_mask) or not bool(valid_token_mask[token_id]): continue elif self._should_skip_prompt_overlap_token( self.embedding_model.id_to_token[token_id] ): continue row_ids.append(token_id) row_weights[token_id] = max( row_weights.get(token_id, 0.0), specificity_map.get(token_id, 1.0), ) if not row_weights: continue if query_numbers and not self._numeric_prompt_can_match( query_numbers, self._number_strings_from_token_ids(raw_row_ids), ): continue matched_content_weight = sum( query_weights[token_id] for token_id in query_weights.keys() & row_weights.keys() if query_specificity.get(token_id, 0.0) >= 0.20 ) row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( 1, len(row_weights), ) full_query_coverage = len(full_query_token_ids & row_weights.keys()) / max( 1, len(full_query_token_ids), ) primary_query_coverage = len(primary_query_token_ids & row_weights.keys()) / max( 1, len(primary_query_token_ids), ) if ( has_tool_evidence and len(primary_query_token_ids) >= 3 and primary_query_coverage < 0.45 and row_token_coverage < 0.75 ): continue partial_query_floor = 0.60 if len(full_query_token_ids) < 8 else 0.30 if ( len(full_query_token_ids) >= 5 and full_query_coverage <= partial_query_floor and row_token_coverage < 0.75 ): continue if ( len(full_query_token_ids) >= 12 and full_query_coverage < 0.25 and row_token_coverage <= 0.75 ): continue if ( query_content_weight > 0.0 and matched_content_weight / query_content_weight < 0.25 and row_token_coverage < 0.75 and full_query_coverage < 0.60 ): continue query_coverage = ( matched_content_weight / query_content_weight if query_content_weight > 0.0 else row_token_coverage ) numerator = sum( query_weights[token_id] * row_weights[token_id] for token_id in query_weights.keys() & row_weights.keys() ) if numerator <= 0.0: continue row_norm = sum(value * value for value in row_weights.values()) ** 0.5 if row_norm > 0.0: token_score = numerator / (query_norm * row_norm) row_bigrams = { (row_ids[index], row_ids[index + 1]) for index in range(len(row_ids) - 1) } row_trigrams = { (row_ids[index], row_ids[index + 1], row_ids[index + 2]) for index in range(len(row_ids) - 2) } bigram_score = ordered_ngram_score(query_bigrams, row_bigrams) trigram_score = ordered_ngram_score(query_trigrams, row_trigrams) scores[sequence_index] = ( (0.35 * token_score) + (0.35 * query_coverage) + (0.15 * bigram_score) + (0.15 * trigram_score) ) * prompt_length_fit(len(row_weights)) return scores def _score_prompt_anchor_matches( self, answer_anchor_state: Vector | None, keys: object | None, key_norms_list: object | None, values: object | None, keys_array: object | None, key_norms_array: object | None, values_array: object | None, valid_mask_array: object | None, similarity_keys_array: object | None, similarity_key_norms_array: object | None, similarity_mask_array: object | None, *, limit: int, ) -> list[tuple[float, int, int]]: if ( answer_anchor_state is None or keys is None or key_norms_list is None or values is None ): return [] if ( np is not None and keys_array is not None and key_norms_array is not None and values_array is not None and valid_mask_array is not None and limit > 0 ): key_array = keys_array key_norms = key_norms_array if ( similarity_keys_array is not None and similarity_key_norms_array is not None and similarity_mask_array is not None ): state_array = self._center_state_array( self._masked_combined_state_array(answer_anchor_state) ).astype(keys_array.dtype, copy=False) state_array = state_array * similarity_mask_array key_array = similarity_keys_array key_norms = similarity_key_norms_array else: state_array = self._center_state_array(answer_anchor_state).astype( keys_array.dtype, copy=False, ) state_norm = float(np.linalg.norm(state_array)) if state_norm == 0.0: return [] numerators = key_array @ state_array denominators = key_norms * state_norm valid_mask = valid_mask_array & (denominators > 0.0) if np.any(valid_mask): scores = np.zeros_like(numerators, dtype=key_array.dtype) np.divide(numerators, denominators, out=scores, where=valid_mask) positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) if positive_positions.size: selected_positions = positive_positions if positive_positions.size > limit: partition = np.argpartition(scores[positive_positions], -limit)[-limit:] selected_positions = positive_positions[partition] ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] return [ ( float(scores[position]), int(values_array[position]), int(position), ) for position in ordered_positions ] if similarity_mask_array is not None: state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) else: state = self._center_state_vector(answer_anchor_state) state_norm = norm(state) if state_norm == 0.0: return [] scored: list[tuple[float, int, int]] = [] for example_index, (key, key_norm, token_id) in enumerate( zip(keys, key_norms_list, values) ): if token_id < 0: continue denominator = state_norm * key_norm if denominator == 0.0: continue similarity = dot(state, key) / denominator if similarity > 0.0: scored.append((similarity, token_id, example_index)) scored.sort(key=lambda item: item[0], reverse=True) return scored[:limit] def _answer_prior_from_matches( self, matches: list[tuple[float, int, int]], generated_tokens: list[str], ) -> Vector: assert self.embedding_model is not None if not matches: return [0.0 for _ in self.embedding_model.id_to_token] prior = [0.0 for _ in self.embedding_model.id_to_token] generated_ids = { self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id } for similarity, token_id, _ in matches[:ANSWER_TOP_K]: token = self.embedding_model.id_to_token[token_id] if not self._allowed_generation_token(token, generated_tokens): continue if token_id in generated_ids: prior[token_id] += similarity * 0.35 else: prior[token_id] += similarity return _normalize_vector(prior) def _answer_start_matches_from_sequences( self, matches: list[tuple[float, int, int]], ) -> list[tuple[float, int, int]]: if not matches or self.answer_sequence_tokens is None: return [] start_matches: list[tuple[float, int, int]] = [] for similarity, sequence_index, example_index in matches[:ANSWER_START_TOP_K]: if sequence_index >= len(self.answer_sequence_tokens): continue row = self.answer_sequence_tokens[sequence_index] token_ids = [ int(value) for value in (row.tolist() if hasattr(row, "tolist") else row) if int(value) >= 0 ] if token_ids: start_matches.append((similarity, token_ids[0], example_index)) return start_matches def _answer_sequence_prior_from_matches( self, matches: list[tuple[float, int, int]], generated_tokens: list[str], *, temperature: float = 0.0, ) -> Vector: assert self.embedding_model is not None if not matches or self.answer_sequence_tokens is None: return [0.0 for _ in self.embedding_model.id_to_token] generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] prior = [0.0 for _ in self.embedding_model.id_to_token] best_similarity = matches[0][0] if best_similarity >= 0.9: floor_delta = 0.14 if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE else 0.02 match_floor = best_similarity - floor_delta else: match_floor = 0.0 for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: if similarity < ANSWER_SEQUENCE_MATCH_FLOOR: continue if similarity < match_floor: continue token_ids = self._answer_sequence_token_row(sequence_index) if not token_ids: continue next_token_id = self._next_sequence_token_id(token_ids, generated_ids) if next_token_id is None: continue token = self.embedding_model.id_to_token[next_token_id] if self._allowed_answer_sequence_token(token, generated_tokens): prior[next_token_id] += max(1e-9, similarity - match_floor) return _normalize_vector(prior) def _answer_sequence_token_row(self, sequence_index: int) -> list[int]: if sequence_index < 0 or self.answer_sequence_tokens is None: return [] if self.answer_sequence_token_id_rows is not None: if sequence_index >= len(self.answer_sequence_token_id_rows): return [] return self.answer_sequence_token_id_rows[sequence_index] if ( np is not None and hasattr(self.answer_sequence_tokens, "shape") and len(self.answer_sequence_tokens.shape) == 2 ): if sequence_index >= int(self.answer_sequence_tokens.shape[0]): return [] row = np.asarray(self.answer_sequence_tokens[sequence_index]) return [int(value) for value in row.tolist() if int(value) >= 0] try: row = self.answer_sequence_tokens[sequence_index] except (IndexError, TypeError): return [] return self._answer_token_ids_from_row(row) def _filter_avoided_answer_sequence_matches( self, matches: list[tuple[float, int, int]] | None, avoid_token_sequences: Sequence[Sequence[str]] | None, ) -> list[tuple[float, int, int]]: if ( not matches or not avoid_token_sequences or self.embedding_model is None or self.answer_sequence_tokens is None ): return list(matches or []) token_to_id = self.embedding_model.token_to_id avoided_id_sequences: set[tuple[int, ...]] = set() for sequence in avoid_token_sequences: ids: list[int] = [] for token in sequence: token_id = token_to_id.get(token) if token_id is None: ids = [] break ids.append(token_id) if ids: avoided_id_sequences.add(tuple(ids)) if not avoided_id_sequences: return list(matches) sequence_rows = self._answer_sequence_token_rows() filtered: list[tuple[float, int, int]] = [] for match in matches: _, sequence_index, _ = match if sequence_index >= len(sequence_rows): filtered.append(match) continue if tuple(sequence_rows[sequence_index]) in avoided_id_sequences: continue filtered.append(match) return filtered def _answer_sequence_token_rows(self) -> list[list[int]]: if self.answer_sequence_token_id_rows is not None: return self.answer_sequence_token_id_rows rows: list[list[int]] = [] if ( np is not None and self.answer_sequence_tokens is not None and hasattr(self.answer_sequence_tokens, "shape") and len(self.answer_sequence_tokens.shape) == 2 ): token_rows = np.asarray(self.answer_sequence_tokens).tolist() rows = [ [int(value) for value in row if int(value) >= 0] for row in token_rows ] elif self.answer_sequence_tokens is not None: for row in self.answer_sequence_tokens: rows.append(self._answer_token_ids_from_row(row)) self.answer_sequence_token_id_rows = rows return rows @staticmethod def _answer_token_ids_from_row(row: object) -> list[int]: values = row.tolist() if hasattr(row, "tolist") else row if not isinstance(values, list): return [] return [int(value) for value in values if int(value) >= 0] @staticmethod def _answer_fingerprint_from_token_ids(token_ids: list[int]) -> tuple[int, ...]: payload = ",".join(str(token_id) for token_id in token_ids).encode("ascii") digest = hashlib.blake2s( payload, digest_size=ANSWER_FINGERPRINT_WORDS * 4, ).digest() return tuple( int.from_bytes( digest[index * 4 : (index + 1) * 4], "little", signed=True, ) for index in range(ANSWER_FINGERPRINT_WORDS) ) def _refresh_answer_fingerprint_hashes(self) -> None: hashes: set[tuple[int, ...]] = set() lengths: set[int] = set() sequences_by_length: dict[int, set[tuple[int, ...]]] = {} if self.answer_sequence_tokens is not None: for token_ids in self._answer_sequence_token_rows(): if token_ids: token_length = len(token_ids) lengths.add(token_length) sequences_by_length.setdefault(token_length, set()).add(tuple(token_ids)) hashes.add(self._answer_fingerprint_from_token_ids(token_ids)) self.answer_fingerprint_hashes = hashes self.answer_fingerprint_token_lengths = lengths self.answer_fingerprint_token_sequences_by_length = sequences_by_length def _answer_fingerprint_tensor(self) -> list[list[int]]: if self.answer_fingerprint_hashes is None: self._refresh_answer_fingerprint_hashes() return [ list(fingerprint) for fingerprint in sorted(self.answer_fingerprint_hashes or set()) ] @staticmethod def _coerce_answer_fingerprint_hashes(raw_fingerprints: object) -> set[tuple[int, ...]]: rows = raw_fingerprints.tolist() if hasattr(raw_fingerprints, "tolist") else raw_fingerprints hashes: set[tuple[int, ...]] = set() if not isinstance(rows, list): return hashes for row in rows: values = row.tolist() if hasattr(row, "tolist") else row if not isinstance(values, list): continue fingerprint = tuple(int(value) for value in values) if len(fingerprint) == ANSWER_FINGERPRINT_WORDS: hashes.add(fingerprint) return hashes def _answer_fingerprint_lengths(self) -> set[int]: if self.answer_fingerprint_token_lengths is not None: return self.answer_fingerprint_token_lengths lengths: set[int] = set() if ( np is not None and self.answer_sequence_tokens is not None and hasattr(self.answer_sequence_tokens, "shape") and len(self.answer_sequence_tokens.shape) == 2 ): token_matrix = np.asarray(self.answer_sequence_tokens) length_values = np.sum(token_matrix >= 0, axis=1) lengths = { int(length) for length in np.unique(length_values).tolist() if int(length) > 0 } elif self.answer_sequence_tokens is not None: for token_ids in self._answer_sequence_token_rows(): if token_ids: lengths.add(len(token_ids)) self.answer_fingerprint_token_lengths = lengths return lengths def _use_runtime_fingerprint_blacklist(self) -> bool: if ( np is None or self.answer_sequence_tokens is None or not hasattr(self.answer_sequence_tokens, "shape") or len(self.answer_sequence_tokens.shape) != 2 ): return False return int(self.answer_sequence_tokens.shape[0]) > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT def _answer_fingerprint_token_sequence_sets(self) -> dict[int, set[tuple[int, ...]]]: if self.answer_fingerprint_token_sequences_by_length is not None: return self.answer_fingerprint_token_sequences_by_length sequences_by_length: dict[int, set[tuple[int, ...]]] = {} lengths: set[int] = set() if self.answer_sequence_tokens is not None: for token_ids in self._answer_sequence_token_rows(): if token_ids: token_length = len(token_ids) lengths.add(token_length) sequences_by_length.setdefault(token_length, set()).add(tuple(token_ids)) self.answer_fingerprint_token_lengths = lengths self.answer_fingerprint_token_sequences_by_length = sequences_by_length return sequences_by_length def _token_ids_for_generated_tokens(self, generated_tokens: Sequence[str]) -> list[int] | None: if self.embedding_model is None: return None token_ids: list[int] = [] for token in generated_tokens: token_id = self.embedding_model.token_to_id.get(token) if token_id is None: return None token_ids.append(token_id) return token_ids def _would_complete_blacklisted_answer( self, generated_tokens: list[str], candidate: str, ) -> bool: generated_token_ids = self._token_ids_for_generated_tokens(generated_tokens) return self._would_complete_blacklisted_answer_ids(generated_token_ids, candidate) def _would_complete_blacklisted_answer_ids( self, generated_token_ids: Sequence[int] | None, candidate: str, ) -> bool: if ( self.embedding_model is None or not self.answer_fingerprint_hashes or candidate not in self.embedding_model.token_to_id or generated_token_ids is None ): return False candidate_id = self.embedding_model.token_to_id[candidate] if self._is_terminal_punctuation_text(self._render_token(candidate)): return False candidate_length = len(generated_token_ids) + 1 if self._use_runtime_fingerprint_blacklist(): lengths = self._answer_fingerprint_lengths() if lengths and candidate_length not in lengths: return False token_ids = [*generated_token_ids, candidate_id] if not token_ids: return False return self._answer_fingerprint_from_token_ids(token_ids) in self.answer_fingerprint_hashes sequence_sets = self._answer_fingerprint_token_sequence_sets() candidate_sequences = sequence_sets.get(candidate_length) if candidate_sequences is not None: return (*generated_token_ids, candidate_id) in candidate_sequences if self.answer_sequence_tokens is not None: return False lengths = self._answer_fingerprint_lengths() if lengths and candidate_length not in lengths: return False token_ids = [*generated_token_ids, candidate_id] if not token_ids: return False return self._answer_fingerprint_from_token_ids(token_ids) in self.answer_fingerprint_hashes def _would_follow_blacklisted_answer_prefix_ids( self, generated_token_ids: Sequence[int] | None, candidate: str, *, minimum_prefix_length: int = ANSWER_REPLAY_PREFIX_MIN_TOKENS, ) -> bool: if ( self.embedding_model is None or self.answer_sequence_tokens is None or candidate not in self.embedding_model.token_to_id or generated_token_ids is None ): return False candidate_id = self.embedding_model.token_to_id[candidate] candidate_path = (*generated_token_ids, candidate_id) if len(candidate_path) < minimum_prefix_length: return False prefix_sets = self._answer_sequence_prefix_sets(minimum_prefix_length) return candidate_path in prefix_sets.get(len(candidate_path), set()) def _answer_sequence_prefix_sets( self, minimum_prefix_length: int = ANSWER_REPLAY_PREFIX_MIN_TOKENS, ) -> dict[int, set[tuple[int, ...]]]: cached = self.answer_sequence_prefixes_by_length if cached is not None: return cached prefixes: dict[int, set[tuple[int, ...]]] = {} for token_ids in self._answer_sequence_token_rows(): for length in range(minimum_prefix_length, len(token_ids) + 1): prefixes.setdefault(length, set()).add(tuple(token_ids[:length])) self.answer_sequence_prefixes_by_length = prefixes return prefixes def _avoid_text_token_sequences( self, avoid_texts: Sequence[str] | None, ) -> list[list[str]]: if not avoid_texts or self.tokenizer is None: return [] sequences: list[list[str]] = [] seen: set[tuple[str, ...]] = set() for text in avoid_texts: if not isinstance(text, str) or not text.strip(): continue tokens = [ token for token in self.tokenizer.encode(text) if token not in self.tokenizer.special_tokens ] key = tuple(tokens) if tokens and key not in seen: seen.add(key) sequences.append(tokens) return sequences @staticmethod def _runtime_generation_history_key(context: str) -> str: return " ".join(context.split()).casefold() @staticmethod def _runtime_history_enabled(context: str, *, temperature: float) -> bool: if temperature < ANSWER_REPLAY_PREFIX_TEMPERATURE: return False lowered = context.casefold() return "" not in lowered and "" not in lowered def _runtime_avoid_texts( self, context: str, avoid_texts: Sequence[str] | None, *, temperature: float, ) -> list[str]: combined: list[str] = [] seen: set[str] = set() for text in avoid_texts or (): cleaned = " ".join(str(text).split()) if cleaned and cleaned not in seen: combined.append(cleaned) seen.add(cleaned) if not self._runtime_history_enabled(context, temperature=temperature): return combined history = self.runtime_generation_history.get( self._runtime_generation_history_key(context), [], ) for text in history: cleaned = " ".join(str(text).split()) if cleaned and cleaned not in seen: combined.append(cleaned) seen.add(cleaned) return combined def _remember_runtime_generation( self, context: str, generated_text: str, *, temperature: float, ) -> None: if not self._runtime_history_enabled(context, temperature=temperature): return cleaned = " ".join(generated_text.split()) if not cleaned: return key = self._runtime_generation_history_key(context) history = [ existing for existing in self.runtime_generation_history.get(key, []) if existing != cleaned ] history.append(cleaned) self.runtime_generation_history[key] = history[-RUNTIME_GENERATION_HISTORY_LIMIT:] @staticmethod def _would_follow_avoided_sequence( generated_tokens: list[str], candidate: str, avoid_token_sequences: Sequence[Sequence[str]] | None, ) -> bool: if not avoid_token_sequences: return False prefix_length = len(generated_tokens) + 1 if prefix_length < AVOID_SEQUENCE_MIN_TOKENS: return False candidate_path = [*generated_tokens, candidate] for sequence in avoid_token_sequences: if prefix_length <= len(sequence) and list(sequence[:prefix_length]) == candidate_path: return True return False def _should_stop_answer_sequence( self, decode_state: DecodeState, generated_tokens: list[str], ) -> bool: matches = decode_state.answer_sequence_matches if matches is None: matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, ) return self._answer_sequence_is_complete(generated_tokens, matches) def _should_stop_after_answer_path_drift( self, decode_state: DecodeState, generated_tokens: list[str], ) -> bool: matches = decode_state.answer_sequence_matches if matches is None: matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, ) if not matches or matches[0][0] < ANSWER_SEQUENCE_MATCH_FLOOR: return False if self._answer_sequence_has_continuation(generated_tokens, matches): return False if self._generated_answer_ends_terminal_sentence(generated_tokens): return True return self._generated_word_count(generated_tokens) >= 14 def _generated_answer_ends_terminal_sentence(self, generated_tokens: list[str]) -> bool: if not generated_tokens: return False rendered = self._render_token(generated_tokens[-1]) if not self._is_terminal_punctuation_text(rendered): return False return self._generated_word_count(generated_tokens) > 0 def _answer_decode_has_continuation( self, decode_state: DecodeState, generated_tokens: list[str], ) -> bool: matches = decode_state.answer_sequence_matches if matches is None: matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, ) return self._answer_sequence_has_continuation(generated_tokens, matches) def _answer_sequence_is_complete( self, generated_tokens: list[str], matches: list[tuple[float, int, int]], ) -> bool: if ( self.embedding_model is None or self.answer_sequence_tokens is None or not generated_tokens or not matches ): return False generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] if not generated_ids: return False for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): continue row = self.answer_sequence_tokens[sequence_index] token_ids = [ int(value) for value in (row.tolist() if hasattr(row, "tolist") else row) if int(value) >= 0 ] if not token_ids: continue if len(generated_ids) >= len(token_ids) and generated_ids[: len(token_ids)] == token_ids: return True if ( self.answer_fingerprint_hashes and len(generated_ids) + 1 == len(token_ids) and generated_ids == token_ids[: len(generated_ids)] and self._answer_fingerprint_from_token_ids(token_ids) in self.answer_fingerprint_hashes ): generated_tail = self._render_token(generated_tokens[-1]) if self._is_structural_punctuation_text( generated_tail ) and not self._is_terminal_punctuation_text(generated_tail): continue final_token = self.embedding_model.id_to_token[token_ids[-1]] if self._is_terminal_punctuation_text(self._render_token(final_token)): continue return True return False def _answer_sequence_has_continuation( self, generated_tokens: list[str], matches: list[tuple[float, int, int]], ) -> bool: if ( self.embedding_model is None or self.answer_sequence_tokens is None or not generated_tokens or not matches ): return False generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] if not generated_ids: return False for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): continue row = self.answer_sequence_tokens[sequence_index] token_ids = [ int(value) for value in (row.tolist() if hasattr(row, "tolist") else row) if int(value) >= 0 ] if not token_ids: continue next_token_id = self._next_sequence_token_id(token_ids, generated_ids) if next_token_id is None: continue token = self.embedding_model.id_to_token[next_token_id] if self._allowed_answer_sequence_token(token, generated_tokens): return True return False def _next_sequence_token_id( self, token_ids: list[int], generated_ids: list[int], ) -> int | None: if not generated_ids: return token_ids[0] if len(generated_ids) >= len(token_ids): return None if token_ids[: len(generated_ids)] != generated_ids: return None return token_ids[len(generated_ids)] def _transition_prior(self, context_tokens: list[str]) -> Vector: prior, _ = self._transition_prior_with_order(context_tokens) return prior def _transition_prior_with_order( self, context_tokens: list[str], ) -> tuple[Vector, int | None]: assert self.embedding_model is not None if self.transition_id_tables: for order in TRANSITION_ORDERS: if len(context_tokens) < order: continue key_ids: list[int] = [] for token in context_tokens[-order:]: token_id = self.embedding_model.token_to_id.get(token) if token_id is None: key_ids = [] break key_ids.append(token_id) if not key_ids: continue transitions = self._transition_tensor_lookup(order, key_ids) if transitions is None: transitions = self.transition_id_tables.get(order, {}).get(tuple(key_ids)) if not transitions: continue next_token_ids, probabilities = transitions prior = [0.0 for _ in self.embedding_model.id_to_token] for token_id, probability in zip(next_token_ids, probabilities): token_index = int(token_id) if 0 <= token_index < len(prior): prior[token_index] = float(probability) return _normalize_vector(prior), order if not self.transition_tables: return [0.0 for _ in self.embedding_model.id_to_token], None for order in TRANSITION_ORDERS: if len(context_tokens) < order: continue key = tuple(context_tokens[-order:]) transitions = self.transition_tables.get(order, {}).get(key) if not transitions: continue prior = [0.0 for _ in self.embedding_model.id_to_token] for token, probability in transitions.items(): token_id = self.embedding_model.token_to_id.get(token) if token_id is not None: prior[token_id] = probability return _normalize_vector(prior), order return [0.0 for _ in self.embedding_model.id_to_token], None def _transition_prior_array_with_order( self, context_tokens: list[str], ) -> tuple[object, int | None]: assert np is not None assert self.embedding_model is not None prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) if self.transition_id_tables: for order in TRANSITION_ORDERS: if len(context_tokens) < order: continue key_ids: list[int] = [] for token in context_tokens[-order:]: token_id = self.embedding_model.token_to_id.get(token) if token_id is None: key_ids = [] break key_ids.append(token_id) if not key_ids: continue transitions = self._transition_tensor_lookup(order, key_ids) if transitions is None: transitions = self.transition_id_tables.get(order, {}).get(tuple(key_ids)) if not transitions: continue next_token_ids, probabilities = transitions token_ids_array = np.asarray(next_token_ids, dtype=np.int64) probabilities_array = np.asarray(probabilities, dtype=np.float64) valid = ( (token_ids_array >= 0) & (token_ids_array < len(self.embedding_model.id_to_token)) & (probabilities_array > 0.0) ) if np.any(valid): prior[token_ids_array[valid]] = probabilities_array[valid] total = float(prior.sum()) if total > 0.0: prior /= total return prior, order return prior, None if not self.transition_tables: return prior, None for order in TRANSITION_ORDERS: if len(context_tokens) < order: continue key = tuple(context_tokens[-order:]) transitions = self.transition_tables.get(order, {}).get(key) if not transitions: continue for token, probability in transitions.items(): token_id = self.embedding_model.token_to_id.get(token) if token_id is not None: prior[token_id] = probability total = float(prior.sum()) if total > 0.0: prior /= total return prior, order return prior, None def _copy_prior(self, context_tokens: list[str]) -> Vector: assert self.embedding_model is not None assert self.tokenizer is not None prior = [0.0 for _ in self.embedding_model.id_to_token] decay = 0.82 answer_start = None for index in range(len(context_tokens) - 1, -1, -1): if context_tokens[index] == "": answer_start = index + 1 break source_tokens = ( context_tokens[: max(0, answer_start - 1)] if answer_start is not None else context_tokens ) if not source_tokens: return prior for distance, token in enumerate(reversed(source_tokens)): if token in self.tokenizer.special_tokens: continue if not self._eligible_copy_token(token): continue token_id = self.embedding_model.token_to_id.get(token) if token_id is None: continue prior[token_id] += (decay**distance) * self._copy_token_distinctiveness(token) return _normalize_vector(prior) def _copy_prior_array(self, context_tokens: list[str]) -> object: assert np is not None assert self.embedding_model is not None assert self.tokenizer is not None prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) decay = 0.82 answer_start = None for index in range(len(context_tokens) - 1, -1, -1): if context_tokens[index] == "": answer_start = index + 1 break source_tokens = ( context_tokens[: max(0, answer_start - 1)] if answer_start is not None else context_tokens ) for distance, token in enumerate(reversed(source_tokens)): if token in self.tokenizer.special_tokens: continue if not self._eligible_copy_token(token): continue token_id = self.embedding_model.token_to_id.get(token) if token_id is None: continue prior[token_id] += (decay**distance) * self._copy_token_distinctiveness(token) total = float(prior.sum()) if total > 0.0: prior /= total return prior def _copy_token_distinctiveness(self, token: str) -> float: rendered = self._render_token(token).strip() if not rendered: return 0.0 letters = sum(character.isalpha() for character in rendered) digits = sum(character.isdigit() for character in rendered) symbols = sum( not character.isalnum() and not character.isspace() for character in rendered ) score = 1.0 if any(character.isupper() for character in rendered) and letters: score += 0.8 if digits: score += 0.9 if symbols: score += 0.5 if len(rendered) >= 4: score += 0.2 return score def _prompt_copy_evidence_is_distinctive(self, context_tokens: list[str]) -> bool: answer_start = None for index in range(len(context_tokens) - 1, -1, -1): if context_tokens[index] == "": answer_start = index break prompt_tokens = context_tokens[:answer_start] if answer_start is not None else context_tokens for token in prompt_tokens: if self.tokenizer is not None and token in self.tokenizer.special_tokens: continue rendered = self._render_token(token).strip() if any(character.isdigit() for character in rendered): return True if sum(character.isupper() for character in rendered) >= 2: return True return False def _source_evidence_prior( self, context_tokens: list[str], generated_tokens: list[str] | None = None, ) -> Vector: assert self.embedding_model is not None prior = [0.0 for _ in self.embedding_model.id_to_token] for token_id, weight in self._source_evidence_token_weights( context_tokens, generated_tokens or [], ).items(): if 0 <= token_id < len(prior): prior[token_id] += weight return _normalize_vector(prior) def _source_evidence_prior_array( self, context_tokens: list[str], generated_tokens: list[str] | None = None, ) -> object: assert np is not None assert self.embedding_model is not None prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) for token_id, weight in self._source_evidence_token_weights( context_tokens, generated_tokens or [], ).items(): if 0 <= token_id < prior.size: prior[token_id] += weight total = float(prior.sum()) if total > 0.0: prior /= total return prior def _source_evidence_token_weights( self, context_tokens: list[str], generated_tokens: list[str], ) -> dict[int, float]: if self.embedding_model is None or self.tokenizer is None: return {} segments = self._source_evidence_segments(context_tokens) if not segments: return {} generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] first_source_index = _first_index(context_tokens, "") query_tokens = ( context_tokens[:first_source_index] if first_source_index is not None else context_tokens ) query_token_ids = { self.embedding_model.token_to_id[token] for token in query_tokens if token in self.embedding_model.token_to_id and token not in self.tokenizer.special_tokens and self._eligible_copy_token(token) } weights: dict[int, float] = {} def add_token(token: str, weight: float, *, allow_piece: bool = False) -> None: if token in self.tokenizer.special_tokens: return if not allow_piece and not self._allowed_generation_token(token, generated_tokens): return if allow_piece: rendered = self._render_token(token) if not rendered or not rendered.strip(): return elif not self._eligible_copy_token(token): return token_id = self.embedding_model.token_to_id.get(token) if token_id is None: return weights[token_id] = weights.get(token_id, 0.0) + weight for segment_tokens, segment_weight, segment_role in segments[-6:]: if generated_ids and segment_role != "snippet": continue token_ids = [ self.embedding_model.token_to_id[token] for token in segment_tokens if token in self.embedding_model.token_to_id ] aligned = False if generated_ids and token_ids: max_suffix = min(8, len(generated_ids), len(token_ids)) for suffix_length in range(max_suffix, 0, -1): suffix = generated_ids[-suffix_length:] for index in range(len(token_ids) - suffix_length): if token_ids[index : index + suffix_length] != suffix: continue next_token_id = token_ids[index + suffix_length] next_token = self.embedding_model.id_to_token[next_token_id] add_token( next_token, segment_weight * (3.0 + suffix_length), allow_piece=True, ) aligned = True if aligned: break if aligned: continue content_rank = 0 anchor_seen = False segment_has_query_anchor = any(token_id in query_token_ids for token_id in token_ids) for token in segment_tokens: rendered = self._render_token(token) if "://" in rendered or rendered.casefold().startswith("http"): continue if not self._eligible_copy_token(token): continue token_id = self.embedding_model.token_to_id.get(token) if token_id is None: continue if segment_has_query_anchor: in_query = token_id in query_token_ids if in_query: weight = segment_weight * 0.42 anchor_seen = True elif anchor_seen: weight = segment_weight * 2.10 else: weight = segment_weight * 0.32 elif content_rank == 0: weight = segment_weight * 4.0 elif content_rank == 1: weight = segment_weight * 1.35 else: weight = segment_weight * 0.65 weight *= 0.94 ** min(content_rank, 24) add_token(token, weight) content_rank += 1 return weights def _source_evidence_segments(self, context_tokens: list[str]) -> list[tuple[list[str], float, str]]: if self.tokenizer is None: return [] answer_boundary = _last_index(context_tokens, "") upper_bound = answer_boundary if answer_boundary is not None else len(context_tokens) boundary_tokens = {"", "", "", "", ""} segments: list[tuple[list[str], float, str]] = [] index = 0 while index < upper_bound: if context_tokens[index] != "": index += 1 continue start = index + 1 end = start while ( end < upper_bound and context_tokens[end] not in boundary_tokens and self._render_token(context_tokens[end]) != "\n" ): end += 1 source_tokens = context_tokens[start:end] pipe_positions = [ position for position, token in enumerate(source_tokens) if self._render_token(token).strip() == "|" ] if pipe_positions: snippet_tokens = source_tokens[pipe_positions[-1] + 1 :] if snippet_tokens: segments.append((snippet_tokens, 1.0, "snippet")) elif source_tokens: segments.append((source_tokens, 0.90, "snippet")) index = end + 1 return segments def _source_evidence_is_complete( self, context_tokens: list[str], generated_tokens: list[str], ) -> bool: if ( self.embedding_model is None or self.tokenizer is None or self._generated_word_count(generated_tokens) < 5 ): return False generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] if not generated_ids: return False for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens): if segment_role != "snippet": continue segment_ids = [ self.embedding_model.token_to_id[token] for token in segment_tokens if token in self.embedding_model.token_to_id ] if len(generated_ids) > len(segment_ids): continue max_suffix = min(12, len(generated_ids), len(segment_ids)) for suffix_length in range(max_suffix, 4, -1): suffix_ids = generated_ids[-suffix_length:] for start in range(len(segment_ids) - suffix_length + 1): if segment_ids[start : start + suffix_length] != suffix_ids: continue next_index = start + suffix_length if next_index >= len(segment_ids): return True next_token = self.embedding_model.id_to_token[segment_ids[next_index]] if self._source_punctuation_continues_numeric_span( segment_ids, next_index, ): return False if self._is_terminal_punctuation_text(self._render_token(next_token)): return True return False def _source_evidence_has_continuation( self, context_tokens: list[str], generated_tokens: list[str], ) -> bool: if self.embedding_model is None or not generated_tokens: return False generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] if not generated_ids: return False for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens): if segment_role != "snippet": continue segment_ids = [ self.embedding_model.token_to_id[token] for token in segment_tokens if token in self.embedding_model.token_to_id ] max_suffix = min(12, len(generated_ids), len(segment_ids)) for suffix_length in range(max_suffix, 0, -1): suffix_ids = generated_ids[-suffix_length:] for start in range(len(segment_ids) - suffix_length + 1): if segment_ids[start : start + suffix_length] != suffix_ids: continue next_index = start + suffix_length if next_index >= len(segment_ids): return False if self._source_punctuation_continues_numeric_span( segment_ids, next_index, ) or self._source_punctuation_continues_numeric_span( segment_ids, next_index - 1, ): return True next_token = self.embedding_model.id_to_token[segment_ids[next_index]] return not self._is_terminal_punctuation_text( self._render_token(next_token) ) return False def _source_evidence_next_token( self, context_tokens: list[str], generated_tokens: list[str], ) -> str | None: if self.embedding_model is None: return None for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens): if segment_role != "snippet" or not segment_tokens: continue if not generated_tokens: return segment_tokens[0] segment_ids = [ self.embedding_model.token_to_id[token] for token in segment_tokens if token in self.embedding_model.token_to_id ] generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] if not segment_ids or not generated_ids: continue max_suffix = min(12, len(generated_ids), len(segment_ids)) for suffix_length in range(max_suffix, 0, -1): suffix_ids = generated_ids[-suffix_length:] for start in range(len(segment_ids) - suffix_length + 1): if segment_ids[start : start + suffix_length] != suffix_ids: continue next_index = start + suffix_length if next_index < len(segment_ids): return self.embedding_model.id_to_token[segment_ids[next_index]] return None def _source_punctuation_continues_numeric_span( self, segment_ids: list[int], punctuation_index: int, ) -> bool: if self.embedding_model is None: return False if punctuation_index <= 0 or punctuation_index + 1 >= len(segment_ids): return False punctuation_text = self._render_token( self.embedding_model.id_to_token[segment_ids[punctuation_index]] ).strip() if not self._is_structural_punctuation_text(punctuation_text): return False previous_text = self._render_token( self.embedding_model.id_to_token[segment_ids[punctuation_index - 1]] ) next_text = self._render_token( self.embedding_model.id_to_token[segment_ids[punctuation_index + 1]] ) return any(character.isdigit() for character in previous_text) and any( character.isdigit() for character in next_text ) def _preference_prior(self) -> Vector: assert self.embedding_model is not None if not self.preference_bias or not any(value != 0.0 for value in self.preference_bias): return [0.0 for _ in self.embedding_model.id_to_token] eligible_indices = [ index for index, token in enumerate(self.embedding_model.id_to_token) if self.preference_bias[index] > 0.0 and self._eligible_preference_token(token) ] if not eligible_indices: return [0.0 for _ in self.embedding_model.id_to_token] eligible_probabilities = self._calibrated_softmax( [self.preference_bias[index] for index in eligible_indices] ) prior = [0.0 for _ in self.embedding_model.id_to_token] for index, probability in zip(eligible_indices, eligible_probabilities): prior[index] = probability return prior def _preference_prior_array(self) -> object: assert np is not None assert self.embedding_model is not None if self.preference_bias_array is None or not np.any(self.preference_bias_array != 0.0): return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) if self.preference_valid_mask_array is None or not np.any(self.preference_valid_mask_array): return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) positive_mask = self.preference_bias_array > 0.0 active_mask = self.preference_valid_mask_array & positive_mask if not np.any(active_mask): return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) prior[active_mask] = self._calibrated_softmax_array( self.preference_bias_array[active_mask] ) return prior def _eligible_preference_token(self, token: str) -> bool: assert self.tokenizer is not None if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens: return False if not self._starts_new_word(token): return False rendered = self._render_token(token) if not rendered.strip() or self._is_punctuation_piece(rendered): return False alphanumeric = "".join(character for character in rendered if character.isalnum()) return len(alphanumeric) >= 1 def _build_transition_tables( self, tokens: list[str], ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: counts: dict[int, dict[tuple[str, ...], dict[str, int]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } for order in sorted(TRANSITION_ORDERS): for index in range(order - 1, len(tokens) - 1): key = tuple(tokens[index - order + 1 : index + 1]) nxt = tokens[index + 1] bucket = counts[order].setdefault(key, {}) bucket[nxt] = bucket.get(nxt, 0) + 1 probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } for order, mapping in counts.items(): items = list(mapping.items()) items.sort(key=lambda item: (-sum(item[1].values()), item[0])) if ( self.config.max_transition_contexts_per_order is not None and self.config.max_transition_contexts_per_order >= 0 ): items = items[: self.config.max_transition_contexts_per_order] for key, bucket in items: next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0])) if self.config.max_transition_next_tokens > 0: next_items = next_items[: self.config.max_transition_next_tokens] total = sum(value for _, value in next_items) if total <= 0: continue probabilities[order][key] = { token: value / total for token, value in next_items } return probabilities def _transition_table_tensors(self) -> dict[str, object]: assert self.embedding_model is not None if self.transition_tensor_cache is not None: return { "transition_orders": self.transition_tensor_cache["orders"], "transition_key_offsets": self.transition_tensor_cache["key_offsets"], "transition_key_token_ids": self.transition_tensor_cache["key_token_ids"], "transition_next_offsets": self.transition_tensor_cache["next_offsets"], "transition_next_token_ids": self.transition_tensor_cache["next_token_ids"], "transition_next_probabilities": self.transition_tensor_cache["next_probabilities"], } if not self.transition_tables: return { "transition_orders": [], "transition_key_offsets": [0], "transition_key_token_ids": [], "transition_next_offsets": [0], "transition_next_token_ids": [], "transition_next_probabilities": [], } token_to_id = self.embedding_model.token_to_id orders: list[int] = [] key_offsets: list[int] = [0] key_token_ids: list[int] = [] next_offsets: list[int] = [0] next_token_ids: list[int] = [] next_probabilities: list[float] = [] for order in sorted(self.transition_tables): mapping = self.transition_tables.get(order, {}) for key, transitions in mapping.items(): key_ids = [token_to_id.get(token, -1) for token in key] if len(key_ids) != order or any(token_id < 0 for token_id in key_ids): continue next_items = [ (token_to_id[token], float(probability)) for token, probability in transitions.items() if token in token_to_id and probability > 0.0 ] if not next_items: continue orders.append(order) key_token_ids.extend(key_ids) key_offsets.append(len(key_token_ids)) for token_id, probability in next_items: next_token_ids.append(token_id) next_probabilities.append(probability) next_offsets.append(len(next_token_ids)) return { "transition_orders": orders, "transition_key_offsets": key_offsets, "transition_key_token_ids": key_token_ids, "transition_next_offsets": next_offsets, "transition_next_token_ids": next_token_ids, "transition_next_probabilities": next_probabilities, } def _deserialize_transition_id_tables_from_tensors( self, tensors: dict[str, object], ) -> dict[int, dict[tuple[int, ...], tuple[object, object]]] | None: required = ( "transition_orders", "transition_key_offsets", "transition_key_token_ids", "transition_next_offsets", "transition_next_token_ids", "transition_next_probabilities", ) if any(name not in tensors for name in required): return None def _as_sequence(name: str) -> object: value = tensors.get(name, []) return value if hasattr(value, "shape") else list(value) orders = _as_sequence("transition_orders") key_offsets = _as_sequence("transition_key_offsets") key_token_ids = _as_sequence("transition_key_token_ids") next_offsets = _as_sequence("transition_next_offsets") next_token_ids = _as_sequence("transition_next_token_ids") next_probabilities = _as_sequence("transition_next_probabilities") row_count = len(orders) if row_count == 0: return {order: {} for order in sorted(TRANSITION_ORDERS)} if len(key_offsets) != row_count + 1 or len(next_offsets) != row_count + 1: return None if np is not None and hasattr(orders, "shape"): self.transition_tensor_cache = { "orders": orders, "key_offsets": key_offsets, "key_token_ids": key_token_ids, "next_offsets": next_offsets, "next_token_ids": next_token_ids, "next_probabilities": next_probabilities, "order_spans": {}, } self.transition_built_orders = set() return {order: {} for order in sorted(TRANSITION_ORDERS)} tables: dict[int, dict[tuple[int, ...], tuple[object, object]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } for index in range(row_count): order = int(orders[index]) key_start = int(key_offsets[index]) key_end = int(key_offsets[index + 1]) next_start = int(next_offsets[index]) next_end = int(next_offsets[index + 1]) key = tuple(int(token_id) for token_id in key_token_ids[key_start:key_end]) if len(key) != order or next_end <= next_start: continue tables.setdefault(order, {})[key] = ( next_token_ids[next_start:next_end], next_probabilities[next_start:next_end], ) return tables def _serialize_transition_tables(self) -> dict[str, dict[str, dict[str, float]]]: assert self.transition_tables is not None return { str(order): { _encode_ngram_key(key): value for key, value in mapping.items() } for order, mapping in self.transition_tables.items() } def _deserialize_transition_tables( self, payload: dict[str, dict[str, dict[str, float]]], ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: tables: dict[int, dict[tuple[str, ...], dict[str, float]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } for order_text, mapping in payload.items(): order = int(order_text) tables[order] = { _decode_ngram_key(key): { str(token): float(probability) for token, probability in value.items() } for key, value in mapping.items() } return tables def _transition_tensor_order_span(self, order: int) -> tuple[int, int] | None: if np is None or self.transition_tensor_cache is None: return None spans = self.transition_tensor_cache.get("order_spans") if isinstance(spans, dict) and order in spans: return spans[order] orders = self.transition_tensor_cache["orders"] positions = np.flatnonzero(orders == order) span = ( (int(positions[0]), int(positions[-1]) + 1) if positions.size else None ) if isinstance(spans, dict): spans[order] = span return span def _transition_tensor_lookup( self, order: int, key_ids: list[int], ) -> tuple[object, object] | None: if ( np is None or self.transition_tensor_cache is None or len(key_ids) != order ): return None span = self._transition_tensor_order_span(order) if span is None: return None row_start, row_end = span key_offsets = self.transition_tensor_cache["key_offsets"] key_token_ids = self.transition_tensor_cache["key_token_ids"] next_offsets = self.transition_tensor_cache["next_offsets"] next_token_ids = self.transition_tensor_cache["next_token_ids"] next_probabilities = self.transition_tensor_cache["next_probabilities"] key_start = int(key_offsets[row_start]) key_end = int(key_offsets[row_end]) key_block = np.asarray(key_token_ids[key_start:key_end], dtype=np.int64) row_count = row_end - row_start if row_count <= 0 or key_block.size != row_count * order: return None keys = key_block.reshape(row_count, order) query = np.asarray(key_ids, dtype=np.int64) matches = np.flatnonzero(np.all(keys == query[None, :], axis=1)) if not matches.size: return None row = row_start + int(matches[0]) next_start = int(next_offsets[row]) next_end = int(next_offsets[row + 1]) if next_end <= next_start: return None return ( next_token_ids[next_start:next_end], next_probabilities[next_start:next_end], ) def _eligible_copy_token(self, token: str) -> bool: rendered = self._render_token(token) if not rendered.strip(): return False if self._is_punctuation_piece(rendered): return False if not self._starts_new_word(token): return False alphanumeric = "".join(character for character in rendered if character.isalnum()) return len(alphanumeric) >= 2 def _allowed_generation_token( self, token: str, generated_tokens: list[str], context_tokens: list[str] | None = None, ) -> bool: return self._allowed_generation_token_with_meta( token, self._generation_token_meta(token), generated_tokens, context_tokens, ) def _allowed_generation_token_with_meta( self, token: str, meta: GenerationTokenMeta, generated_tokens: list[str], context_tokens: list[str] | None = None, ) -> bool: assert self.embedding_model is not None assert self.tokenizer is not None if token == self.tokenizer.unk_token: return False if token in self.tokenizer.special_tokens: return self._allowed_tool_protocol_token( token, generated_tokens=generated_tokens, context_tokens=context_tokens or [], ) if len(self.embedding_model.id_to_token) < 1024: return True if meta.rendered == "\n": return bool(generated_tokens) if not meta.stripped: return False if meta.word_joiner: return ( self._can_attach_word_joiner(generated_tokens) or self._can_start_line_with_word_joiner(token, generated_tokens) ) if meta.structural_punctuation: return bool(generated_tokens) or self._can_start_answer_with_structural_punctuation(token) if meta.structural_symbol: return bool(generated_tokens) or meta.starts_new_word if not meta.starts_new_word: if not generated_tokens: return False previous_rendered = self._render_token(generated_tokens[-1]) return ( bool(previous_rendered) and any(character.isalnum() for character in previous_rendered) and bool(meta.alphanumeric) ) return len(meta.alphanumeric) >= 1 or not meta.punctuation_piece @staticmethod def _allowed_tool_protocol_token( token: str, *, generated_tokens: list[str], context_tokens: list[str], ) -> bool: if token not in TOOL_PROTOCOL_TOKENS: return False if token == "": return ( ReframrModel._context_requests_tool_call(context_tokens) and "" not in generated_tokens and "" not in generated_tokens and "" not in generated_tokens ) if token in {"", ""}: return False if token == "": return ( "" in context_tokens or "" in context_tokens or "" in context_tokens ) return True @staticmethod def _context_requests_tool_call(context_tokens: list[str]) -> bool: rendered_terms: list[str] = [] for token in context_tokens: if token in TOOL_PROTOCOL_TOKENS or token.startswith("<"): continue normalized = token.replace("▁", " ").strip().casefold() if not normalized: continue rendered_terms.append(normalized) pieces = { "".join( character for character in piece if character.isalnum() or character in {"-", "."} ) for piece in normalized.split() } if pieces & TOOL_CALL_CONTEXT_TERMS: return True joined = " ".join(rendered_terms) compact = "".join(rendered_terms) return any( term in joined or term.replace("-", "") in compact for term in TOOL_CALL_CONTEXT_TERMS ) def _would_repeat_recent_pattern( self, candidate: str, generated_tokens: list[str], recent_rendered_words: list[str] | None = None, ) -> bool: if len(generated_tokens) >= 2 and generated_tokens[-1] == candidate and generated_tokens[-2] == candidate: return True if len(generated_tokens) >= 2: trigram = tuple(generated_tokens[-2:] + [candidate]) recent_tokens = generated_tokens[-12:] for index in range(max(0, len(recent_tokens) - 4)): if tuple(recent_tokens[index : index + 3]) == trigram: return True rendered_words = recent_rendered_words if rendered_words is None: rendered_words = self._recent_rendered_words(generated_tokens) candidate_meta = self._generation_token_meta(candidate) candidate_word = candidate_meta.rendered.casefold() if ( rendered_words and candidate_meta.starts_new_word and any(character.isalnum() for character in candidate_word) ): candidate_bigram = (rendered_words[-1], candidate_word) recent_window = rendered_words[-10:] recent_bigrams = { (recent_window[index], recent_window[index + 1]) for index in range(len(recent_window) - 1) } if candidate_bigram in recent_bigrams: return True if ( len(candidate_word) > 2 and rendered_words[-10:].count(candidate_word) >= 2 and not candidate_meta.common_connector ): return True return False @staticmethod def _is_inside_tool_protocol_continuation(generated_tokens: list[str]) -> bool: return any(token in TOOL_PROTOCOL_TOKENS for token in generated_tokens[-6:]) def _would_repeat_recent_phrase( self, candidate: str, generated_tokens: list[str], *, recent_rendered_words: list[str] | None = None, ) -> bool: if not self._starts_new_word(candidate): return False rendered_words = list( recent_rendered_words if recent_rendered_words is not None else self._recent_rendered_words(generated_tokens) ) candidate_word = self._render_token(candidate).casefold() if not any(character.isalnum() for character in candidate_word): return False rendered_words.append(candidate_word) recent_window = rendered_words[-48:] for span in range(4, min(8, len(recent_window)) + 1): suffix = tuple(recent_window[-span:]) earlier = recent_window[:-span] for index in range(len(earlier) - span + 1): if tuple(earlier[index : index + span]) == suffix: return True return False def _recent_phrase_repeat_candidate_words( self, recent_rendered_words: list[str], ) -> set[str]: repeat_candidates: set[str] = set() base_window = recent_rendered_words[-47:] max_span = min(8, len(base_window) + 1) if max_span < 4: return repeat_candidates for span in range(4, max_span + 1): prefix_length = span - 1 suffix_prefix = tuple(base_window[-prefix_length:]) earlier_length = len(base_window) - prefix_length if earlier_length < span: continue for index in range(earlier_length - span + 1): earlier_segment = base_window[index : index + span] if tuple(earlier_segment[:-1]) == suffix_prefix: candidate_word = earlier_segment[-1] if any(character.isalnum() for character in candidate_word): repeat_candidates.add(candidate_word) return repeat_candidates def _recent_rendered_words(self, generated_tokens: list[str]) -> list[str]: rendered_words: list[str] = [] for token in generated_tokens: if not self._starts_new_word(token): continue rendered = self._render_token(token).casefold() if any(character.isalnum() for character in rendered): rendered_words.append(rendered) return rendered_words def _select_generation_token( self, distribution: dict[str, float], *, context_tokens: list[str] | None = None, generated_tokens: list[str] | None = None, temperature: float = DEFAULT_GENERATION_TEMPERATURE, top_k: int = DEFAULT_GENERATION_TOP_K, top_p: float = DEFAULT_GENERATION_TOP_P, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, preserve_dominant_candidates: bool = False, avoid_token_sequences: Sequence[Sequence[str]] | None = None, ) -> str: assert self.tokenizer is not None generated_tokens = generated_tokens or [] candidates = self._prepare_generation_candidates( distribution, context_tokens=context_tokens or [], generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=preserve_dominant_candidates, avoid_token_sequences=avoid_token_sequences, ) if candidates: return self._sample_generation_candidate( candidates, context_tokens=context_tokens or [], generated_tokens=generated_tokens, stochastic=temperature > 0.0, preserve_dominant_candidates=preserve_dominant_candidates, ) for token, _ in sorted(distribution.items(), key=lambda item: item[1], reverse=True): if token in self.tokenizer.special_tokens and token not in TOOL_PROTOCOL_TOKENS: continue if token == self.tokenizer.unk_token: continue if not self._allowed_generation_token(token, generated_tokens, context_tokens or []): continue if self._would_complete_blacklisted_answer(generated_tokens, token): continue return token return "" def _select_generation_token_from_array( self, probabilities: object, *, context_tokens: list[str], generated_tokens: list[str], temperature: float = DEFAULT_GENERATION_TEMPERATURE, top_k: int = DEFAULT_GENERATION_TOP_K, top_p: float = DEFAULT_GENERATION_TOP_P, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, preserve_dominant_candidates: bool = False, avoid_token_sequences: Sequence[Sequence[str]] | None = None, ) -> str: assert np is not None assert self.tokenizer is not None assert self.embedding_model is not None values = np.asarray(probabilities, dtype=np.float64) if values.size == 0: return "" first_pool_size = min(values.size, max(top_k, 64)) if first_pool_size <= 0: first_pool_size = min(values.size, 64) expanded_pool_size = min(values.size, max(top_k * 4, 64)) pool_sizes: list[int] = [] for pool_size in (first_pool_size, expanded_pool_size, values.size): if pool_size > 0 and pool_size not in pool_sizes: pool_sizes.append(pool_size) for pool_size in pool_sizes: if pool_size < values.size: candidate_indices = np.argpartition(values, -pool_size)[-pool_size:] candidate_indices = candidate_indices[np.argsort(values[candidate_indices])[::-1]] else: candidate_indices = np.argsort(values)[::-1] distribution: dict[str, float] = {} for raw_index in candidate_indices: index = int(raw_index) score = float(values[index]) if score <= 0.0: continue token = self.embedding_model.id_to_token[index] if ( token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens and token not in TOOL_PROTOCOL_TOKENS ): continue distribution[token] = score selected = self._select_generation_token( distribution, context_tokens=context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=preserve_dominant_candidates, avoid_token_sequences=avoid_token_sequences, ) if selected: return selected return "" def _prepare_generation_candidates( self, distribution: dict[str, float], *, context_tokens: list[str] | None = None, generated_tokens: list[str], temperature: float, top_k: int, top_p: float, repetition_penalty: float, preserve_dominant_candidates: bool = False, avoid_token_sequences: Sequence[Sequence[str]] | None = None, ) -> list[tuple[str, float]]: assert self.tokenizer is not None assert self.embedding_model is not None context_tokens = context_tokens or [] generated_word_count = self._generated_word_count(generated_tokens) clause_words = self._words_since_clause_break(generated_tokens) recent_rendered_words = self._recent_rendered_words(generated_tokens) generated_token_ids = self._token_ids_for_generated_tokens(generated_tokens) inside_tool_protocol = self._is_inside_tool_protocol_continuation(generated_tokens) phrase_repeat_candidate_words = ( self._recent_phrase_repeat_candidate_words(recent_rendered_words) if generated_word_count >= MIN_COMPLETE_ANSWER_WORDS and not inside_tool_protocol else set() ) prompt_content_tokens = [ token for token in context_tokens if token not in self.tokenizer.special_tokens and self._generation_token_meta(token).starts_new_word and self._generation_token_meta(token).alphanumeric and not self._generation_token_meta(token).punctuation_piece ] initial_prompt_content_token = ( prompt_content_tokens[0] if len(prompt_content_tokens) > 1 else None ) best_probability = max(distribution.values(), default=0.0) has_uppercase_start_candidate = any( probability > 0.0 and self._generation_token_meta(token).starts_new_word and self._generation_token_meta(token).rendered[:1].isupper() for token, probability in distribution.items() ) adjusted: list[tuple[str, float]] = [] for token, probability in sorted(distribution.items(), key=lambda item: item[1], reverse=True): if token in self.tokenizer.special_tokens and token not in TOOL_PROTOCOL_TOKENS: continue if token == self.tokenizer.unk_token or probability <= 0.0: continue meta = self._generation_token_meta(token) allowed_by_general_filter = self._allowed_generation_token_with_meta( token, meta, generated_tokens, context_tokens, ) if not allowed_by_general_filter: dominant_learned_continuation = ( preserve_dominant_candidates and best_probability > 0.0 and probability >= best_probability * 0.99 and self._allowed_answer_sequence_token(token, generated_tokens) ) if not dominant_learned_continuation: continue if self._would_complete_blacklisted_answer_ids(generated_token_ids, token): continue repeats_recent_pattern = self._would_repeat_recent_pattern( token, generated_tokens, recent_rendered_words=recent_rendered_words, ) hard_phrase_loop = ( generated_word_count >= MIN_COMPLETE_ANSWER_WORDS and not inside_tool_protocol and meta.starts_new_word and meta.rendered.casefold() in phrase_repeat_candidate_words ) if hard_phrase_loop: continue if repeats_recent_pattern: dominant_candidate_allowed = ( preserve_dominant_candidates and best_probability > 0.0 and probability >= best_probability * 0.80 ) if not dominant_candidate_allowed: continue score = probability if ( temperature >= ANSWER_REPLAY_PREFIX_TEMPERATURE and not inside_tool_protocol and self._would_follow_blacklisted_answer_prefix_ids( generated_token_ids, token, ) ): score *= ANSWER_REPLAY_PREFIX_PENALTY if ( temperature > 0.0 and self._would_follow_avoided_sequence( generated_tokens, token, avoid_token_sequences, ) ): score *= 0.12 rendered = meta.rendered punctuation_token = meta.structural_punctuation starts_new_word = meta.starts_new_word alphanumeric = meta.alphanumeric if ( not generated_tokens and initial_prompt_content_token is not None and token == initial_prompt_content_token ): dominant_answer_candidate = ( preserve_dominant_candidates and best_probability > 0.0 and probability >= best_probability * 0.80 ) if not dominant_answer_candidate: continue if ( not generated_tokens and temperature > 0.0 and has_uppercase_start_candidate and starts_new_word and rendered[:1].islower() and best_probability > 0.0 and probability < best_probability * 0.85 ): continue if generated_tokens and starts_new_word and alphanumeric: previous_alphanumeric = self._generation_token_meta( generated_tokens[-1] ).alphanumeric if previous_alphanumeric.casefold() == alphanumeric.casefold(): continue common_connector = meta.common_connector if ( starts_new_word and len(alphanumeric) == 1 and not common_connector ): score *= 0.08 recent_count = generated_tokens[-12:].count(token) if recent_count > 0 and not common_connector: score /= repetition_penalty ** (2 * recent_count) if generated_tokens and token == generated_tokens[-1]: score /= repetition_penalty**3 if generated_tokens and token in generated_tokens[-4:] and not common_connector: score *= 0.35 if generated_tokens and not starts_new_word and self._starts_new_word(generated_tokens[-1]): score *= 0.08 if not generated_tokens and punctuation_token: if best_probability <= 0.0 or probability < best_probability * 0.80: score *= 0.01 elif not generated_tokens and not starts_new_word: score *= 0.02 if ( not generated_tokens and temperature > 0.0 and has_uppercase_start_candidate and starts_new_word and rendered[:1].islower() ): score *= 0.03 if punctuation_token: if generated_tokens and self._is_structural_punctuation_token(generated_tokens[-1]): score *= 0.05 if clause_words >= 6: score *= 1.0 + min(1.4, 0.18 * (clause_words - 5)) elif generated_word_count >= 12: score *= 1.1 if score > 0.0: adjusted.append((token, score)) if not adjusted: return [] adjusted.sort(key=lambda item: item[1], reverse=True) if preserve_dominant_candidates: top_score = adjusted[0][1] second_score = adjusted[1][1] if len(adjusted) > 1 else 0.0 if top_score >= 0.5 and ( second_score <= 0.0 or top_score >= second_score * 1.2 or top_score - second_score >= 0.08 ): return [(adjusted[0][0], 1.0)] effective_top_k = top_k if ( temperature >= CREATIVE_EARLY_POOL_TEMPERATURE and generated_word_count < CREATIVE_EARLY_POOL_WORD_LIMIT and not inside_tool_protocol and top_k > CREATIVE_EARLY_POOL_MAX ): effective_top_k = CREATIVE_EARLY_POOL_MAX if effective_top_k > 0: adjusted = adjusted[:effective_top_k] if 0.0 < top_p < 1.0: kept: list[tuple[str, float]] = [] cumulative = 0.0 total = sum(score for _, score in adjusted) for token, score in adjusted: normalized = score / total if total else 0.0 kept.append((token, score)) cumulative += normalized if cumulative >= top_p: break adjusted = kept if temperature <= 0.0: return [(adjusted[0][0], 1.0)] exponent = 1.0 / temperature tempered = [ (token, score**exponent) for token, score in adjusted if score > 0.0 ] total = sum(score for _, score in tempered) if total <= 0.0: return [] return [(token, score / total) for token, score in tempered] def _sample_generation_candidate( self, candidates: list[tuple[str, float]], *, context_tokens: list[str], generated_tokens: list[str], stochastic: bool = False, preserve_dominant_candidates: bool = False, ) -> str: if not candidates: return "" if len(candidates) == 1: return candidates[0][0] top_probability = candidates[0][1] second_probability = candidates[1][1] top_has_clear_half_majority = top_probability >= 0.5 and ( second_probability <= 0.0 or top_probability - second_probability >= 0.02 ) if preserve_dominant_candidates and top_has_clear_half_majority: return candidates[0][0] decisive_stochastic_winner = stochastic and ( top_probability >= 0.985 or ( top_probability >= 0.96 and second_probability > 0.0 and top_probability >= second_probability * 20.0 ) or ( top_probability >= 0.90 and second_probability > 0.0 and top_probability >= second_probability * 40.0 ) or ( top_probability >= 0.90 and top_probability - second_probability >= 0.75 ) ) decisive_deterministic_winner = not stochastic and ( top_has_clear_half_majority or (second_probability > 0.0 and top_probability >= second_probability * 2.5) or ( top_probability >= 0.08 and second_probability > 0.0 and top_probability >= second_probability * 1.35 ) ) if decisive_stochastic_winner or decisive_deterministic_winner: return candidates[0][0] if stochastic: threshold = random.random() else: seed_payload = "\u0002".join([*context_tokens, "", *generated_tokens, str(len(candidates))]) seed = int.from_bytes(hashlib.sha256(seed_payload.encode("utf-8")).digest()[:8], "big") threshold = random.Random(seed).random() cumulative = 0.0 for token, probability in candidates: cumulative += probability if threshold <= cumulative: return token return candidates[-1][0] def _top_entries_from_vector( self, values: Vector, limit: int, ) -> list[dict[str, object]]: if limit <= 0: return [] ranked = sorted( enumerate(values), key=lambda item: item[1], reverse=True, ) return [ self._token_entry(index, probability) for index, probability in ranked[:limit] if probability > 0.0 ] def _token_entry( self, index: int, probability: float, ) -> dict[str, object]: assert self.embedding_model is not None token = self.embedding_model.id_to_token[index] return { "token": token, "text": self._render_token(token), "probability": probability, } def _build_reasoning_summary( self, transition_order: int | None, blend_weights: dict[str, float], ) -> str: dominant_source = max(blend_weights.items(), key=lambda item: item[1])[0] if blend_weights else "base" if transition_order is not None: transition_message = f" Transition prior is using order-{transition_order} context." else: transition_message = " Transition prior found no matching n-gram." return ( "Generation is running on analytical state, recurrent traces, and corpus-derived token transitions." f"{transition_message}" f" Dominant blend source: {dominant_source}." ) def _generated_word_count(self, tokens: list[str]) -> int: count = 0 for token in tokens: rendered = self._render_token(token) if not any(character.isalnum() for character in rendered): continue if self._starts_new_word(token) or count == 0: count += 1 return count def _is_structural_punctuation_text(self, text: str) -> bool: if len(text) != 1: return False if self._is_word_joiner_text(text): return False category = unicodedata.category(text) return category.startswith("P") def _is_structural_punctuation_token(self, token: str) -> bool: return self._is_structural_punctuation_text(self._render_token(token)) def _is_structural_symbol_token(self, token: str) -> bool: rendered = self._render_token(token) return len(rendered) == 1 and unicodedata.category(rendered).startswith("S") def _is_word_joiner_token(self, token: str) -> bool: return self._is_word_joiner_text(self._render_token(token)) def _is_word_joiner_text(self, text: str) -> bool: if len(text) != 1: return False category = unicodedata.category(text) if category in ("Pc", "Pd", "Lm"): return True name = unicodedata.name(text, "") return "APOSTROPHE" in name or ( "SINGLE" in name and "QUOTATION MARK" in name ) def _can_start_line_with_word_joiner(self, token: str, generated_tokens: list[str]) -> bool: rendered = self._render_token(token) if len(rendered) != 1 or unicodedata.category(rendered) != "Pd": return False if not self._starts_new_word(token): return False return not generated_tokens or self._render_token(generated_tokens[-1]) == "\n" def _can_start_answer_with_structural_punctuation(self, token: str) -> bool: rendered = self._render_token(token) if len(rendered) != 1 or not self._starts_new_word(token): return False return unicodedata.category(rendered) in ("Ps", "Pi") def _is_common_connector_token(self, token: str) -> bool: rendered = self._render_token(token) return rendered.isalpha() and len(rendered) == 1 and rendered.islower() def _can_attach_word_joiner(self, generated_tokens: list[str]) -> bool: if not generated_tokens: return False rendered = self._render_token(generated_tokens[-1]) if not rendered: return False if any(character.isalnum() for character in rendered): return True if len(rendered) != 1: return False return unicodedata.category(rendered) in ("Ps", "Pi") def _words_since_clause_break(self, tokens: list[str]) -> int: assert self.tokenizer is not None words = 0 for token in reversed(tokens): if token in self.tokenizer.special_tokens: continue rendered = self._render_token(token) if self._is_structural_punctuation_text(rendered): break if self._starts_new_word(token) and not self._is_punctuation_piece(rendered): words += 1 return words def _should_stop_generation(self, generated_tokens: list[str]) -> bool: if not generated_tokens: return False if not self._is_terminal_punctuation_text(self._render_token(generated_tokens[-1])): return False word_count = self._generated_word_count(generated_tokens) if word_count >= MIN_COMPLETE_ANSWER_WORDS: return True return ( word_count >= MIN_COMPLETE_MULTI_SENTENCE_WORDS and self._terminal_sentence_count(generated_tokens) >= 2 ) def _terminal_sentence_count(self, tokens: list[str]) -> int: return sum( 1 for token in tokens if self._is_terminal_punctuation_text(self._render_token(token)) ) def _is_terminal_punctuation_text(self, text: str) -> bool: stripped = text.strip() if not stripped: return False terminal_character = stripped[-1] if not self._is_structural_punctuation_text(terminal_character): return False return not self._is_word_joiner_text(terminal_character) def _should_skip_prompt_overlap_token(self, token: str) -> bool: rendered = self._render_token(token) if not rendered.strip(): return True if ( self.embedding_model is not None and len(self.embedding_model.id_to_token) >= 1024 and not self._starts_new_word(token) ): return True if self._is_structural_punctuation_text(rendered): return True return rendered.strip().casefold() in PROMPT_ENVELOPE_TERMS def _starts_new_word(self, token: str) -> bool: assert self.tokenizer is not None if token in self.tokenizer.special_tokens: return True if token.startswith(self.tokenizer.word_prefix): return True return len(token) == 1 and not token.isalnum() and not self._is_word_joiner_token(token) def _generation_token_meta(self, token: str) -> GenerationTokenMeta: cache = self.generation_token_meta_cache if cache is None: cache = {} self.generation_token_meta_cache = cache cached = cache.get(token) if cached is not None: return cached rendered = self._render_token(token) meta = GenerationTokenMeta( rendered=rendered, stripped=rendered.strip(), starts_new_word=self._starts_new_word(token), punctuation_piece=self._is_punctuation_piece(rendered), structural_punctuation=self._is_structural_punctuation_token(token), structural_symbol=self._is_structural_symbol_token(token), word_joiner=self._is_word_joiner_token(token), alphanumeric="".join(character for character in rendered if character.isalnum()), common_connector=self._is_common_connector_token(token), ) cache[token] = meta return meta def _decode_tokens(self, tokens: list[str]) -> str: assert self.tokenizer is not None return self.tokenizer.decode( tokens, preserve_special_tokens=TOOL_PROTOCOL_TOKENS, ) @staticmethod def _normalize_generated_tool_protocol_text(text: str, *, context: str | None = None) -> str: marker = "" call_index = text.find(marker) if call_index < 0: return text cleaned = text[:] for boundary in ("", "", ""): boundary_index = cleaned.find(boundary, call_index + len(marker)) if boundary_index >= 0: cleaned = cleaned[:boundary_index].rstrip() second_call_index = cleaned.find(marker, call_index + len(marker)) if second_call_index >= 0: cleaned = cleaned[:second_call_index].rstrip() brace_start = cleaned.find("{", call_index) if brace_start < 0: return cleaned.strip() depth = 0 in_string = False escaped = False last_top_level_comma: int | None = None for index in range(brace_start, len(cleaned)): character = cleaned[index] if escaped: escaped = False continue if in_string and character == "\\": escaped = True continue if character == '"': in_string = not in_string continue if in_string: continue if character == "{": depth += 1 continue if character == "}": depth -= 1 if depth <= 0: candidate = cleaned[: index + 1].strip() return ReframrModel._repair_tool_call_payload_if_needed( candidate, context=context, ) continue if character == "," and depth == 1: last_top_level_comma = index if depth > 0: if last_top_level_comma is not None: candidate = cleaned[:last_top_level_comma].rstrip() + "}" return ReframrModel._repair_tool_call_payload_if_needed( candidate, context=context, ) candidate = cleaned.rstrip() + "}" return ReframrModel._repair_tool_call_payload_if_needed( candidate, context=context, ) return ReframrModel._repair_tool_call_payload_if_needed( cleaned.strip(), context=context, ) @staticmethod def _repair_tool_call_payload_if_needed(text: str, *, context: str | None = None) -> str: marker = "" if not text.startswith(marker): return text brace_start = text.find("{", len(marker)) if brace_start < 0: return text tool_name = text[len(marker) : brace_start].strip() payload_text = text[brace_start:].strip() try: payload = json.loads(payload_text) if isinstance(payload, dict) and tool_name == "web.search": repaired_query = ReframrModel._repair_search_query_from_context_if_weak( str(payload.get("query", "")), context, ) if repaired_query is not None: payload["query"] = repaired_query return f"{marker} {tool_name} {json.dumps(payload, ensure_ascii=False)}" return text except (TypeError, json.JSONDecodeError): pass body = payload_text.strip() if body.startswith("{"): body = body[1:] if body.endswith("}"): body = body[:-1] body = " ".join(body.replace('"', "").split()) if not tool_name or not body: return text if tool_name == "web.search": payload = { "query": ReframrModel._repair_search_query_from_context_if_weak( body, context, ) or body } else: payload = {"input": body} return f"{marker} {tool_name} {json.dumps(payload, ensure_ascii=False)}" @staticmethod def _repair_search_query_from_context_if_weak( query: str, context: str | None, ) -> str | None: cleaned_query = " ".join(query.replace("{", " ").replace("}", " ").split()) normalized_words = [ word.strip(" \t\r\n:,.;!?\"'()[]{}").casefold() for word in cleaned_query.split() if word.strip(" \t\r\n:,.;!?\"'()[]{}") ] unique_content_words = { word for word in normalized_words if word not in {"query", "web.search", "tool_call"} } lowered_query = cleaned_query.casefold() weak = ( len(unique_content_words) < 3 or lowered_query.startswith("query:") or "web.search" in lowered_query or any( marker in lowered_query for marker in ("", "", "according to") ) ) if not weak: return None context_query = ReframrModel._search_query_from_context(context or "") return context_query or None @staticmethod def _search_query_from_context(context: str) -> str: if not context: return "" before_tool_result = context.split("", 1)[0] before_final = before_tool_result.split("", 1)[0] lines = [line.strip() for line in before_final.splitlines() if line.strip()] if not lines: lines = [before_final.strip()] latest_user = "" for line in lines: lowered = line.casefold() if lowered.startswith("user:"): latest_user = line.split(":", 1)[1].strip() elif lowered.startswith("question:"): latest_user = line.split(":", 1)[1].strip() if not latest_user: latest_user = lines[-1] for prefix in ("User:", "Question:", "Prompt:", "Context:"): if latest_user.casefold().startswith(prefix.casefold()): latest_user = latest_user[len(prefix) :].strip() cleaned = " ".join(latest_user.split()) return cleaned.strip(" \t\r\n\"'") @staticmethod def _finalize_generated_text(text: str) -> str: stripped = text.rstrip() if not stripped: return stripped if stripped.startswith(""): return stripped stripped = ReframrModel._remove_separator_punctuation_before_boundary(stripped) if stripped and ReframrModel._is_separator_punctuation(stripped[-1:]): stripped = stripped[:-1].rstrip() if not stripped: return stripped if ( ReframrModel._is_surface_punctuation(stripped[:1]) or ReframrModel._is_surface_punctuation(stripped[-1:]) ): return stripped if any(character.isalnum() for character in stripped[-8:]): return f"{stripped}." return stripped @staticmethod def _remove_separator_punctuation_before_boundary(text: str) -> str: cleaned: list[str] = [] for character in text: if ( ReframrModel._is_separator_punctuation(character) and cleaned and ReframrModel._is_separator_punctuation(cleaned[-1]) ): cleaned.pop() cleaned.append(character) return "".join(cleaned) @staticmethod def _is_surface_punctuation(character: str) -> bool: return len(character) == 1 and unicodedata.category(character).startswith("P") @staticmethod def _is_separator_punctuation(character: str) -> bool: return ( ReframrModel._is_surface_punctuation(character) and unicodedata.bidirectional(character) == "CS" ) def _render_token(self, token: str) -> str: assert self.tokenizer is not None if token.startswith(self.tokenizer.word_prefix): return token[len(self.tokenizer.word_prefix) :] return token def _require_fit(self) -> None: if ( self.tokenizer is None or self.embedding_model is None or self.memory_units is None or self.readout_weights is None or self.ternary_mask is None or self.associative_keys is None or ( self.associative_key_norms is None and self.associative_key_norms_array is None ) or self.associative_values is None or self.transition_tables is None ): raise RuntimeError("Call fit() before using the REFRAMR model.") def _ensure_numeric_caches(self) -> None: if np is None: return if self.readout_weights_array is None: self._refresh_numeric_caches()