| import json |
| import hashlib |
| import random |
| import site |
| import string |
| import sys |
| import unicodedata |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Sequence |
|
|
| _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor" |
| for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"): |
| if _vendor_path.exists(): |
| vendor_text = str(_vendor_path) |
| if vendor_text not in sys.path: |
| sys.path.insert(0, vendor_text) |
|
|
| try: |
| import numpy as np |
| except ModuleNotFoundError: |
| user_site = site.getusersitepackages() |
| if user_site and user_site not in sys.path: |
| sys.path.append(user_site) |
| try: |
| import numpy as np |
| except ModuleNotFoundError: |
| np = None |
|
|
| if np is not None and not hasattr(np, "asarray"): |
| np = None |
|
|
| from .checkpoint import read_safetensor_file, write_safetensor_file |
| from .config import ReframrConfig |
| from .embeddings import EmbeddingModel, fit_ppmi_embedding_from_tokens |
| from .hippo import AnalyticalMemoryUnit, analytical_embedding_drive, analytical_embedding_drive_fast |
| from .linalg import Vector, dot, mean, norm, softmax, zeros_vector |
| from .reservoir import apply_readout, ridge_regression_readout |
| from .reasoning import TOOL_PROTOCOL_TOKENS, reasoning_prefix |
| from .sparse_context import HashedSparseAttention |
| from .ternary import apply_ternary_mask, derive_ternary_mask_from_states |
| from .tokenizer import NativeTokenizer |
|
|
| ASSOCIATIVE_BLEND = 0.42 |
| TRANSITION_BLEND = 0.08 |
| COPY_BLEND = 0.04 |
| BASE_BLEND = 0.34 |
| FAST_ASSOCIATIVE_BLEND = 0.06 |
| FAST_TRANSITION_BLEND = 0.14 |
| FAST_COPY_BLEND = 0.12 |
| FAST_BASE_BLEND = 0.72 |
| FAST_PREFERENCE_BLEND = 0.15 |
| FAST_ANSWER_BLEND = 0.16 |
| FAST_SOURCE_EVIDENCE_BLEND = 0.52 |
| PROMPT_READOUT_LOGIT_ZSCORE_SCALE = 0.48 |
| PROMPT_START_READOUT_CONFIDENCE_FLOOR = 0.45 |
| ASSOCIATIVE_TOP_K = 12 |
| ANSWER_TOP_K = 48 |
| ANSWER_START_TOP_K = 32 |
| MIN_COMPLETE_ANSWER_WORDS = 6 |
| MIN_COMPLETE_MULTI_SENTENCE_WORDS = 4 |
| ANSWER_SEQUENCE_MATCH_FLOOR = 0.27 |
| ANSWER_START_CONFIDENCE_FLOOR = 0.45 |
| ANSWER_START_MATCH_SUPPORT_FLOOR = 0.18 |
| ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR = 0.45 |
| ANSWER_SEQUENCE_LOCK_FLOOR = 0.55 |
| ANSWER_SEQUENCE_SPIKE_CONFIDENCE = 0.80 |
| READOUT_LOGIT_ZSCORE_SCALE = 0.22 |
| TRACE_IDENTITY_SCALE = 0.78 |
| TRACE_IDENTITY_HASHES = ( |
| (1103515245, 12345, 214013, 2531011), |
| (1664525, 1013904223, 22695477, 1), |
| (69069, 362437, 134775813, 17), |
| (134775813, 97, 1103515245, 31), |
| (22695477, 911, 1664525, 73), |
| (214013, 2531011, 69069, 19), |
| (48271, 0, 69621, 11), |
| (16807, 37, 40692, 101), |
| (279470273, 173, 1299709, 53), |
| (39916801, 29, 2147483629, 7), |
| ) |
| PROMPT_ENVELOPE_TERMS = frozenset( |
| {"system", "instruction", "user", "human", "assistant", "question", "answer"} |
| ) |
| NGRAM_KEY_SEPARATOR = "\u0001" |
| TRANSITION_ORDERS = (10, 8, 6, 5, 4, 3, 2, 1) |
| DEFAULT_GENERATION_TEMPERATURE = 0.82 |
| DEFAULT_GENERATION_TOP_K = 24 |
| DEFAULT_GENERATION_TOP_P = 0.92 |
| DEFAULT_REPETITION_PENALTY = 1.18 |
| ANSWER_SEQUENCE_MAX_TOKENS = 192 |
| ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT = 8192 |
| ANSWER_SEQUENCE_VARIATION_TEMPERATURE = 0.65 |
| ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT = 4 |
| ANSWER_SEQUENCE_CREATIVE_TEMPERATURE = 1.10 |
| ANSWER_REPLAY_PREFIX_TEMPERATURE = 0.95 |
| ANSWER_REPLAY_PREFIX_MIN_TOKENS = 64 |
| ANSWER_REPLAY_PREFIX_PENALTY = 0.18 |
| CREATIVE_EARLY_POOL_TEMPERATURE = 1.05 |
| CREATIVE_EARLY_POOL_WORD_LIMIT = 6 |
| CREATIVE_EARLY_POOL_MAX = 8 |
| TOOL_CALL_CONTEXT_TERMS = frozenset( |
| { |
| "current", |
| "latest", |
| "today", |
| "yesterday", |
| "tonight", |
| "now", |
| "fresh", |
| "recent", |
| "web", |
| "search", |
| "real-time", |
| "price", |
| "weather", |
| "election", |
| "news", |
| "official", |
| "result", |
| "live", |
| } |
| ) |
| RUNTIME_GENERATION_HISTORY_LIMIT = 8 |
| AVOID_SEQUENCE_MIN_TOKENS = 6 |
| WORD_COMPLETION_OVERFLOW_TOKENS = 16 |
| ANSWER_FINGERPRINT_WORDS = 4 |
| SPARSE_CONTEXT_MIN_TOKENS = 16 |
| SPARSE_CONTEXT_TOP_K = 64 |
| SPARSE_CONTEXT_HASH_BITS = 12 |
| SPARSE_CONTEXT_PROBE_RADIUS = 1 |
| SPARSE_CONTEXT_CANDIDATE_MULTIPLIER = 16 |
| SPARSE_CONTEXT_TRACE_BLEND = 0.35 |
| RUNTIME_ARRAY_DTYPE = np.float32 if np is not None else None |
|
|
|
|
| @dataclass(frozen=True, slots=True) |
| class CharacterCountFact: |
| character: str |
| word: str |
| count: int |
| surface_seed: int |
| focused: bool |
|
|
|
|
| @dataclass(frozen=True, slots=True) |
| class GenerationTokenMeta: |
| rendered: str |
| stripped: str |
| starts_new_word: bool |
| punctuation_piece: bool |
| structural_punctuation: bool |
| structural_symbol: bool |
| word_joiner: bool |
| alphanumeric: str |
| common_connector: bool |
|
|
|
|
| def _normalize_vector(values: Vector) -> Vector: |
| total = sum(values) |
| if total <= 0.0: |
| return [0.0 for _ in values] |
| return [value / total for value in values] |
|
|
|
|
| def _encode_ngram_key(tokens: tuple[str, ...]) -> str: |
| return NGRAM_KEY_SEPARATOR.join(tokens) |
|
|
|
|
| def _decode_ngram_key(key: str) -> tuple[str, ...]: |
| return tuple(part for part in key.split(NGRAM_KEY_SEPARATOR) if part) |
|
|
|
|
| def _last_index(values: list[str], target: str) -> int | None: |
| for index in range(len(values) - 1, -1, -1): |
| if values[index] == target: |
| return index |
| return None |
|
|
|
|
| def _first_index(values: list[str], target: str) -> int | None: |
| for index, value in enumerate(values): |
| if value == target: |
| return index |
| return None |
|
|
|
|
| @dataclass(slots=True) |
| class DecodeState: |
| hidden_states: list[Vector] |
| context_traces: list[Vector] |
| combined_state: Vector |
| context_tokens: list[str] |
| answer_anchor_state: Vector | None = None |
| answer_matches: list[tuple[float, int, int]] | None = None |
| answer_start_matches: list[tuple[float, int, int]] | None = None |
| answer_sequence_matches: list[tuple[float, int, int]] | None = None |
| prompt_answer_prior: object | None = None |
| prompt_answer_start_prior: object | None = None |
|
|
|
|
| @dataclass(slots=True) |
| class ReframrModel: |
| config: ReframrConfig |
| tokenizer: NativeTokenizer | None = None |
| embedding_model: EmbeddingModel | None = None |
| memory_units: list[AnalyticalMemoryUnit] | None = None |
| ternary_scale: float = 1.0 |
| ternary_mask: list[int] | None = None |
| ternary_mask_array: object | None = None |
| readout_weights: list[list[float]] | None = None |
| readout_weights_array: object | None = None |
| readout_bias: Vector | None = None |
| readout_bias_array: object | None = None |
| prompt_answer_weights: list[list[float]] | None = None |
| prompt_answer_weights_array: object | None = None |
| prompt_answer_bias: Vector | None = None |
| prompt_answer_bias_array: object | None = None |
| prompt_answer_start_weights: list[list[float]] | None = None |
| prompt_answer_start_weights_array: object | None = None |
| prompt_answer_start_bias: Vector | None = None |
| prompt_answer_start_bias_array: object | None = None |
| trace_token_weights: Vector | None = None |
| trace_token_weights_array: object | None = None |
| trace_embedding_table_array: object | None = None |
| preference_bias: Vector | None = None |
| preference_bias_array: object | None = None |
| preference_valid_mask_array: object | None = None |
| state_offset: Vector | None = None |
| state_offset_array: object | None = None |
| associative_keys: list[Vector] | None = None |
| associative_keys_array: object | None = None |
| associative_key_norms: list[float] | None = None |
| associative_key_norms_array: object | None = None |
| associative_values: list[int] | None = None |
| associative_values_array: object | None = None |
| associative_valid_mask_array: object | None = None |
| answer_keys: list[Vector] | None = None |
| answer_keys_array: object | None = None |
| answer_key_norms: list[float] | None = None |
| answer_key_norms_array: object | None = None |
| answer_similarity_keys_array: object | None = None |
| answer_similarity_key_norms_array: object | None = None |
| answer_similarity_mask_array: object | None = None |
| answer_values: list[int] | None = None |
| answer_values_array: object | None = None |
| answer_valid_mask_array: object | None = None |
| answer_start_keys: list[Vector] | None = None |
| answer_start_keys_array: object | None = None |
| answer_start_key_norms: list[float] | None = None |
| answer_start_key_norms_array: object | None = None |
| answer_start_similarity_keys_array: object | None = None |
| answer_start_similarity_key_norms_array: object | None = None |
| answer_start_values: list[int] | None = None |
| answer_start_values_array: object | None = None |
| answer_start_valid_mask_array: object | None = None |
| answer_sequence_keys: list[Vector] | None = None |
| answer_sequence_keys_array: object | None = None |
| answer_sequence_key_norms: list[float] | None = None |
| answer_sequence_key_norms_array: object | None = None |
| answer_sequence_similarity_keys_array: object | None = None |
| answer_sequence_similarity_key_norms_array: object | None = None |
| answer_sequence_prompt_tokens: list[list[int]] | None = None |
| answer_sequence_prompt_tokens_array: object | None = None |
| answer_sequence_tokens: list[list[int]] | None = None |
| answer_sequence_tokens_array: object | None = None |
| answer_sequence_token_id_rows: list[list[int]] | None = None |
| answer_sequence_prompt_weight_maps: list[dict[int, float]] | None = None |
| answer_sequence_prompt_weight_norms: list[float] | None = None |
| answer_sequence_prompt_bigram_sets: list[set[tuple[int, int]]] | None = None |
| answer_sequence_prompt_trigram_sets: list[set[tuple[int, int, int]]] | None = None |
| answer_sequence_prompt_number_sets: list[set[str]] | None = None |
| answer_sequence_prompt_inverted_index: dict[int, list[int]] | None = None |
| answer_sequence_prompt_specificity: dict[int, float] | None = None |
| prompt_overlap_valid_token_mask_array: object | None = None |
| answer_fingerprint_hashes: set[tuple[int, ...]] | None = None |
| answer_fingerprint_token_lengths: set[int] | None = None |
| answer_fingerprint_token_sequences_by_length: dict[int, set[tuple[int, ...]]] | None = None |
| answer_sequence_prefixes_by_length: dict[int, set[tuple[int, ...]]] | None = None |
| transition_tables: dict[int, dict[tuple[str, ...], dict[str, float]]] | None = None |
| transition_id_tables: dict[int, dict[tuple[int, ...], tuple[object, object]]] | None = None |
| transition_tensor_cache: dict[str, object] | None = None |
| transition_built_orders: set[int] | None = None |
| generation_token_meta_cache: dict[str, GenerationTokenMeta] | None = None |
| runtime_generation_history: dict[str, list[str]] = field(default_factory=dict, repr=False) |
|
|
| def fit(self, text: str) -> "ReframrModel": |
| self.generation_token_meta_cache = None |
| self.answer_sequence_prefixes_by_length = None |
| self.tokenizer = NativeTokenizer.train( |
| text, |
| vocab_size=self.config.tokenizer_vocab_size, |
| min_pair_frequency=self.config.tokenizer_min_pair_frequency, |
| lowercase=self.config.lowercase, |
| ) |
| tokens = self.tokenizer.encode(text) |
| if len(tokens) < 2: |
| raise ValueError("REFRAMR needs at least two tokens to derive a next-token readout.") |
|
|
| self.embedding_model = fit_ppmi_embedding_from_tokens( |
| tokens, |
| embedding_dim=self.config.embedding_dim, |
| window_size=self.config.window_size, |
| min_frequency=self.config.min_frequency, |
| max_vocab=self.config.max_vocab, |
| required_tokens=self.tokenizer.vocab, |
| ) |
| self.memory_units = [ |
| AnalyticalMemoryUnit(self.config.state_dim, timescale) |
| for timescale in self.config.timescales |
| ] |
| token_counts: dict[str, float] = {} |
| for token in tokens: |
| token_counts[token] = token_counts.get(token, 0.0) + 1.0 |
| self.trace_token_weights = self._derive_trace_token_weights_from_counts(token_counts) |
|
|
| raw_states, targets, target_ids = self._collect_training_examples(tokens) |
| self.ternary_scale, self.ternary_mask = derive_ternary_mask_from_states(raw_states) |
| analytical_states = [ |
| apply_ternary_mask(state, self.ternary_mask, self.ternary_scale) |
| for state in raw_states |
| ] |
| self.associative_keys = [state[:] for state in analytical_states] |
| self.associative_key_norms = [norm(state) for state in analytical_states] |
| self.associative_values = target_ids[:] |
| self.answer_keys = [] |
| self.answer_key_norms = [] |
| self.answer_values = [] |
| self.answer_start_keys = [] |
| self.answer_start_key_norms = [] |
| self.answer_start_values = [] |
| self.answer_sequence_keys = [] |
| self.answer_sequence_key_norms = [] |
| self.answer_sequence_prompt_tokens = [] |
| self.answer_sequence_tokens = [] |
| self.prompt_answer_weights = [] |
| self.prompt_answer_bias = [0.0 for _ in self.embedding_model.id_to_token] |
| self.prompt_answer_start_weights = [] |
| self.prompt_answer_start_bias = [0.0 for _ in self.embedding_model.id_to_token] |
| self.transition_tables = self._build_transition_tables(tokens) |
| self._fit_answer_memory_from_text(text) |
| self._refresh_answer_fingerprint_hashes() |
| self.readout_weights = ridge_regression_readout( |
| analytical_states, |
| targets, |
| regularization=self.config.regularization, |
| ) |
| self.readout_bias = [0.0 for _ in self.embedding_model.id_to_token] |
| self.preference_bias = [0.0 for _ in self.embedding_model.id_to_token] |
| self.state_offset = [0.0 for _ in analytical_states[0]] if analytical_states else [] |
| self._refresh_numeric_caches() |
| return self |
|
|
| def _fit_answer_memory_from_text(self, text: str) -> None: |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| if ( |
| self.answer_keys is None |
| or self.answer_key_norms is None |
| or self.answer_values is None |
| or self.answer_start_keys is None |
| or self.answer_start_key_norms is None |
| or self.answer_start_values is None |
| or self.answer_sequence_keys is None |
| or self.answer_sequence_key_norms is None |
| or self.answer_sequence_prompt_tokens is None |
| or self.answer_sequence_tokens is None |
| ): |
| return |
|
|
| for line in text.splitlines(): |
| if "<answer>" not in line: |
| continue |
| prompt_text, answer_text = line.split("<answer>", 1) |
| prompt_text = prompt_text.strip() |
| answer_text = answer_text.strip() |
| if not prompt_text or not answer_text: |
| continue |
|
|
| prompt_tokens = self.tokenizer.encode(prompt_text) + ["<answer>"] |
| answer_tokens = [ |
| token |
| for token in self.tokenizer.encode(answer_text) |
| if token in self.embedding_model.token_to_id |
| and ( |
| token not in self.tokenizer.special_tokens |
| or token in TOOL_PROTOCOL_TOKENS |
| ) |
| ] |
| if not prompt_tokens or not answer_tokens: |
| continue |
|
|
| key = self._encode_context(prompt_tokens) |
| key_norm = norm(key) |
| if key_norm <= 0.0: |
| continue |
|
|
| answer_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in answer_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] |
| ] |
| prompt_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in prompt_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] |
| if token in self.embedding_model.token_to_id |
| and ( |
| token not in self.tokenizer.special_tokens |
| or token in TOOL_PROTOCOL_TOKENS |
| ) |
| ] |
| if not answer_ids: |
| continue |
|
|
| self.answer_keys.append(key[:]) |
| self.answer_key_norms.append(key_norm) |
| self.answer_values.append(answer_ids[0]) |
| self.answer_start_keys.append(key[:]) |
| self.answer_start_key_norms.append(key_norm) |
| self.answer_start_values.append(answer_ids[0]) |
| self.answer_sequence_keys.append(key[:]) |
| self.answer_sequence_key_norms.append(key_norm) |
| self.answer_sequence_prompt_tokens.append( |
| prompt_ids |
| + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(prompt_ids))] |
| ) |
| self.answer_sequence_tokens.append( |
| answer_ids |
| + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(answer_ids))] |
| ) |
|
|
| def predict_next_distribution( |
| self, |
| context: str, |
| *, |
| reasoning_mode: str | None = None, |
| ) -> dict[str, float]: |
| self._require_fit() |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| probabilities = self.predict_next_token_distribution( |
| context, |
| reasoning_mode=reasoning_mode, |
| ) |
| distribution: dict[str, float] = {} |
| for token, probability in probabilities.items(): |
| rendered = self._render_token(token) |
| distribution[rendered] = distribution.get(rendered, 0.0) + probability |
| return distribution |
|
|
| def predict_next_token_distribution( |
| self, |
| context: str, |
| *, |
| reasoning_mode: str | None = None, |
| ) -> dict[str, float]: |
| self._require_fit() |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| assert self.readout_weights is not None |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) |
| return self._predict_next_token_distribution_from_tokens(context_tokens) |
|
|
| def generate_text( |
| self, |
| context: str, |
| *, |
| max_tokens: int = 64, |
| reasoning_mode: str | None = None, |
| temperature: float = 0.0, |
| top_k: int = DEFAULT_GENERATION_TOP_K, |
| top_p: float = DEFAULT_GENERATION_TOP_P, |
| repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
| avoid_texts: Sequence[str] | None = None, |
| ) -> str: |
| character_count_response = self._character_count_response( |
| context, |
| temperature=temperature, |
| ) |
| if character_count_response is not None: |
| return character_count_response |
| self._require_fit() |
| self._ensure_numeric_caches() |
| assert self.tokenizer is not None |
| runtime_avoid_texts = self._runtime_avoid_texts( |
| context, |
| avoid_texts, |
| temperature=temperature, |
| ) |
| avoid_token_sequences = self._avoid_text_token_sequences(runtime_avoid_texts) |
| if ( |
| np is not None |
| and self.readout_weights_array is not None |
| and self.embedding_model is not None |
| and len(self.embedding_model.id_to_token) >= 1024 |
| ): |
| generated_text = self._generate_text_fast( |
| context, |
| max_tokens=max_tokens, |
| reasoning_mode=reasoning_mode, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| avoid_token_sequences=avoid_token_sequences, |
| ) |
| self._remember_runtime_generation( |
| context, |
| generated_text, |
| temperature=temperature, |
| ) |
| return generated_text |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| _, context_tokens = self._generation_prompt_tokens(context, active_mode) |
| decode_state = self._build_decode_state(context_tokens) |
| generated_tokens: list[str] = [] |
| for _ in range(max_tokens): |
| distribution, _ = self._score_next_token_from_state( |
| decode_state, |
| include_trace=False, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| avoid_token_sequences=avoid_token_sequences, |
| ) |
| forced_source_token = self._source_evidence_next_token( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| next_token = forced_source_token or self._select_generation_token( |
| distribution, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| avoid_token_sequences=avoid_token_sequences, |
| preserve_dominant_candidates=( |
| self._answer_decode_has_continuation(decode_state, generated_tokens) |
| or self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| ), |
| ) |
| if not next_token: |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| if self._should_stop_answer_sequence(decode_state, generated_tokens): |
| break |
| if self._should_stop_after_answer_path_drift(decode_state, generated_tokens): |
| break |
| if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens): |
| break |
| if ( |
| self._should_stop_generation(generated_tokens) |
| and not self._answer_decode_has_continuation(decode_state, generated_tokens) |
| and not self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| ): |
| break |
| overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens) |
| while generated_tokens and overflow_budget > 0: |
| has_answer_continuation = self._answer_decode_has_continuation( |
| decode_state, |
| generated_tokens, |
| ) |
| has_source_continuation = self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| if ( |
| self._starts_new_word(generated_tokens[-1]) |
| and not has_answer_continuation |
| and not has_source_continuation |
| ): |
| break |
| distribution, _ = self._score_next_token_from_state( |
| decode_state, |
| include_trace=False, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| avoid_token_sequences=avoid_token_sequences, |
| ) |
| forced_source_token = self._source_evidence_next_token( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| next_token = forced_source_token or self._select_generation_token( |
| distribution, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| avoid_token_sequences=avoid_token_sequences, |
| preserve_dominant_candidates=has_answer_continuation |
| or has_source_continuation, |
| ) |
| if not next_token: |
| break |
| if ( |
| self._starts_new_word(next_token) |
| and not has_answer_continuation |
| and not has_source_continuation |
| ): |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| overflow_budget -= 1 |
| generated_text = self._finalize_generated_text( |
| self._normalize_generated_tool_protocol_text( |
| self._decode_tokens(generated_tokens), |
| context=context, |
| ) |
| ) |
| self._remember_runtime_generation( |
| context, |
| generated_text, |
| temperature=temperature, |
| ) |
| return generated_text |
|
|
| @staticmethod |
| def _character_count_fact(context: str) -> CharacterCountFact | None: |
| normalized = unicodedata.normalize("NFKC", context).strip() |
| tokens = ReframrModel._character_count_word_tokens(normalized) |
| if not tokens: |
| return None |
| lowered = [token.casefold() for token in tokens] |
| count_terms = {"count", "counts", "counting", "many"} |
| unit_terms = {"character", "characters", "letter", "letters"} |
| if not any(token in count_terms for token in lowered): |
| return None |
| if not any(token in unit_terms for token in lowered) and "count" not in lowered: |
| return None |
|
|
| filler_terms = {"a", "an", "the", "single", "one", "please"} |
| word_markers = {"in", "inside"} |
| char_index = ReframrModel._character_count_target_index( |
| lowered, |
| unit_terms=unit_terms, |
| filler_terms=filler_terms, |
| ) |
| word_index = ReframrModel._character_count_word_index( |
| lowered, |
| char_index=char_index, |
| filler_terms=filler_terms, |
| word_markers=word_markers, |
| ) |
| if char_index is None or word_index is None: |
| return None |
| character = tokens[char_index] |
| word = tokens[word_index] |
| if len(character) != 1 or not word: |
| return None |
| order_offset = 0 if char_index < word_index else 1 |
| surface_seed = ((char_index + 1) * 7 + (word_index + 1) * 3 + len(tokens) + order_offset) % 4 |
| structural_terms = ( |
| count_terms |
| | unit_terms |
| | filler_terms |
| | word_markers |
| | { |
| "for", |
| "of", |
| "to", |
| "how", |
| "do", |
| "does", |
| "there", |
| "are", |
| "is", |
| "appear", |
| "appears", |
| "times", |
| "word", |
| } |
| ) |
| extra_content_tokens = [ |
| token |
| for index, token in enumerate(lowered) |
| if index not in {char_index, word_index} |
| and token not in structural_terms |
| ] |
| return CharacterCountFact( |
| character=character, |
| word=word, |
| count=word.casefold().count(character.casefold()), |
| surface_seed=surface_seed, |
| focused=not extra_content_tokens, |
| ) |
|
|
| @staticmethod |
| def _character_count_word_tokens(text: str) -> list[str]: |
| tokens: list[str] = [] |
| current: list[str] = [] |
| for character in text: |
| if character != "_" and character.isalnum(): |
| current.append(character) |
| continue |
| if current: |
| tokens.append("".join(current)) |
| current = [] |
| if current: |
| tokens.append("".join(current)) |
| return tokens |
|
|
| @staticmethod |
| def _character_count_target_index( |
| tokens: list[str], |
| *, |
| unit_terms: set[str], |
| filler_terms: set[str], |
| ) -> int | None: |
| for index, token in enumerate(tokens): |
| if token not in unit_terms: |
| continue |
| for adjacent in (index - 1, index + 1): |
| if 0 <= adjacent < len(tokens) and len(tokens[adjacent]) == 1: |
| return adjacent |
| before = ReframrModel._nearest_content_index(tokens, index - 1, -1, filler_terms) |
| after = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) |
| for candidate in (before, after): |
| if candidate is not None and len(tokens[candidate]) == 1: |
| return candidate |
| for index, token in enumerate(tokens): |
| if token not in {"count", "counts", "counting"}: |
| continue |
| candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) |
| if candidate is not None and tokens[candidate] in unit_terms: |
| candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) |
| if candidate is not None and len(tokens[candidate]) == 1: |
| return candidate |
| return None |
|
|
| @staticmethod |
| def _character_count_word_index( |
| tokens: list[str], |
| *, |
| char_index: int | None, |
| filler_terms: set[str], |
| word_markers: set[str], |
| ) -> int | None: |
| for index, token in enumerate(tokens): |
| if token != "word": |
| continue |
| candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) |
| if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: |
| return candidate |
| for index, token in enumerate(tokens): |
| if token not in word_markers: |
| continue |
| candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) |
| if candidate is not None and tokens[candidate] == "word": |
| candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) |
| if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: |
| return candidate |
| skipped_terms = { |
| "how", |
| "many", |
| "do", |
| "does", |
| "count", |
| "counts", |
| "counting", |
| "letter", |
| "letters", |
| "character", |
| "characters", |
| "word", |
| "there", |
| "are", |
| "is", |
| "appear", |
| "appears", |
| "times", |
| } | filler_terms | word_markers |
| for index in range(len(tokens) - 1, -1, -1): |
| if index == char_index: |
| continue |
| if len(tokens[index]) <= 1 or tokens[index] in skipped_terms: |
| continue |
| return index |
| return None |
|
|
| @staticmethod |
| def _nearest_content_index( |
| tokens: list[str], |
| start: int, |
| direction: int, |
| skipped_terms: set[str], |
| ) -> int | None: |
| index = start |
| while 0 <= index < len(tokens): |
| if tokens[index] not in skipped_terms: |
| return index |
| index += direction |
| return None |
|
|
| @classmethod |
| def _character_count_response(cls, context: str, *, temperature: float = 0.0) -> str | None: |
| fact = cls._character_count_fact(context) |
| if fact is None: |
| return None |
| if not fact.focused: |
| return None |
| return cls._render_character_count_fact(fact, temperature=temperature) |
|
|
| @staticmethod |
| def _render_character_count_fact(fact: CharacterCountFact, *, temperature: float = 0.0) -> str: |
| character_label = f"'{fact.character}'" |
| word_label = f"'{fact.word}'" |
| character_noun = "character" if fact.count == 1 else "characters" |
| return f"{word_label} has {fact.count} {character_label} {character_noun}." |
|
|
| @classmethod |
| def _runtime_source_grounded_response(cls, context: str) -> str | None: |
| return None |
|
|
| @classmethod |
| def _runtime_source_records(cls, context: str) -> list[tuple[str, str, str]]: |
| records: list[tuple[str, str, str]] = [] |
| marker = "<source>" |
| search_from = 0 |
| while True: |
| source_start = context.find(marker, search_from) |
| if source_start < 0: |
| break |
| content_start = source_start + len(marker) |
| content_end = cls._runtime_source_record_end(context, content_start) |
| raw_record = context[content_start:content_end].strip() |
| record = cls._parse_runtime_source_record(raw_record) |
| if record is not None: |
| records.append(record) |
| search_from = max(content_end, content_start + 1) |
| return records |
|
|
| @staticmethod |
| def _runtime_source_record_end(context: str, start: int) -> int: |
| boundaries = [ |
| position |
| for marker in ( |
| "\n", |
| "<source>", |
| "<tool_call>", |
| "<tool_result>", |
| "<final>", |
| "<answer>", |
| "<reason>", |
| ) |
| if (position := context.find(marker, start)) >= 0 |
| ] |
| return min(boundaries) if boundaries else len(context) |
|
|
| @staticmethod |
| def _parse_runtime_source_record(raw_record: str) -> tuple[str, str, str] | None: |
| if not raw_record: |
| return None |
| pieces = [piece.strip() for piece in raw_record.split("|", 2)] |
| if len(pieces) >= 3: |
| title, url, snippet = pieces[0], pieces[1], pieces[2] |
| else: |
| title, url, snippet = "the provided source", "", pieces[-1] |
| title = ReframrModel._clean_runtime_source_field(title) or "the provided source" |
| url = ReframrModel._clean_runtime_source_field(url) |
| snippet = ReframrModel._clean_runtime_source_field(snippet) |
| if not snippet: |
| return None |
| return title, url, snippet |
|
|
| @staticmethod |
| def _clean_runtime_source_field(text: str) -> str: |
| normalized = unicodedata.normalize("NFKC", text) |
| cleaned = " ".join(normalized.split()) |
| return cleaned.strip(" \t\r\n|") |
|
|
| def _generate_text_fast( |
| self, |
| context: str, |
| *, |
| max_tokens: int, |
| reasoning_mode: str | None, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| avoid_token_sequences: Sequence[Sequence[str]] | None = None, |
| ) -> str: |
| assert self.tokenizer is not None |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| _, context_tokens = self._generation_prompt_tokens(context, active_mode) |
| decode_state = self._build_decode_state(context_tokens) |
| generated_tokens: list[str] = [] |
| for _ in range(max_tokens): |
| probabilities, _ = self._score_next_token_array_from_state( |
| decode_state, |
| include_associative=not generated_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| avoid_token_sequences=avoid_token_sequences, |
| ) |
| forced_source_token = self._source_evidence_next_token( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| next_token = forced_source_token or self._select_generation_token_from_array( |
| probabilities, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| avoid_token_sequences=avoid_token_sequences, |
| preserve_dominant_candidates=( |
| self._answer_decode_has_continuation(decode_state, generated_tokens) |
| or self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| ), |
| ) |
| if not next_token: |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| if self._should_stop_answer_sequence(decode_state, generated_tokens): |
| break |
| if self._should_stop_after_answer_path_drift(decode_state, generated_tokens): |
| break |
| if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens): |
| break |
| if ( |
| self._should_stop_generation(generated_tokens) |
| and not self._answer_decode_has_continuation(decode_state, generated_tokens) |
| and not self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| ): |
| break |
|
|
| overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens) |
| while generated_tokens and overflow_budget > 0: |
| has_answer_continuation = self._answer_decode_has_continuation( |
| decode_state, |
| generated_tokens, |
| ) |
| has_source_continuation = self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| if ( |
| self._starts_new_word(generated_tokens[-1]) |
| and not has_answer_continuation |
| and not has_source_continuation |
| ): |
| break |
| probabilities, _ = self._score_next_token_array_from_state( |
| decode_state, |
| include_associative=False, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| avoid_token_sequences=avoid_token_sequences, |
| ) |
| forced_source_token = self._source_evidence_next_token( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| next_token = forced_source_token or self._select_generation_token_from_array( |
| probabilities, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| avoid_token_sequences=avoid_token_sequences, |
| preserve_dominant_candidates=has_answer_continuation |
| or has_source_continuation, |
| ) |
| if not next_token: |
| break |
| if ( |
| self._starts_new_word(next_token) |
| and not has_answer_continuation |
| and not has_source_continuation |
| ): |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| overflow_budget -= 1 |
| return self._finalize_generated_text( |
| self._normalize_generated_tool_protocol_text( |
| self._decode_tokens(generated_tokens), |
| context=context, |
| ) |
| ) |
|
|
| def trace_next_token( |
| self, |
| context: str, |
| *, |
| reasoning_mode: str | None = None, |
| top_k: int = 5, |
| ) -> dict[str, object]: |
| self._require_fit() |
| assert self.tokenizer is not None |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) |
| _, trace = self._score_next_token_from_tokens( |
| context_tokens, |
| top_k=top_k, |
| include_trace=True, |
| ) |
| trace.update( |
| { |
| "context": context, |
| "reasoning_mode": active_mode, |
| "reasoning_tokens": reasoning_prefix(active_mode), |
| "context_tokens": context_tokens, |
| } |
| ) |
| return trace |
|
|
| def trace_generation( |
| self, |
| context: str, |
| *, |
| max_tokens: int = 16, |
| reasoning_mode: str | None = None, |
| top_k: int = 5, |
| temperature: float = 0.0, |
| top_p: float = DEFAULT_GENERATION_TOP_P, |
| repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
| ) -> dict[str, object]: |
| character_count_response = self._character_count_response( |
| context, |
| temperature=temperature, |
| ) |
| if character_count_response is not None: |
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| prompt = context if "<answer>" in context else f"{context} <answer>" |
| return { |
| "context": context, |
| "prompt": prompt, |
| "reasoning_mode": active_mode, |
| "reasoning_tokens": reasoning_prefix(active_mode), |
| "generation_policy": { |
| "temperature": temperature, |
| "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), |
| "top_p": top_p, |
| "repetition_penalty": repetition_penalty, |
| }, |
| "prompt_tokens": [], |
| "generated_tokens": [], |
| "generated_text": character_count_response, |
| "generated_token_count": len(character_count_response.split()), |
| "steps": [], |
| "reasoning_summary": ( |
| "The prompt matched the generic character-counting path, so Reframr " |
| "read the requested character and word from the prompt and counted " |
| "the characters directly." |
| ), |
| } |
| self._require_fit() |
| assert self.tokenizer is not None |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| prompt, context_tokens = self._generation_prompt_tokens(context, active_mode) |
| decode_state = self._build_decode_state(context_tokens) |
| prompt_tokens = decode_state.context_tokens[:] |
| generated_tokens: list[str] = [] |
| steps: list[dict[str, object]] = [] |
|
|
| for step_index in range(1, max_tokens + 1): |
| distribution, trace = self._score_next_token_from_state( |
| decode_state, |
| top_k=top_k, |
| include_trace=True, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| ) |
| forced_source_token = self._source_evidence_next_token( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| next_token = forced_source_token or self._select_generation_token( |
| distribution, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=max(DEFAULT_GENERATION_TOP_K, top_k), |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=( |
| self._answer_decode_has_continuation(decode_state, generated_tokens) |
| or self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| ), |
| ) |
| if not next_token: |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| trace["step"] = step_index |
| trace["chosen_token"] = next_token |
| trace["chosen_text"] = self._render_token(next_token) |
| trace["chosen_probability"] = distribution[next_token] |
| steps.append(trace) |
| if self._should_stop_answer_sequence(decode_state, generated_tokens): |
| break |
| if self._should_stop_after_answer_path_drift(decode_state, generated_tokens): |
| break |
| if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens): |
| break |
| if ( |
| self._should_stop_generation(generated_tokens) |
| and not self._answer_decode_has_continuation(decode_state, generated_tokens) |
| and not self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| ): |
| break |
|
|
| overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens) |
| while generated_tokens and overflow_budget > 0: |
| has_answer_continuation = self._answer_decode_has_continuation( |
| decode_state, |
| generated_tokens, |
| ) |
| has_source_continuation = self._source_evidence_has_continuation( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| if ( |
| self._starts_new_word(generated_tokens[-1]) |
| and not has_answer_continuation |
| and not has_source_continuation |
| ): |
| break |
| distribution, trace = self._score_next_token_from_state( |
| decode_state, |
| top_k=top_k, |
| include_trace=True, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| ) |
| forced_source_token = self._source_evidence_next_token( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| next_token = forced_source_token or self._select_generation_token( |
| distribution, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=max(DEFAULT_GENERATION_TOP_K, top_k), |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=has_answer_continuation |
| or has_source_continuation, |
| ) |
| if not next_token: |
| break |
| if ( |
| self._starts_new_word(next_token) |
| and not has_answer_continuation |
| and not has_source_continuation |
| ): |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| trace["step"] = len(steps) + 1 |
| trace["chosen_token"] = next_token |
| trace["chosen_text"] = self._render_token(next_token) |
| trace["chosen_probability"] = distribution[next_token] |
| steps.append(trace) |
| if self._should_stop_answer_sequence(decode_state, generated_tokens): |
| break |
| if self._should_stop_after_answer_path_drift(decode_state, generated_tokens): |
| break |
| overflow_budget -= 1 |
|
|
| return { |
| "context": context, |
| "prompt": prompt, |
| "reasoning_mode": active_mode, |
| "reasoning_tokens": reasoning_prefix(active_mode), |
| "generation_policy": { |
| "temperature": temperature, |
| "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), |
| "top_p": top_p, |
| "repetition_penalty": repetition_penalty, |
| }, |
| "prompt_tokens": prompt_tokens, |
| "generated_tokens": generated_tokens, |
| "generated_text": self._finalize_generated_text( |
| self._normalize_generated_tool_protocol_text( |
| self._decode_tokens(generated_tokens), |
| context=context, |
| ) |
| ), |
| "generated_token_count": len(generated_tokens), |
| "steps": steps, |
| } |
|
|
| def _generation_prompt_tokens(self, context: str, active_mode: str) -> tuple[str, list[str]]: |
| assert self.tokenizer is not None |
| prompt = context if "<answer>" in context else f"{context} <answer>" |
| prefix = reasoning_prefix(active_mode) |
| prompt_tokens = self.tokenizer.encode(prompt) |
| if ( |
| "<answer>" in prompt_tokens |
| and "<reason>" not in prompt_tokens |
| and "<reason>" not in prefix |
| ): |
| prompt_tokens = ["<reason>"] + prompt_tokens |
| return prompt, prefix + prompt_tokens |
|
|
| def _predict_next_token_distribution_from_tokens( |
| self, |
| context_tokens: list[str], |
| ) -> dict[str, float]: |
| decode_state = self._build_decode_state(context_tokens) |
| return self._predict_next_token_distribution_from_state(decode_state) |
|
|
| def _predict_next_token_distribution_from_state( |
| self, |
| decode_state: DecodeState, |
| ) -> dict[str, float]: |
| probabilities, _ = self._score_next_token_from_state( |
| decode_state, |
| include_trace=False, |
| ) |
| return probabilities |
|
|
| @staticmethod |
| def _answer_memory_is_confident( |
| *, |
| answer_sequence_match_confidence: float, |
| answer_start_confidence: float, |
| generated_count: int, |
| ) -> bool: |
| if generated_count > 0: |
| return answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR |
| if answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR: |
| return True |
| if answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR: |
| return True |
| if answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR + ANSWER_SEQUENCE_MATCH_FLOOR: |
| return True |
| return ( |
| answer_sequence_match_confidence >= ANSWER_START_MATCH_SUPPORT_FLOOR |
| and answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR |
| and answer_start_confidence <= answer_sequence_match_confidence + ANSWER_START_CONFIDENCE_FLOOR |
| ) |
|
|
| @staticmethod |
| def _answer_sequence_should_lock( |
| *, |
| answer_sequence_confidence: float, |
| answer_sequence_match_confidence: float, |
| has_answer_sequence_prior: bool, |
| ) -> bool: |
| if not has_answer_sequence_prior or answer_sequence_confidence <= 0.0: |
| return False |
| if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: |
| return True |
| if ( |
| answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR |
| and answer_sequence_confidence >= 0.30 |
| and answer_sequence_confidence <= 0.65 |
| ): |
| return True |
| return ( |
| answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR |
| and answer_sequence_confidence <= ANSWER_SEQUENCE_SPIKE_CONFIDENCE |
| ) |
|
|
| def _prompt_start_readout_is_confident( |
| self, |
| prior: object, |
| tokens: Sequence[str] | None = None, |
| ) -> bool: |
| if self.tokenizer is None: |
| return False |
| if tokens is None: |
| if self.embedding_model is None: |
| return False |
| tokens = self.embedding_model.id_to_token |
| values = prior.tolist() if hasattr(prior, "tolist") else list(prior) |
| if not values or not tokens: |
| return False |
| limit = min(len(values), len(tokens)) |
| if limit <= 0: |
| return False |
| best_index = max(range(limit), key=lambda index: float(values[index])) |
| best_probability = float(values[best_index]) |
| if best_probability < PROMPT_START_READOUT_CONFIDENCE_FLOOR: |
| return False |
| meta = self._generation_token_meta(tokens[best_index]) |
| return ( |
| meta.starts_new_word |
| and bool(meta.alphanumeric) |
| and not meta.structural_punctuation |
| and not meta.structural_symbol |
| ) |
|
|
| def _locked_answer_sequence_matches( |
| self, |
| matches: list[tuple[float, int, int]], |
| *, |
| generated_tokens: list[str], |
| temperature: float, |
| answer_sequence_confidence: float, |
| answer_sequence_match_confidence: float, |
| ) -> list[tuple[float, int, int]]: |
| if not matches: |
| return [] |
| if generated_tokens: |
| aligned_matches = [ |
| match |
| for match in matches[:ANSWER_START_TOP_K] |
| if self._answer_sequence_match_has_continuation( |
| match, |
| generated_tokens, |
| ) |
| ] |
| return aligned_matches[:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT] or matches[:1] |
| best_similarity = matches[0][0] |
| near_match_floor = max(ANSWER_SEQUENCE_MATCH_FLOOR, best_similarity - 0.08) |
| varied = [ |
| match |
| for match in matches[:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT] |
| if match[0] >= near_match_floor |
| ] |
| if ( |
| temperature < ANSWER_SEQUENCE_VARIATION_TEMPERATURE |
| and answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR |
| and len(varied) <= 1 |
| ): |
| return matches[:1] |
| return varied or matches[:1] |
|
|
| @staticmethod |
| def _answer_sequence_matches_are_ambiguous( |
| matches: Sequence[tuple[float, int, int]], |
| ) -> bool: |
| if len(matches) < 2: |
| return False |
| best_similarity = float(matches[0][0]) |
| if best_similarity < ANSWER_SEQUENCE_MATCH_FLOOR: |
| return False |
| near_match_floor = max(ANSWER_SEQUENCE_MATCH_FLOOR, best_similarity - 0.08) |
| return any( |
| float(match[0]) >= near_match_floor |
| for match in matches[1:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT] |
| ) |
|
|
| def _answer_sequence_match_has_continuation( |
| self, |
| match: tuple[float, int, int], |
| generated_tokens: list[str], |
| ) -> bool: |
| if ( |
| self.embedding_model is None |
| or self.answer_sequence_tokens is None |
| or not generated_tokens |
| ): |
| return False |
| similarity, sequence_index, _ = match |
| if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): |
| return False |
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if not generated_ids: |
| return False |
| row = self.answer_sequence_tokens[sequence_index] |
| token_ids = [ |
| int(value) |
| for value in (row.tolist() if hasattr(row, "tolist") else row) |
| if int(value) >= 0 |
| ] |
| if not token_ids: |
| return False |
| next_token_id = self._next_sequence_token_id(token_ids, generated_ids) |
| if next_token_id is None: |
| return False |
| token = self.embedding_model.id_to_token[next_token_id] |
| return self._allowed_answer_sequence_token(token, generated_tokens) |
|
|
| def _allowed_answer_sequence_token( |
| self, |
| token: str, |
| generated_tokens: list[str], |
| ) -> bool: |
| assert self.tokenizer is not None |
| if token == self.tokenizer.unk_token: |
| return False |
| if token in self.tokenizer.special_tokens: |
| return self._allowed_generation_token(token, generated_tokens) |
| return True |
|
|
| def _should_relax_answer_sequence_memory( |
| self, |
| matches: list[tuple[float, int, int]], |
| answer_sequence_prior: Sequence[float], |
| *, |
| generated_tokens: list[str], |
| temperature: float, |
| ) -> bool: |
| if temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE or not matches: |
| return False |
| if self._is_inside_tool_protocol_continuation(generated_tokens): |
| return False |
| if self._answer_sequence_prior_prefers_tool_protocol(answer_sequence_prior): |
| return False |
| return True |
|
|
| def _answer_sequence_prior_prefers_tool_protocol( |
| self, |
| answer_sequence_prior: Sequence[float], |
| ) -> bool: |
| if self.embedding_model is None or not answer_sequence_prior: |
| return False |
| best_index = -1 |
| best_value = 0.0 |
| for index, value in enumerate(answer_sequence_prior): |
| if value > best_value: |
| best_index = index |
| best_value = float(value) |
| return ( |
| best_index >= 0 |
| and best_index < len(self.embedding_model.id_to_token) |
| and best_value > 0.0 |
| and self.embedding_model.id_to_token[best_index] in TOOL_PROTOCOL_TOKENS |
| ) |
|
|
| @staticmethod |
| def _answer_start_blend_weights( |
| *, |
| answer_sequence_match_confidence: float, |
| temperature: float = 0.0, |
| ) -> dict[str, float]: |
| if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE: |
| return { |
| "prompt_answer_start": 0.46, |
| "prompt_answer": 0.24, |
| "answer_sequence": 0.10, |
| "answer_start": 0.20, |
| } |
| if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: |
| return { |
| "prompt_answer_start": 0.35, |
| "prompt_answer": 0.10, |
| "answer_sequence": 0.45, |
| "answer_start": 0.10, |
| } |
| if answer_sequence_match_confidence >= 0.40: |
| return { |
| "prompt_answer_start": 0.25, |
| "prompt_answer": 0.12, |
| "answer_sequence": 0.53, |
| "answer_start": 0.10, |
| } |
| return { |
| "prompt_answer_start": 0.08, |
| "prompt_answer": 0.10, |
| "answer_sequence": 0.02, |
| "answer_start": 0.80, |
| } |
|
|
| def _score_next_token_from_tokens( |
| self, |
| context_tokens: list[str], |
| *, |
| top_k: int = 5, |
| include_trace: bool = True, |
| ) -> tuple[dict[str, float], dict[str, object]]: |
| decode_state = self._build_decode_state(context_tokens) |
| return self._score_next_token_from_state( |
| decode_state, |
| top_k=top_k, |
| include_trace=include_trace, |
| ) |
|
|
| def _score_next_token_from_state( |
| self, |
| decode_state: DecodeState, |
| *, |
| top_k: int = 5, |
| include_trace: bool = True, |
| generated_tokens: list[str] | None = None, |
| temperature: float = 0.0, |
| avoid_token_sequences: Sequence[Sequence[str]] | None = None, |
| ) -> tuple[dict[str, float], dict[str, object]]: |
| assert self.embedding_model is not None |
| assert self.readout_weights is not None |
| generated_tokens = generated_tokens or [] |
|
|
| state = self._masked_decode_state(decode_state) |
| logits = self._apply_readout_fast(state) |
| base_probabilities = self._calibrated_softmax(logits) |
| if decode_state.answer_matches is None: |
| decode_state.answer_matches = self._score_answer_matches( |
| decode_state.answer_anchor_state, |
| limit=max(ANSWER_TOP_K, top_k) if include_trace else ANSWER_TOP_K, |
| ) |
| answer_matches = decode_state.answer_matches |
| if decode_state.answer_start_matches is None: |
| decode_state.answer_start_matches = self._score_answer_start_matches( |
| decode_state.answer_anchor_state, |
| limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, |
| ) |
| answer_start_matches = decode_state.answer_start_matches |
| if decode_state.answer_sequence_matches is None: |
| decode_state.answer_sequence_matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, |
| ) |
| answer_sequence_matches = self._filter_avoided_answer_sequence_matches( |
| decode_state.answer_sequence_matches, |
| avoid_token_sequences, |
| ) |
| if not answer_start_matches and answer_sequence_matches: |
| answer_start_matches = self._answer_start_matches_from_sequences( |
| answer_sequence_matches |
| ) |
| decode_state.answer_start_matches = answer_start_matches |
| answer_prior = self._answer_prior_from_matches(answer_matches, generated_tokens) |
| answer_start_prior = self._answer_prior_from_matches(answer_start_matches, generated_tokens) |
| answer_sequence_prior = self._answer_sequence_prior_from_matches( |
| answer_sequence_matches, |
| generated_tokens, |
| temperature=temperature, |
| ) |
| answer_sequence_confidence = max(answer_sequence_prior) if answer_sequence_prior else 0.0 |
| answer_sequence_match_confidence = ( |
| answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 |
| ) |
| answer_start_confidence = answer_start_matches[0][0] if answer_start_matches else 0.0 |
| prompt_copy_is_distinctive = ( |
| not generated_tokens |
| and self._prompt_copy_evidence_is_distinctive(decode_state.context_tokens) |
| ) |
| answer_memory_confident = self._answer_memory_is_confident( |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| answer_start_confidence=answer_start_confidence, |
| generated_count=len(generated_tokens), |
| ) |
| if prompt_copy_is_distinctive and not answer_sequence_matches: |
| answer_memory_confident = False |
| has_answer_sequence_prior = any(value > 0.0 for value in answer_sequence_prior) |
| if not answer_memory_confident: |
| zero_prior = [0.0 for _ in self.embedding_model.id_to_token] |
| answer_prior = zero_prior |
| answer_start_prior = zero_prior |
| answer_sequence_prior = zero_prior |
| answer_sequence_confidence = 0.0 |
| has_answer_sequence_prior = False |
| answer_locked = self._answer_sequence_should_lock( |
| answer_sequence_confidence=answer_sequence_confidence, |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| has_answer_sequence_prior=has_answer_sequence_prior, |
| ) or ( |
| bool(generated_tokens) |
| and temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE |
| and self._answer_sequence_has_continuation( |
| generated_tokens, |
| answer_sequence_matches, |
| ) |
| ) |
| if self._should_relax_answer_sequence_memory( |
| answer_sequence_matches, |
| answer_sequence_prior, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| ): |
| answer_locked = False |
| if decode_state.prompt_answer_prior is None: |
| decode_state.prompt_answer_prior = self._prompt_answer_readout_prior( |
| decode_state.answer_anchor_state, |
| start=False, |
| ) |
| prompt_answer_prior = decode_state.prompt_answer_prior |
| prompt_answer_start_prior = ( |
| decode_state.prompt_answer_start_prior |
| if not generated_tokens |
| else [0.0 for _ in self.embedding_model.id_to_token] |
| ) |
| if not generated_tokens and prompt_answer_start_prior is None: |
| decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior( |
| decode_state.answer_anchor_state, |
| start=True, |
| ) |
| prompt_answer_start_prior = decode_state.prompt_answer_start_prior |
| prompt_start_readout_confident = ( |
| not generated_tokens |
| and prompt_answer_start_prior is not None |
| and self._prompt_start_readout_is_confident(prompt_answer_start_prior) |
| ) |
| prompt_readout_supported = answer_memory_confident and ( |
| answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR |
| or answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR |
| ) |
| if prompt_start_readout_confident: |
| prompt_readout_supported = True |
| if not prompt_readout_supported: |
| prompt_answer_prior = [0.0 for _ in self.embedding_model.id_to_token] |
| prompt_answer_start_prior = [0.0 for _ in self.embedding_model.id_to_token] |
| use_answer_start = ( |
| not generated_tokens |
| and ( |
| any(value > 0.0 for value in answer_start_prior) |
| or any(value > 0.0 for value in prompt_answer_start_prior) |
| ) |
| ) |
| if answer_locked: |
| locked_matches = self._locked_answer_sequence_matches( |
| answer_sequence_matches, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| answer_sequence_confidence=answer_sequence_confidence, |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| ) |
| answer_sequence_prior = self._answer_sequence_prior_from_matches( |
| locked_matches, |
| generated_tokens, |
| temperature=temperature, |
| ) |
| answer_prior = answer_sequence_prior |
| elif use_answer_start: |
| start_blend = self._answer_start_blend_weights( |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| temperature=temperature, |
| ) |
| answer_prior = self._weighted_prior_sum( |
| [ |
| (start_blend["prompt_answer_start"], prompt_answer_start_prior), |
| (start_blend["prompt_answer"], prompt_answer_prior), |
| (start_blend["answer_sequence"], answer_sequence_prior), |
| (start_blend["answer_start"], answer_start_prior), |
| ], |
| ) |
| elif any(value > 0.0 for value in answer_sequence_prior): |
| sequence_weight = ( |
| 0.10 |
| if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE |
| else 0.30 |
| ) |
| answer_prior = self._weighted_prior_sum( |
| [ |
| (0.55, prompt_answer_prior), |
| (sequence_weight, answer_sequence_prior), |
| (0.20, answer_prior), |
| ], |
| ) |
| elif any(value > 0.0 for value in prompt_answer_prior): |
| answer_prior = self._weighted_prior_sum( |
| [ |
| (0.65, prompt_answer_prior), |
| (0.35, answer_prior), |
| ], |
| ) |
| answer_guided = ( |
| max(answer_prior) >= 0.08 |
| if answer_prior |
| else False |
| ) |
| associative_matches = ( |
| [] |
| if use_answer_start or answer_guided |
| else self._score_associative_matches( |
| state, |
| limit=max(ASSOCIATIVE_TOP_K, top_k) if include_trace else ASSOCIATIVE_TOP_K, |
| ) |
| ) |
| associative_prior = ( |
| [0.0 for _ in self.embedding_model.id_to_token] |
| if use_answer_start or answer_guided |
| else self._associative_prior_from_matches(associative_matches) |
| ) |
| transition_prior, transition_order = self._transition_prior_with_order(decode_state.context_tokens) |
| copy_prior = self._copy_prior(decode_state.context_tokens) |
| source_evidence_prior = self._source_evidence_prior( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| preference_prior = self._preference_prior() |
| probabilities, blend_weights = self._blend_probabilities( |
| base_probabilities, |
| answer_prior, |
| associative_prior, |
| transition_prior, |
| copy_prior, |
| source_evidence_prior, |
| preference_prior, |
| transition_order=transition_order, |
| generated_count=len(generated_tokens), |
| answer_locked=answer_locked, |
| answer_guided_start=use_answer_start, |
| copy_guided_start=prompt_copy_is_distinctive, |
| ) |
| probabilities = self._focus_answer_start_probabilities( |
| probabilities, |
| answer_sequence_prior, |
| generated_tokens=generated_tokens, |
| answer_memory_confident=answer_memory_confident, |
| has_answer_sequence_prior=has_answer_sequence_prior, |
| sequence_focus_allowed=answer_sequence_match_confidence >= 0.40 or answer_locked, |
| temperature=temperature, |
| ) |
| distribution = { |
| token: probabilities[index] |
| for index, token in enumerate(self.embedding_model.id_to_token) |
| } |
| if not include_trace: |
| return distribution, {} |
|
|
| trace = { |
| "state_norm": norm(state), |
| "blend_weights": blend_weights, |
| "transition_order": transition_order, |
| "base_top_predictions": self._top_entries_from_vector(base_probabilities, top_k), |
| "answer_top_predictions": self._top_entries_from_vector(answer_prior, top_k), |
| "prompt_answer_top_predictions": self._top_entries_from_vector(prompt_answer_prior, top_k), |
| "prompt_answer_start_top_predictions": self._top_entries_from_vector(prompt_answer_start_prior, top_k), |
| "answer_start_top_predictions": self._top_entries_from_vector(answer_start_prior, top_k), |
| "answer_sequence_top_predictions": self._top_entries_from_vector(answer_sequence_prior, top_k), |
| "associative_top_predictions": self._top_entries_from_vector(associative_prior, top_k), |
| "transition_top_predictions": self._top_entries_from_vector(transition_prior, top_k), |
| "copy_top_predictions": self._top_entries_from_vector(copy_prior, top_k), |
| "source_evidence_top_predictions": self._top_entries_from_vector(source_evidence_prior, top_k), |
| "preference_top_predictions": self._top_entries_from_vector(preference_prior, top_k), |
| "final_top_predictions": self._top_entries_from_vector(probabilities, top_k), |
| "associative_matches": [ |
| { |
| "example_index": example_index, |
| "similarity": similarity, |
| **self._token_entry(token_id, similarity), |
| } |
| for similarity, token_id, example_index in associative_matches[:top_k] |
| ], |
| "answer_matches": [ |
| { |
| "example_index": example_index, |
| "similarity": similarity, |
| **self._token_entry(token_id, similarity), |
| } |
| for similarity, token_id, example_index in answer_matches[:top_k] |
| ], |
| "answer_start_matches": [ |
| { |
| "example_index": example_index, |
| "similarity": similarity, |
| **self._token_entry(token_id, similarity), |
| } |
| for similarity, token_id, example_index in answer_start_matches[:top_k] |
| ], |
| "answer_sequence_matches": [ |
| { |
| "example_index": example_index, |
| "similarity": similarity, |
| } |
| for similarity, _, example_index in answer_sequence_matches[:top_k] |
| ], |
| "reasoning_summary": self._build_reasoning_summary( |
| transition_order, |
| blend_weights, |
| ), |
| } |
| return distribution, trace |
|
|
| def _score_next_token_array_from_state( |
| self, |
| decode_state: DecodeState, |
| *, |
| include_associative: bool, |
| generated_tokens: list[str] | None = None, |
| temperature: float = 0.0, |
| avoid_token_sequences: Sequence[Sequence[str]] | None = None, |
| ) -> tuple[object, dict[str, float]]: |
| assert np is not None |
| assert self.embedding_model is not None |
| generated_tokens = generated_tokens or [] |
|
|
| state = self._masked_decode_state_array(decode_state) |
| logits = self._apply_readout_array(state) |
| base_probabilities = self._calibrated_softmax_array(logits) |
| if decode_state.answer_matches is None: |
| decode_state.answer_matches = self._score_answer_matches(decode_state.answer_anchor_state) |
| answer_prior = np.asarray( |
| self._answer_prior_from_matches( |
| decode_state.answer_matches, |
| generated_tokens, |
| ), |
| dtype=np.float64, |
| ) |
| if decode_state.answer_sequence_matches is None: |
| decode_state.answer_sequence_matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| ) |
| answer_sequence_matches = self._filter_avoided_answer_sequence_matches( |
| decode_state.answer_sequence_matches, |
| avoid_token_sequences, |
| ) |
| if not decode_state.answer_start_matches and answer_sequence_matches: |
| decode_state.answer_start_matches = self._answer_start_matches_from_sequences( |
| answer_sequence_matches |
| ) |
| answer_sequence_prior = np.asarray( |
| self._answer_sequence_prior_from_matches( |
| answer_sequence_matches, |
| generated_tokens, |
| temperature=temperature, |
| ), |
| dtype=np.float64, |
| ) |
| answer_sequence_confidence = ( |
| float(answer_sequence_prior.max()) if answer_sequence_prior.size else 0.0 |
| ) |
| answer_sequence_match_confidence = ( |
| answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 |
| ) |
| if not generated_tokens and decode_state.answer_start_matches is None: |
| decode_state.answer_start_matches = self._score_answer_start_matches( |
| decode_state.answer_anchor_state |
| ) |
| answer_start_confidence = ( |
| decode_state.answer_start_matches[0][0] |
| if not generated_tokens and decode_state.answer_start_matches |
| else 0.0 |
| ) |
| prompt_copy_is_distinctive = ( |
| not generated_tokens |
| and self._prompt_copy_evidence_is_distinctive(decode_state.context_tokens) |
| ) |
| answer_memory_confident = self._answer_memory_is_confident( |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| answer_start_confidence=answer_start_confidence, |
| generated_count=len(generated_tokens), |
| ) |
| if prompt_copy_is_distinctive and not answer_sequence_matches: |
| answer_memory_confident = False |
| has_answer_sequence_prior = bool(np.any(answer_sequence_prior > 0.0)) |
| if not answer_memory_confident: |
| answer_prior = np.zeros_like(base_probabilities) |
| answer_sequence_prior = np.zeros_like(base_probabilities) |
| answer_sequence_confidence = 0.0 |
| has_answer_sequence_prior = False |
| answer_locked = self._answer_sequence_should_lock( |
| answer_sequence_confidence=answer_sequence_confidence, |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| has_answer_sequence_prior=has_answer_sequence_prior, |
| ) or ( |
| bool(generated_tokens) |
| and temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE |
| and self._answer_sequence_has_continuation( |
| generated_tokens, |
| answer_sequence_matches, |
| ) |
| ) |
| if self._should_relax_answer_sequence_memory( |
| answer_sequence_matches, |
| answer_sequence_prior.tolist(), |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| ): |
| answer_locked = False |
| if decode_state.prompt_answer_prior is None: |
| decode_state.prompt_answer_prior = self._prompt_answer_readout_prior_array( |
| decode_state.answer_anchor_state, |
| start=False, |
| ) |
| prompt_answer_prior = decode_state.prompt_answer_prior |
| prompt_answer_start_prior = np.zeros_like(base_probabilities) |
| use_answer_start = False |
| if answer_locked: |
| locked_matches = self._locked_answer_sequence_matches( |
| answer_sequence_matches, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| answer_sequence_confidence=answer_sequence_confidence, |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| ) |
| answer_sequence_prior = np.asarray( |
| self._answer_sequence_prior_from_matches( |
| locked_matches, |
| generated_tokens, |
| temperature=temperature, |
| ), |
| dtype=np.float64, |
| ) |
| answer_prior = answer_sequence_prior |
| elif not generated_tokens: |
| if decode_state.prompt_answer_start_prior is None: |
| decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior_array( |
| decode_state.answer_anchor_state, |
| start=True, |
| ) |
| prompt_answer_start_prior = ( |
| decode_state.prompt_answer_start_prior |
| if decode_state.prompt_answer_start_prior is not None |
| else np.zeros_like(base_probabilities) |
| ) |
| prompt_start_readout_confident = self._prompt_start_readout_is_confident( |
| prompt_answer_start_prior |
| ) |
| prompt_readout_supported = answer_memory_confident and ( |
| answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR |
| or answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR |
| ) |
| if prompt_start_readout_confident: |
| prompt_readout_supported = True |
| if not prompt_readout_supported: |
| prompt_answer_prior = np.zeros_like(base_probabilities) |
| prompt_answer_start_prior = np.zeros_like(base_probabilities) |
| answer_start_prior = np.asarray( |
| self._answer_prior_from_matches( |
| decode_state.answer_start_matches, |
| generated_tokens, |
| ), |
| dtype=np.float64, |
| ) |
| if not answer_memory_confident: |
| answer_start_prior = np.zeros_like(base_probabilities) |
| if np.any(answer_start_prior > 0.0) or np.any(prompt_answer_start_prior > 0.0): |
| start_blend = self._answer_start_blend_weights( |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| temperature=temperature, |
| ) |
| answer_prior = self._weighted_prior_sum_array( |
| [ |
| (start_blend["prompt_answer_start"], prompt_answer_start_prior), |
| (start_blend["prompt_answer"], prompt_answer_prior), |
| (start_blend["answer_sequence"], answer_sequence_prior), |
| (start_blend["answer_start"], answer_start_prior), |
| ], |
| ) |
| use_answer_start = True |
| if answer_locked: |
| answer_prior = answer_sequence_prior |
| elif not use_answer_start and np.any(answer_sequence_prior > 0.0): |
| sequence_weight = ( |
| 0.10 |
| if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE |
| else 0.30 |
| ) |
| answer_prior = self._weighted_prior_sum_array( |
| [ |
| (0.55, prompt_answer_prior), |
| (sequence_weight, answer_sequence_prior), |
| (0.20, answer_prior), |
| ], |
| ) |
| elif not use_answer_start and np.any(prompt_answer_prior > 0.0): |
| answer_prior = self._weighted_prior_sum_array( |
| [ |
| (0.65, prompt_answer_prior), |
| (0.35, answer_prior), |
| ], |
| ) |
| answer_guided = bool(answer_prior.size and float(np.max(answer_prior)) >= 0.08) |
| if include_associative and not use_answer_start and not answer_guided: |
| associative_prior = np.asarray( |
| self._associative_prior_from_matches( |
| self._score_associative_matches(state) |
| ), |
| dtype=np.float64, |
| ) |
| else: |
| associative_prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| transition_prior, transition_order = self._transition_prior_array_with_order( |
| decode_state.context_tokens |
| ) |
| copy_prior = self._copy_prior_array(decode_state.context_tokens) |
| source_evidence_prior = self._source_evidence_prior_array( |
| decode_state.context_tokens, |
| generated_tokens, |
| ) |
| preference_prior = self._preference_prior_array() |
| probabilities, blend_weights = self._blend_probability_arrays( |
| base_probabilities, |
| answer_prior, |
| associative_prior, |
| transition_prior, |
| copy_prior, |
| source_evidence_prior, |
| preference_prior, |
| transition_order=transition_order, |
| generated_count=len(generated_tokens), |
| answer_locked=answer_locked, |
| answer_guided_start=use_answer_start, |
| ) |
| probabilities = self._focus_answer_start_probability_array( |
| probabilities, |
| answer_sequence_prior, |
| generated_tokens=generated_tokens, |
| answer_memory_confident=answer_memory_confident, |
| has_answer_sequence_prior=has_answer_sequence_prior, |
| sequence_focus_allowed=answer_sequence_match_confidence >= 0.40 or answer_locked, |
| temperature=temperature, |
| ) |
| return probabilities, blend_weights |
|
|
| @staticmethod |
| def _focus_answer_start_probabilities( |
| probabilities: Vector, |
| answer_sequence_prior: Vector, |
| *, |
| generated_tokens: list[str], |
| answer_memory_confident: bool, |
| has_answer_sequence_prior: bool, |
| sequence_focus_allowed: bool | None = None, |
| temperature: float = 0.0, |
| ) -> Vector: |
| if sequence_focus_allowed is None: |
| sequence_focus_allowed = has_answer_sequence_prior |
| if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE: |
| return probabilities |
| if ( |
| generated_tokens |
| or not answer_memory_confident |
| or not has_answer_sequence_prior |
| or not sequence_focus_allowed |
| ): |
| return probabilities |
| if not probabilities or not answer_sequence_prior: |
| return probabilities |
| focused = [ |
| probability if index < len(answer_sequence_prior) and answer_sequence_prior[index] > 0.0 else probability * 0.02 |
| for index, probability in enumerate(probabilities) |
| ] |
| total = sum(focused) |
| if total <= 0.0: |
| return probabilities |
| return [value / total for value in focused] |
|
|
| @staticmethod |
| def _focus_answer_start_probability_array( |
| probabilities: object, |
| answer_sequence_prior: object, |
| *, |
| generated_tokens: list[str], |
| answer_memory_confident: bool, |
| has_answer_sequence_prior: bool, |
| sequence_focus_allowed: bool | None = None, |
| temperature: float = 0.0, |
| ) -> object: |
| if sequence_focus_allowed is None: |
| sequence_focus_allowed = has_answer_sequence_prior |
| if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE: |
| return probabilities |
| if ( |
| np is None |
| or generated_tokens |
| or not answer_memory_confident |
| or not has_answer_sequence_prior |
| or not sequence_focus_allowed |
| ): |
| return probabilities |
| values = np.asarray(probabilities, dtype=np.float64) |
| prior = np.asarray(answer_sequence_prior, dtype=np.float64) |
| if values.size == 0 or prior.size != values.size or not np.any(prior > 0.0): |
| return probabilities |
| focused = values.copy() |
| focused[prior <= 0.0] *= 0.02 |
| total = float(focused.sum()) |
| if total <= 0.0: |
| return probabilities |
| return focused / total |
|
|
| def _calibrated_softmax( |
| self, |
| logits: Vector, |
| *, |
| scale: float = READOUT_LOGIT_ZSCORE_SCALE, |
| ) -> Vector: |
| if np is not None: |
| return self._calibrated_softmax_array( |
| np.asarray(logits, dtype=np.float64), |
| scale=scale, |
| ).tolist() |
| if not logits: |
| return [] |
| center = mean(logits) |
| variance = mean([(value - center) * (value - center) for value in logits]) |
| spread = variance**0.5 |
| if spread <= 1e-12: |
| return softmax(logits) |
| calibrated = [ |
| max(-20.0, min(20.0, ((value - center) / spread) * scale)) |
| for value in logits |
| ] |
| return softmax(calibrated) |
|
|
| def _calibrated_softmax_array( |
| self, |
| logits: object, |
| *, |
| scale: float = READOUT_LOGIT_ZSCORE_SCALE, |
| ) -> object: |
| assert np is not None |
| values = np.asarray(logits, dtype=np.float64) |
| if values.size == 0: |
| return values |
| spread = float(values.std()) |
| if spread > 1e-12: |
| values = ((values - float(values.mean())) / spread) * scale |
| values = np.clip(values, -20.0, 20.0) |
| else: |
| values = values - float(values.max()) |
| values = values - float(values.max()) |
| exponentials = np.exp(values) |
| total = float(exponentials.sum()) |
| if total <= 0.0: |
| return np.full(values.shape, 1.0 / max(1, values.size), dtype=np.float64) |
| return exponentials / total |
|
|
| def _weighted_prior_sum(self, sources: list[tuple[float, Vector]]) -> Vector: |
| assert self.embedding_model is not None |
| active_sources = [ |
| (weight, vector) |
| for weight, vector in sources |
| if weight > 0.0 and any(value > 0.0 for value in vector) |
| ] |
| if not active_sources: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| total_weight = sum(weight for weight, _ in active_sources) |
| merged = [0.0 for _ in self.embedding_model.id_to_token] |
| for weight, vector in active_sources: |
| normalized_weight = weight / total_weight |
| for index, value in enumerate(vector): |
| merged[index] += normalized_weight * value |
| return _normalize_vector(merged) |
|
|
| def _weighted_prior_sum_array(self, sources: list[tuple[float, object]]) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| active_sources = [ |
| (weight, np.asarray(vector, dtype=np.float64)) |
| for weight, vector in sources |
| if weight > 0.0 and np.any(np.asarray(vector, dtype=np.float64) > 0.0) |
| ] |
| if not active_sources: |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| total_weight = sum(weight for weight, _ in active_sources) |
| merged = np.zeros_like(active_sources[0][1], dtype=np.float64) |
| for weight, vector in active_sources: |
| merged += (weight / total_weight) * vector |
| total = float(merged.sum()) |
| if total > 0.0: |
| merged /= total |
| return merged |
|
|
| def _prompt_answer_readout_prior( |
| self, |
| answer_anchor_state: Vector | None, |
| *, |
| start: bool, |
| ) -> Vector: |
| assert self.embedding_model is not None |
| if answer_anchor_state is None: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| weights = self.prompt_answer_start_weights if start else self.prompt_answer_weights |
| bias = self.prompt_answer_start_bias if start else self.prompt_answer_bias |
| if np is not None: |
| return self._prompt_answer_readout_prior_array( |
| answer_anchor_state, |
| start=start, |
| ).tolist() |
| if not weights: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) |
| logits = apply_readout(weights, state) |
| if bias: |
| logits = [value + bias[index] for index, value in enumerate(logits)] |
| return self._calibrated_softmax( |
| logits, |
| scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, |
| ) |
|
|
| def _prompt_answer_readout_prior_array( |
| self, |
| answer_anchor_state: Vector | None, |
| *, |
| start: bool, |
| ) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| if answer_anchor_state is None: |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| weights = ( |
| self.prompt_answer_start_weights_array |
| if start |
| else self.prompt_answer_weights_array |
| ) |
| bias = self.prompt_answer_start_bias_array if start else self.prompt_answer_bias_array |
| if weights is None: |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| state_array = self._center_state_array( |
| self._masked_combined_state_array(answer_anchor_state) |
| ) |
| logits = weights @ state_array |
| if bias is not None and bias.shape == logits.shape: |
| logits = logits + bias |
| return self._calibrated_softmax_array( |
| logits, |
| scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, |
| ) |
|
|
| def save(self, path: str | Path) -> None: |
| self._require_fit() |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| assert self.ternary_mask is not None |
| assert self.readout_weights is not None |
| assert self.associative_keys is not None |
| assert self.associative_values is not None |
| assert self.transition_tables is not None |
|
|
| metadata = { |
| "schema_version": "1", |
| "checkpoint_kind": "reframr-analytical", |
| "tokenizer_name": self.tokenizer.name, |
| "config": json.dumps(self.config.to_dict(), separators=(",", ":")), |
| "tokenizer": json.dumps(self.tokenizer.to_dict(), separators=(",", ":")), |
| "embedding_id_to_token": json.dumps(self.embedding_model.id_to_token, separators=(",", ":")), |
| "tokenizer_vocab_size": str(self.tokenizer.vocab_size), |
| "transition_table_format": "tensor-v1", |
| } |
| self._refresh_answer_fingerprint_hashes() |
| if np is not None: |
| self._refresh_numeric_caches() |
| transition_tensors = self._transition_table_tensors() |
| tensors = { |
| "embedding_table": self.embedding_model.embeddings, |
| "ternary_scale": [self.ternary_scale], |
| "ternary_mask": self.ternary_mask, |
| "readout_weights": self.readout_weights, |
| "readout_bias": self.readout_bias |
| or [0.0 for _ in self.embedding_model.id_to_token], |
| "prompt_answer_weights": self.prompt_answer_weights |
| if self.prompt_answer_weights is not None |
| else [], |
| "prompt_answer_bias": self.prompt_answer_bias |
| or [0.0 for _ in self.embedding_model.id_to_token], |
| "prompt_answer_start_weights": self.prompt_answer_start_weights |
| if self.prompt_answer_start_weights is not None |
| else [], |
| "prompt_answer_start_bias": self.prompt_answer_start_bias |
| or [0.0 for _ in self.embedding_model.id_to_token], |
| "trace_token_weights": self.trace_token_weights |
| or [1.0 for _ in self.embedding_model.id_to_token], |
| "preference_bias": self.preference_bias |
| or [0.0 for _ in self.embedding_model.id_to_token], |
| "state_offset": self.state_offset |
| or [0.0 for _ in range(self._combined_state_width())], |
| "associative_keys": self.associative_keys, |
| "associative_key_norms": self.associative_key_norms_array |
| if self.associative_key_norms_array is not None |
| else self.associative_key_norms or [], |
| "associative_values": self.associative_values, |
| "answer_keys": self.answer_keys if self.answer_keys is not None else [], |
| "answer_key_norms": self.answer_key_norms_array |
| if self.answer_key_norms_array is not None |
| else self.answer_key_norms or [], |
| "answer_similarity_keys": self.answer_similarity_keys_array |
| if self.answer_similarity_keys_array is not None |
| else [], |
| "answer_similarity_key_norms": self.answer_similarity_key_norms_array |
| if self.answer_similarity_key_norms_array is not None |
| else [], |
| "answer_values": self.answer_values if self.answer_values is not None else [], |
| "answer_start_keys": self.answer_start_keys if self.answer_start_keys is not None else [], |
| "answer_start_key_norms": self.answer_start_key_norms_array |
| if self.answer_start_key_norms_array is not None |
| else self.answer_start_key_norms or [], |
| "answer_start_similarity_keys": self.answer_start_similarity_keys_array |
| if self.answer_start_similarity_keys_array is not None |
| else [], |
| "answer_start_similarity_key_norms": self.answer_start_similarity_key_norms_array |
| if self.answer_start_similarity_key_norms_array is not None |
| else [], |
| "answer_start_values": self.answer_start_values if self.answer_start_values is not None else [], |
| "answer_sequence_keys": self.answer_sequence_keys if self.answer_sequence_keys is not None else [], |
| "answer_sequence_key_norms": self.answer_sequence_key_norms_array |
| if self.answer_sequence_key_norms_array is not None |
| else self.answer_sequence_key_norms or [], |
| "answer_sequence_similarity_keys": self.answer_sequence_similarity_keys_array |
| if self.answer_sequence_similarity_keys_array is not None |
| else [], |
| "answer_sequence_similarity_key_norms": self.answer_sequence_similarity_key_norms_array |
| if self.answer_sequence_similarity_key_norms_array is not None |
| else [], |
| "answer_sequence_prompt_tokens": self.answer_sequence_prompt_tokens if self.answer_sequence_prompt_tokens is not None else [], |
| "answer_sequence_tokens": self.answer_sequence_tokens if self.answer_sequence_tokens is not None else [], |
| "answer_fingerprint_hashes": self._answer_fingerprint_tensor(), |
| **transition_tensors, |
| } |
| write_safetensor_file(path, tensors, metadata=metadata) |
|
|
| @classmethod |
| def load(cls, path: str | Path) -> "ReframrModel": |
| checkpoint_path = Path(path) |
| checkpoint = read_safetensor_file( |
| checkpoint_path, |
| arrays=np is not None and checkpoint_path.stat().st_size > 10_000_000, |
| ) |
| metadata = checkpoint.metadata |
| config = ReframrConfig.from_dict(json.loads(metadata["config"])) |
| model = cls(config) |
| model.tokenizer = NativeTokenizer.from_dict(json.loads(metadata["tokenizer"])) |
| id_to_token = [str(token) for token in json.loads(metadata["embedding_id_to_token"])] |
| embedding_table = checkpoint.tensors["embedding_table"] |
| if np is not None and hasattr(embedding_table, "shape"): |
| embeddings = embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| else: |
| embeddings = [[float(value) for value in row] for row in embedding_table] |
| model.embedding_model = EmbeddingModel( |
| token_to_id={token: index for index, token in enumerate(id_to_token)}, |
| id_to_token=id_to_token, |
| embeddings=embeddings, |
| ppmi_matrix=[], |
| ) |
| model.memory_units = [ |
| AnalyticalMemoryUnit(model.config.state_dim, timescale) |
| for timescale in model.config.timescales |
| ] |
| model.ternary_scale = float(checkpoint.tensors["ternary_scale"][0]) |
| model.ternary_mask = [int(value) for value in checkpoint.tensors["ternary_mask"]] |
| readout_tensor = checkpoint.tensors["readout_weights"] |
| model.readout_weights = ( |
| readout_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if np is not None and hasattr(readout_tensor, "shape") |
| else [[float(value) for value in row] for row in readout_tensor] |
| ) |
| readout_bias_tensor = checkpoint.tensors.get("readout_bias", []) |
| model.readout_bias = [ |
| float(value) for value in ( |
| readout_bias_tensor.tolist() |
| if hasattr(readout_bias_tensor, "tolist") |
| else readout_bias_tensor |
| ) |
| ] |
| if not model.readout_bias: |
| model.readout_bias = [0.0 for _ in id_to_token] |
| prompt_answer_tensor = checkpoint.tensors.get("prompt_answer_weights", []) |
| model.prompt_answer_weights = ( |
| prompt_answer_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if np is not None |
| and hasattr(prompt_answer_tensor, "shape") |
| and len(prompt_answer_tensor.shape) == 2 |
| else [[float(value) for value in row] for row in prompt_answer_tensor] |
| ) |
| prompt_answer_bias_tensor = checkpoint.tensors.get("prompt_answer_bias", []) |
| model.prompt_answer_bias = [ |
| float(value) for value in ( |
| prompt_answer_bias_tensor.tolist() |
| if hasattr(prompt_answer_bias_tensor, "tolist") |
| else prompt_answer_bias_tensor |
| ) |
| ] |
| if not model.prompt_answer_bias: |
| model.prompt_answer_bias = [0.0 for _ in id_to_token] |
| prompt_answer_start_tensor = checkpoint.tensors.get("prompt_answer_start_weights", []) |
| model.prompt_answer_start_weights = ( |
| prompt_answer_start_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if np is not None |
| and hasattr(prompt_answer_start_tensor, "shape") |
| and len(prompt_answer_start_tensor.shape) == 2 |
| else [[float(value) for value in row] for row in prompt_answer_start_tensor] |
| ) |
| prompt_answer_start_bias_tensor = checkpoint.tensors.get("prompt_answer_start_bias", []) |
| model.prompt_answer_start_bias = [ |
| float(value) for value in ( |
| prompt_answer_start_bias_tensor.tolist() |
| if hasattr(prompt_answer_start_bias_tensor, "tolist") |
| else prompt_answer_start_bias_tensor |
| ) |
| ] |
| if not model.prompt_answer_start_bias: |
| model.prompt_answer_start_bias = [0.0 for _ in id_to_token] |
| trace_weight_tensor = checkpoint.tensors.get("trace_token_weights", []) |
| model.trace_token_weights = [ |
| float(value) for value in ( |
| trace_weight_tensor.tolist() |
| if hasattr(trace_weight_tensor, "tolist") |
| else trace_weight_tensor |
| ) |
| ] |
| if not model.trace_token_weights: |
| model.trace_token_weights = [ |
| 1.0 if token in TOOL_PROTOCOL_TOKENS else 0.0 if token in model.tokenizer.special_tokens else 1.0 |
| for token in id_to_token |
| ] |
| preference_bias_tensor = checkpoint.tensors.get("preference_bias", []) |
| model.preference_bias = [ |
| float(value) for value in ( |
| preference_bias_tensor.tolist() |
| if hasattr(preference_bias_tensor, "tolist") |
| else preference_bias_tensor |
| ) |
| ] |
| if not model.preference_bias: |
| model.preference_bias = [0.0 for _ in id_to_token] |
| state_offset_tensor = checkpoint.tensors.get("state_offset", []) |
| model.state_offset = [ |
| float(value) for value in ( |
| state_offset_tensor.tolist() |
| if hasattr(state_offset_tensor, "tolist") |
| else state_offset_tensor |
| ) |
| ] |
| if not model.state_offset: |
| model.state_offset = [0.0 for _ in range(model._combined_state_width())] |
|
|
| def _runtime_vector_tensor(name: str) -> object | None: |
| tensor = checkpoint.tensors.get(name, []) |
| if np is not None and hasattr(tensor, "shape"): |
| if len(tensor.shape) == 1 and int(tensor.shape[0]) > 0: |
| return tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| return None |
| values = tensor.tolist() if hasattr(tensor, "tolist") else tensor |
| return [float(value) for value in values] if values else None |
|
|
| def _runtime_matrix_tensor(name: str) -> object | None: |
| tensor = checkpoint.tensors.get(name, []) |
| if ( |
| np is not None |
| and hasattr(tensor, "shape") |
| and len(tensor.shape) == 2 |
| and int(tensor.shape[0]) > 0 |
| ): |
| return tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| return None |
|
|
| associative_tensor = checkpoint.tensors.get("associative_keys", []) |
| model.associative_keys = ( |
| associative_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if np is not None and hasattr(associative_tensor, "shape") |
| else [[float(value) for value in row] for row in associative_tensor] |
| ) |
| cached_associative_key_norms = _runtime_vector_tensor("associative_key_norms") |
| if cached_associative_key_norms is not None: |
| model.associative_key_norms = cached_associative_key_norms |
| elif np is not None and hasattr(model.associative_keys, "shape"): |
| model.associative_key_norms = None |
| else: |
| model.associative_key_norms = [norm(key) for key in model.associative_keys] |
| raw_associative_values = checkpoint.tensors.get("associative_values", []) |
| model.associative_values = [ |
| int(value) for value in ( |
| raw_associative_values.tolist() |
| if hasattr(raw_associative_values, "tolist") |
| else raw_associative_values |
| ) |
| ] |
| answer_tensor = checkpoint.tensors.get("answer_keys", []) |
| if np is not None and hasattr(answer_tensor, "shape"): |
| model.answer_keys = ( |
| answer_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if len(answer_tensor.shape) == 2 |
| else [] |
| ) |
| else: |
| model.answer_keys = [[float(value) for value in row] for row in answer_tensor] |
| if ( |
| np is not None |
| and hasattr(model.answer_keys, "shape") |
| and len(model.answer_keys.shape) == 2 |
| ): |
| model.answer_key_norms = _runtime_vector_tensor("answer_key_norms") |
| else: |
| model.answer_key_norms = ( |
| _runtime_vector_tensor("answer_key_norms") |
| or [norm(key) for key in model.answer_keys] |
| ) |
| raw_answer_values = checkpoint.tensors.get("answer_values", []) |
| model.answer_values = [ |
| int(value) for value in ( |
| raw_answer_values.tolist() |
| if hasattr(raw_answer_values, "tolist") |
| else raw_answer_values |
| ) |
| ] |
| answer_start_tensor = checkpoint.tensors.get("answer_start_keys", []) |
| if np is not None and hasattr(answer_start_tensor, "shape"): |
| model.answer_start_keys = ( |
| answer_start_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if len(answer_start_tensor.shape) == 2 |
| else [] |
| ) |
| else: |
| model.answer_start_keys = [ |
| [float(value) for value in row] for row in answer_start_tensor |
| ] |
| if ( |
| np is not None |
| and hasattr(model.answer_start_keys, "shape") |
| and len(model.answer_start_keys.shape) == 2 |
| ): |
| model.answer_start_key_norms = _runtime_vector_tensor("answer_start_key_norms") |
| else: |
| model.answer_start_key_norms = ( |
| _runtime_vector_tensor("answer_start_key_norms") |
| or [norm(key) for key in model.answer_start_keys] |
| ) |
| raw_answer_start_values = checkpoint.tensors.get("answer_start_values", []) |
| model.answer_start_values = [ |
| int(value) for value in ( |
| raw_answer_start_values.tolist() |
| if hasattr(raw_answer_start_values, "tolist") |
| else raw_answer_start_values |
| ) |
| ] |
| answer_sequence_tensor = checkpoint.tensors.get("answer_sequence_keys", []) |
| if np is not None and hasattr(answer_sequence_tensor, "shape"): |
| model.answer_sequence_keys = ( |
| answer_sequence_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if len(answer_sequence_tensor.shape) == 2 |
| else [] |
| ) |
| else: |
| model.answer_sequence_keys = [ |
| [float(value) for value in row] for row in answer_sequence_tensor |
| ] |
| if ( |
| np is not None |
| and hasattr(model.answer_sequence_keys, "shape") |
| and len(model.answer_sequence_keys.shape) == 2 |
| ): |
| model.answer_sequence_key_norms = _runtime_vector_tensor("answer_sequence_key_norms") |
| else: |
| model.answer_sequence_key_norms = ( |
| _runtime_vector_tensor("answer_sequence_key_norms") |
| or [norm(key) for key in model.answer_sequence_keys] |
| ) |
| raw_answer_sequence_prompt_tokens = checkpoint.tensors.get("answer_sequence_prompt_tokens", []) |
| if np is not None and hasattr(raw_answer_sequence_prompt_tokens, "shape"): |
| model.answer_sequence_prompt_tokens = raw_answer_sequence_prompt_tokens.astype(int, copy=False) |
| else: |
| model.answer_sequence_prompt_tokens = [ |
| [int(value) for value in row] for row in raw_answer_sequence_prompt_tokens |
| ] |
| raw_answer_sequence_tokens = checkpoint.tensors.get("answer_sequence_tokens", []) |
| if np is not None and hasattr(raw_answer_sequence_tokens, "shape"): |
| model.answer_sequence_tokens = raw_answer_sequence_tokens.astype(int, copy=False) |
| else: |
| model.answer_sequence_tokens = [ |
| [int(value) for value in row] for row in raw_answer_sequence_tokens |
| ] |
| model.answer_sequence_token_id_rows = None |
| raw_fingerprints = checkpoint.tensors.get("answer_fingerprint_hashes", []) |
| model.answer_fingerprint_hashes = model._coerce_answer_fingerprint_hashes( |
| raw_fingerprints |
| ) |
| model.answer_fingerprint_token_lengths = None |
| model.answer_fingerprint_token_sequences_by_length = None |
| if not model.answer_fingerprint_hashes: |
| model._refresh_answer_fingerprint_hashes() |
| model.answer_similarity_keys_array = _runtime_matrix_tensor("answer_similarity_keys") |
| model.answer_similarity_key_norms_array = _runtime_vector_tensor("answer_similarity_key_norms") |
| model.answer_start_similarity_keys_array = _runtime_matrix_tensor("answer_start_similarity_keys") |
| model.answer_start_similarity_key_norms_array = _runtime_vector_tensor("answer_start_similarity_key_norms") |
| model.answer_sequence_similarity_keys_array = _runtime_matrix_tensor("answer_sequence_similarity_keys") |
| model.answer_sequence_similarity_key_norms_array = _runtime_vector_tensor("answer_sequence_similarity_key_norms") |
| model.transition_id_tables = model._deserialize_transition_id_tables_from_tensors( |
| checkpoint.tensors |
| ) |
| if model.transition_id_tables is not None: |
| model.transition_tables = {order: {} for order in sorted(TRANSITION_ORDERS)} |
| else: |
| model.transition_tables = model._deserialize_transition_tables( |
| json.loads(metadata.get("transition_tables", "{}")) |
| ) |
| model._refresh_numeric_caches() |
| return model |
|
|
| def _collect_training_examples( |
| self, |
| tokens: list[str], |
| ) -> tuple[list[Vector], list[Vector], list[int]]: |
| assert self.embedding_model is not None |
| if np is not None: |
| hidden_states = [ |
| np.zeros(self.config.state_dim, dtype=np.float64) |
| for _ in self.config.timescales |
| ] |
| context_traces = [ |
| np.zeros(self.config.embedding_dim, dtype=np.float64) |
| for _ in self.config.timescales |
| ] |
| zero_embedding: Vector | object = np.zeros(self.config.embedding_dim, dtype=np.float64) |
| else: |
| hidden_states = [zeros_vector(self.config.state_dim) for _ in self.config.timescales] |
| context_traces = [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] |
| zero_embedding = zeros_vector(self.config.embedding_dim) |
| states: list[Vector] = [] |
| labels: list[Vector] = [] |
| label_ids: list[int] = [] |
| token_ids = [ |
| self.embedding_model.token_to_id.get(token, -1) |
| for token in tokens |
| ] |
| example_count = max(0, len(tokens) - 1) |
| stride = 1 |
| if self.config.max_training_examples and example_count > self.config.max_training_examples: |
| stride = max( |
| 1, |
| (example_count + self.config.max_training_examples - 1) // self.config.max_training_examples, |
| ) |
|
|
| for index in range(len(tokens) - 1): |
| token = tokens[index] |
| token_id = token_ids[index] |
| embedding = ( |
| self.embedding_model.embeddings[token_id] |
| if token_id >= 0 |
| else zero_embedding |
| ) |
| trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) |
| hidden_states, context_traces, combined_state = self._step_hidden_states_from_embedding( |
| hidden_states, |
| context_traces, |
| embedding, |
| trace_embedding=trace_embedding, |
| ) |
| if stride > 1 and index % stride != 0 and index != len(tokens) - 2: |
| continue |
| states.append(combined_state) |
| next_token_id = token_ids[index + 1] |
| labels.append(self._one_hot_from_id(next_token_id)) |
| label_ids.append(next_token_id) |
|
|
| if self.config.max_training_examples and len(states) > self.config.max_training_examples: |
| states = states[: self.config.max_training_examples] |
| labels = labels[: self.config.max_training_examples] |
| label_ids = label_ids[: self.config.max_training_examples] |
| return states, labels, label_ids |
|
|
| def _is_punctuation_piece(self, piece: str) -> bool: |
| return bool(piece) and all(character in string.punctuation for character in piece) |
|
|
| def _encode_context(self, tokens: list[str]) -> Vector: |
| return self._masked_decode_state(self._build_decode_state(tokens)) |
|
|
| def _build_decode_state(self, tokens: list[str]) -> DecodeState: |
| assert self.memory_units is not None |
|
|
| state = DecodeState( |
| hidden_states=( |
| [ |
| np.zeros(self.config.state_dim, dtype=np.float64) |
| for _ in self.config.timescales |
| ] |
| if np is not None |
| else [zeros_vector(self.config.state_dim) for _ in self.config.timescales] |
| ), |
| context_traces=( |
| [ |
| np.zeros(self.config.embedding_dim, dtype=np.float64) |
| for _ in self.config.timescales |
| ] |
| if np is not None |
| else [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] |
| ), |
| combined_state=self._zero_combined_state(), |
| context_tokens=[], |
| ) |
| for token in tokens: |
| self._advance_decode_state(state, token) |
| self._apply_sparse_context_anchor(state) |
| return state |
|
|
| def _advance_decode_state(self, state: DecodeState, token: str) -> DecodeState: |
| next_hidden_states, next_context_traces, combined_state = self._step_hidden_states( |
| state.hidden_states, |
| state.context_traces, |
| token, |
| ) |
| state.hidden_states = next_hidden_states |
| state.context_traces = next_context_traces |
| state.combined_state = combined_state |
| state.context_tokens.append(token) |
| if token == "<answer>": |
| state.answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:] |
| state.answer_matches = None |
| state.answer_start_matches = None |
| state.answer_sequence_matches = None |
| state.prompt_answer_prior = None |
| state.prompt_answer_start_prior = None |
| return state |
|
|
| def _apply_sparse_context_anchor(self, state: DecodeState) -> None: |
| if ( |
| np is None |
| or self.embedding_model is None |
| or state.answer_anchor_state is None |
| or not state.context_tokens |
| ): |
| return |
| answer_index = _last_index(state.context_tokens, "<answer>") |
| if answer_index is None or answer_index <= 0: |
| return |
| context_ids = self._long_context_sparse_token_ids(state.context_tokens[:answer_index]) |
| if len(context_ids) < SPARSE_CONTEXT_MIN_TOKENS: |
| return |
| query_id = context_ids[-1] |
| embeddings = np.asarray(self.embedding_model.embeddings, dtype=np.float32) |
| if embeddings.ndim != 2 or embeddings.shape[0] == 0: |
| return |
| selector = HashedSparseAttention( |
| embeddings, |
| k_neighbors=min(SPARSE_CONTEXT_TOP_K, len(context_ids)), |
| hash_bits=SPARSE_CONTEXT_HASH_BITS, |
| probe_radius=SPARSE_CONTEXT_PROBE_RADIUS, |
| candidate_multiplier=SPARSE_CONTEXT_CANDIDATE_MULTIPLIER, |
| ) |
| token_ids = np.asarray(context_ids, dtype=np.int64) |
| selector.build_context_index(token_ids) |
| selection = selector.select_positions_cached(query_id) |
| if not selection.positions: |
| return |
| selected_ids = token_ids[np.asarray(selection.positions, dtype=np.int64)] |
| selected_embeddings = embeddings[selected_ids] |
| scores = np.asarray(selection.scores, dtype=np.float32) |
| scores -= float(scores.max()) |
| weights = np.exp(scores) |
| weights /= max(float(weights.sum()), 1e-8) |
| sparse_embedding = weights @ selected_embeddings |
| blended_anchor = self._blend_sparse_embedding_into_combined_state( |
| state.answer_anchor_state, |
| sparse_embedding, |
| state_dim=self.config.state_dim, |
| embedding_dim=self.config.embedding_dim, |
| timescale_count=len(self.config.timescales), |
| blend=SPARSE_CONTEXT_TRACE_BLEND, |
| ) |
| state.answer_anchor_state = blended_anchor |
| if state.context_tokens and state.context_tokens[-1] == "<answer>": |
| state.combined_state = blended_anchor.copy() |
| state.answer_matches = None |
| state.answer_start_matches = None |
| state.answer_sequence_matches = None |
| state.prompt_answer_prior = None |
| state.prompt_answer_start_prior = None |
|
|
| def _long_context_sparse_token_ids(self, tokens: Sequence[str]) -> list[int]: |
| assert self.embedding_model is not None |
| special_tokens = self.tokenizer.special_tokens if self.tokenizer is not None else set() |
| ids: list[int] = [] |
| for token in tokens: |
| if token in special_tokens and token not in TOOL_PROTOCOL_TOKENS: |
| continue |
| token_id = self._token_id_for_token(token) |
| if token_id >= 0: |
| ids.append(token_id) |
| return ids |
|
|
| @staticmethod |
| def _blend_sparse_embedding_into_combined_state( |
| combined_state: Vector, |
| sparse_embedding: object, |
| *, |
| state_dim: int, |
| embedding_dim: int, |
| timescale_count: int, |
| blend: float, |
| ) -> Vector: |
| if np is None: |
| return combined_state |
| state_array = np.asarray(combined_state, dtype=np.float32).copy() |
| sparse_array = np.asarray(sparse_embedding, dtype=np.float32) |
| if sparse_array.shape[0] != embedding_dim: |
| return combined_state |
| block_width = state_dim + embedding_dim |
| expected_width = block_width * timescale_count |
| if state_array.shape[0] != expected_width: |
| return combined_state |
| alpha = min(1.0, max(0.0, float(blend))) |
| for block_index in range(timescale_count): |
| trace_start = block_index * block_width + state_dim |
| trace_end = trace_start + embedding_dim |
| state_array[trace_start:trace_end] = ( |
| (1.0 - alpha) * state_array[trace_start:trace_end] |
| + alpha * sparse_array |
| ) |
| return state_array.tolist() |
|
|
| def _masked_decode_state(self, state: DecodeState) -> Vector: |
| assert self.ternary_mask is not None |
| return apply_ternary_mask(state.combined_state, self.ternary_mask, self.ternary_scale) |
|
|
| def _masked_combined_state(self, combined_state: Vector) -> Vector: |
| assert self.ternary_mask is not None |
| return apply_ternary_mask(combined_state, self.ternary_mask, self.ternary_scale) |
|
|
| def _masked_decode_state_array(self, state: DecodeState) -> object: |
| assert np is not None |
| if self.ternary_mask_array is None: |
| return np.asarray(self._masked_decode_state(state), dtype=RUNTIME_ARRAY_DTYPE) |
| return ( |
| np.asarray(state.combined_state, dtype=RUNTIME_ARRAY_DTYPE) |
| * self.ternary_scale |
| * self.ternary_mask_array |
| ) |
|
|
| def _masked_combined_state_array(self, combined_state: Vector) -> object: |
| assert np is not None |
| if self.ternary_mask_array is None: |
| return np.asarray(self._masked_combined_state(combined_state), dtype=RUNTIME_ARRAY_DTYPE) |
| return ( |
| np.asarray(combined_state, dtype=RUNTIME_ARRAY_DTYPE) |
| * self.ternary_scale |
| * self.ternary_mask_array |
| ) |
|
|
| def _center_state_vector(self, state: Vector) -> Vector: |
| if not self.state_offset or len(self.state_offset) != len(state): |
| return state |
| return [value - self.state_offset[index] for index, value in enumerate(state)] |
|
|
| def _center_state_array(self, state: object) -> object: |
| assert np is not None |
| state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.state_offset_array is None or self.state_offset_array.shape != state_array.shape: |
| return state_array |
| return state_array - self.state_offset_array |
|
|
| def _zero_combined_state(self) -> Vector: |
| return [0.0 for _ in range(self._combined_state_width())] |
|
|
| def _combined_state_width(self) -> int: |
| return (self.config.state_dim + self.config.embedding_dim) * len(self.config.timescales) |
|
|
| def _derive_trace_token_weights_from_counts(self, token_counts: dict[str, float]) -> Vector: |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
| counts = [ |
| float(token_counts.get(token, 0.0)) |
| for token in self.embedding_model.id_to_token |
| ] |
| positive_counts = sorted(value for value in counts if value > 0.0) |
| reference = ( |
| positive_counts[len(positive_counts) // 2] |
| if positive_counts |
| else 1.0 |
| ) |
| weights: Vector = [] |
| for token, count in zip(self.embedding_model.id_to_token, counts): |
| if token in TOOL_PROTOCOL_TOKENS: |
| weights.append(1.0) |
| elif token in self.tokenizer.special_tokens: |
| weights.append(0.0) |
| elif count <= 0.0: |
| weights.append(1.0) |
| else: |
| weight = (reference / count) ** 0.75 |
| weights.append(max(0.08, min(4.8, weight))) |
| return weights |
|
|
| def _token_id_for_token(self, token: str) -> int: |
| assert self.embedding_model is not None |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None and token.lower() != token: |
| token_id = self.embedding_model.token_to_id.get(token.lower()) |
| return int(token_id) if token_id is not None else -1 |
|
|
| def _trace_embedding_from_token_id( |
| self, |
| embedding: Vector | object, |
| token_id: int, |
| ) -> Vector | object: |
| if token_id < 0: |
| return embedding |
| if self.trace_embedding_table_array is not None: |
| return self.trace_embedding_table_array[token_id] |
| weight = self.trace_token_weights[token_id] if self.trace_token_weights is not None else 1.0 |
| dimension = self.config.embedding_dim |
| if hasattr(embedding, "shape"): |
| trace_embedding = embedding * weight |
| for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: |
| bucket = (token_id * bucket_multiplier + bucket_offset) % dimension |
| sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 |
| trace_embedding[bucket] += weight * TRACE_IDENTITY_SCALE * sign |
| return trace_embedding |
| trace_values = [float(value) * weight for value in embedding] |
| for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: |
| bucket = (token_id * bucket_multiplier + bucket_offset) % dimension |
| sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 |
| trace_values[bucket] += weight * TRACE_IDENTITY_SCALE * sign |
| return trace_values |
|
|
| def _build_trace_embedding_table_array(self, embedding_array: object) -> object | None: |
| if np is None or self.trace_token_weights is None: |
| return None |
| values = np.asarray(embedding_array, dtype=np.float64) |
| if values.size == 0 or len(values.shape) != 2: |
| return None |
| weights = np.asarray(self.trace_token_weights, dtype=np.float64) |
| if weights.shape[0] != values.shape[0]: |
| return None |
| trace_values = values * weights[:, None] |
| if values.shape[1] <= 0: |
| return trace_values |
| token_ids = np.arange(values.shape[0], dtype=np.int64) |
| for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: |
| buckets = ((token_ids * bucket_multiplier + bucket_offset) % values.shape[1]).astype( |
| np.int64, |
| copy=False, |
| ) |
| signs = np.where( |
| ((token_ids * sign_multiplier + sign_offset) & 1) == 0, |
| 1.0, |
| -1.0, |
| ) |
| np.add.at(trace_values, (token_ids, buckets), weights * TRACE_IDENTITY_SCALE * signs) |
| return trace_values |
|
|
| def _runtime_key_norms_array( |
| self, |
| key_array: object | None, |
| key_norms: list[float] | None, |
| ) -> object | None: |
| assert np is not None |
| if key_norms is not None and len(key_norms) > 0: |
| return np.asarray(key_norms, dtype=RUNTIME_ARRAY_DTYPE) |
| if key_array is None: |
| return None |
| keys = np.asarray(key_array, dtype=RUNTIME_ARRAY_DTYPE) |
| if len(keys.shape) != 2 or keys.shape[0] == 0: |
| return None |
| return np.linalg.norm(keys, axis=1).astype(RUNTIME_ARRAY_DTYPE, copy=False) |
|
|
| def _runtime_vector_cache(self, cached: object | None, length: int) -> object | None: |
| assert np is not None |
| if cached is None or not hasattr(cached, "shape"): |
| return None |
| array = np.asarray(cached, dtype=RUNTIME_ARRAY_DTYPE) |
| if len(array.shape) != 1 or int(array.shape[0]) != int(length): |
| return None |
| return array |
|
|
| def _runtime_matrix_cache( |
| self, |
| cached: object | None, |
| rows: int, |
| width: int, |
| ) -> object | None: |
| assert np is not None |
| if cached is None or not hasattr(cached, "shape"): |
| return None |
| array = np.asarray(cached, dtype=RUNTIME_ARRAY_DTYPE) |
| if ( |
| len(array.shape) != 2 |
| or int(array.shape[0]) != int(rows) |
| or int(array.shape[1]) != int(width) |
| ): |
| return None |
| return array |
|
|
| def _refresh_numeric_caches(self) -> None: |
| if np is None: |
| self.ternary_mask_array = None |
| self.readout_weights_array = None |
| self.readout_bias_array = None |
| self.prompt_answer_weights_array = None |
| self.prompt_answer_bias_array = None |
| self.prompt_answer_start_weights_array = None |
| self.prompt_answer_start_bias_array = None |
| self.trace_token_weights_array = None |
| self.trace_embedding_table_array = None |
| self.preference_bias_array = None |
| self.preference_valid_mask_array = None |
| self.state_offset_array = None |
| self.associative_keys_array = None |
| self.associative_key_norms_array = None |
| self.associative_values_array = None |
| self.associative_valid_mask_array = None |
| self.answer_keys_array = None |
| self.answer_key_norms_array = None |
| self.answer_similarity_keys_array = None |
| self.answer_similarity_key_norms_array = None |
| self.answer_similarity_mask_array = None |
| self.answer_values_array = None |
| self.answer_valid_mask_array = None |
| self.answer_start_keys_array = None |
| self.answer_start_key_norms_array = None |
| self.answer_start_similarity_keys_array = None |
| self.answer_start_similarity_key_norms_array = None |
| self.answer_start_values_array = None |
| self.answer_start_valid_mask_array = None |
| self.answer_sequence_keys_array = None |
| self.answer_sequence_key_norms_array = None |
| self.answer_sequence_similarity_keys_array = None |
| self.answer_sequence_similarity_key_norms_array = None |
| self.answer_sequence_prompt_tokens_array = None |
| self.answer_sequence_tokens_array = None |
| self.answer_sequence_prompt_weight_maps = None |
| self.answer_sequence_prompt_weight_norms = None |
| self.answer_sequence_prompt_bigram_sets = None |
| self.answer_sequence_prompt_trigram_sets = None |
| self.answer_sequence_prompt_number_sets = None |
| self.answer_sequence_prompt_inverted_index = None |
| self._refresh_answer_sequence_prompt_overlap_cache() |
| self.prompt_overlap_valid_token_mask_array = None |
| return |
| cached_associative_key_norms_array = self.associative_key_norms_array |
| cached_answer_key_norms_array = self.answer_key_norms_array |
| cached_answer_similarity_keys_array = self.answer_similarity_keys_array |
| cached_answer_similarity_key_norms_array = self.answer_similarity_key_norms_array |
| cached_answer_start_key_norms_array = self.answer_start_key_norms_array |
| cached_answer_start_similarity_keys_array = self.answer_start_similarity_keys_array |
| cached_answer_start_similarity_key_norms_array = self.answer_start_similarity_key_norms_array |
| cached_answer_sequence_key_norms_array = self.answer_sequence_key_norms_array |
| cached_answer_sequence_similarity_keys_array = self.answer_sequence_similarity_keys_array |
| cached_answer_sequence_similarity_key_norms_array = self.answer_sequence_similarity_key_norms_array |
| self.ternary_mask_array = ( |
| np.asarray(self.ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.ternary_mask is not None |
| else None |
| ) |
| self.readout_weights_array = ( |
| np.asarray(self.readout_weights, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.readout_weights is not None |
| else None |
| ) |
| self.readout_bias_array = ( |
| np.asarray(self.readout_bias, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.readout_bias is not None |
| else None |
| ) |
| self.prompt_answer_weights_array = ( |
| np.asarray(self.prompt_answer_weights, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.prompt_answer_weights is not None |
| and len(self.prompt_answer_weights) > 0 |
| else None |
| ) |
| self.prompt_answer_bias_array = ( |
| np.asarray(self.prompt_answer_bias, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.prompt_answer_bias is not None |
| else None |
| ) |
| self.prompt_answer_start_weights_array = ( |
| np.asarray(self.prompt_answer_start_weights, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.prompt_answer_start_weights is not None |
| and len(self.prompt_answer_start_weights) > 0 |
| else None |
| ) |
| self.prompt_answer_start_bias_array = ( |
| np.asarray(self.prompt_answer_start_bias, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.prompt_answer_start_bias is not None |
| else None |
| ) |
| self.trace_token_weights_array = ( |
| np.asarray(self.trace_token_weights, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.trace_token_weights is not None |
| else None |
| ) |
| trace_embedding_table = ( |
| self._build_trace_embedding_table_array(self.embedding_model.embeddings) |
| if self.embedding_model is not None and self.trace_token_weights is not None |
| else None |
| ) |
| self.trace_embedding_table_array = ( |
| trace_embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if trace_embedding_table is not None |
| else None |
| ) |
| self.preference_bias_array = ( |
| np.asarray(self.preference_bias, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.preference_bias is not None |
| else None |
| ) |
| self.preference_valid_mask_array = ( |
| np.asarray( |
| [ |
| self._eligible_preference_token(token) |
| for token in self.embedding_model.id_to_token |
| ], |
| dtype=bool, |
| ) |
| if self.embedding_model is not None and self.tokenizer is not None |
| else None |
| ) |
| self.state_offset_array = ( |
| np.asarray(self.state_offset, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.state_offset is not None |
| else None |
| ) |
| self.associative_keys_array = ( |
| np.asarray(self.associative_keys, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.associative_keys is not None and len(self.associative_keys) > 0 |
| else None |
| ) |
| associative_key_norms_cache = ( |
| self._runtime_vector_cache( |
| cached_associative_key_norms_array, |
| int(self.associative_keys_array.shape[0]), |
| ) |
| if self.associative_keys_array is not None |
| else None |
| ) |
| self.associative_key_norms_array = ( |
| associative_key_norms_cache |
| if associative_key_norms_cache is not None |
| else self._runtime_key_norms_array( |
| self.associative_keys_array, |
| self.associative_key_norms, |
| ) |
| ) |
| self.associative_values_array = ( |
| np.asarray(self.associative_values, dtype=np.int64) |
| if self.associative_values is not None and len(self.associative_values) > 0 |
| else None |
| ) |
| self.associative_valid_mask_array = ( |
| self.associative_values_array >= 0 |
| if self.associative_values_array is not None |
| else None |
| ) |
| self.answer_keys_array = ( |
| np.asarray(self.answer_keys, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_keys is not None and len(self.answer_keys) > 0 |
| else None |
| ) |
| answer_key_norms_cache = ( |
| self._runtime_vector_cache( |
| cached_answer_key_norms_array, |
| int(self.answer_keys_array.shape[0]), |
| ) |
| if self.answer_keys_array is not None |
| else None |
| ) |
| self.answer_key_norms_array = ( |
| answer_key_norms_cache |
| if answer_key_norms_cache is not None |
| else self._runtime_key_norms_array( |
| self.answer_keys_array, |
| self.answer_key_norms, |
| ) |
| ) |
| self.answer_similarity_keys_array = None |
| self.answer_similarity_key_norms_array = None |
| self.answer_similarity_mask_array = None |
| if self.answer_keys_array is not None and len(self.answer_keys_array.shape) == 2: |
| width = int(self.answer_keys_array.shape[1]) |
| block_width = self.config.state_dim + self.config.embedding_dim |
| expected_width = block_width * len(self.config.timescales) |
| if block_width > 0 and width == expected_width: |
| mask = np.zeros(width, dtype=RUNTIME_ARRAY_DTYPE) |
| for scale_index in range(len(self.config.timescales)): |
| start = scale_index * block_width + self.config.state_dim |
| end = start + self.config.embedding_dim |
| mask[start:end] = 1.0 |
| self.answer_similarity_mask_array = mask |
| answer_similarity_keys_cache = self._runtime_matrix_cache( |
| cached_answer_similarity_keys_array, |
| int(self.answer_keys_array.shape[0]), |
| width, |
| ) |
| answer_similarity_key_norms_cache = self._runtime_vector_cache( |
| cached_answer_similarity_key_norms_array, |
| int(self.answer_keys_array.shape[0]), |
| ) |
| if ( |
| answer_similarity_keys_cache is not None |
| and answer_similarity_key_norms_cache is not None |
| ): |
| self.answer_similarity_keys_array = answer_similarity_keys_cache |
| self.answer_similarity_key_norms_array = answer_similarity_key_norms_cache |
| else: |
| self.answer_similarity_keys_array = self.answer_keys_array * mask[None, :] |
| self.answer_similarity_key_norms_array = np.linalg.norm( |
| self.answer_similarity_keys_array, |
| axis=1, |
| ).astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| self.answer_values_array = ( |
| np.asarray(self.answer_values, dtype=np.int64) |
| if self.answer_values is not None and len(self.answer_values) > 0 |
| else None |
| ) |
| self.answer_valid_mask_array = ( |
| self.answer_values_array >= 0 |
| if self.answer_values_array is not None |
| else None |
| ) |
| self.answer_start_keys_array = ( |
| np.asarray(self.answer_start_keys, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_start_keys is not None and len(self.answer_start_keys) > 0 |
| else None |
| ) |
| answer_start_key_norms_cache = ( |
| self._runtime_vector_cache( |
| cached_answer_start_key_norms_array, |
| int(self.answer_start_keys_array.shape[0]), |
| ) |
| if self.answer_start_keys_array is not None |
| else None |
| ) |
| self.answer_start_key_norms_array = ( |
| answer_start_key_norms_cache |
| if answer_start_key_norms_cache is not None |
| else self._runtime_key_norms_array( |
| self.answer_start_keys_array, |
| self.answer_start_key_norms, |
| ) |
| ) |
| self.answer_start_similarity_keys_array = None |
| self.answer_start_similarity_key_norms_array = None |
| if ( |
| self.answer_start_keys_array is not None |
| and len(self.answer_start_keys_array.shape) == 2 |
| and self.answer_similarity_mask_array is not None |
| and int(self.answer_start_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) |
| ): |
| answer_start_similarity_keys_cache = self._runtime_matrix_cache( |
| cached_answer_start_similarity_keys_array, |
| int(self.answer_start_keys_array.shape[0]), |
| int(self.answer_start_keys_array.shape[1]), |
| ) |
| answer_start_similarity_key_norms_cache = self._runtime_vector_cache( |
| cached_answer_start_similarity_key_norms_array, |
| int(self.answer_start_keys_array.shape[0]), |
| ) |
| if ( |
| answer_start_similarity_keys_cache is not None |
| and answer_start_similarity_key_norms_cache is not None |
| ): |
| self.answer_start_similarity_keys_array = answer_start_similarity_keys_cache |
| self.answer_start_similarity_key_norms_array = answer_start_similarity_key_norms_cache |
| else: |
| self.answer_start_similarity_keys_array = ( |
| self.answer_start_keys_array * self.answer_similarity_mask_array[None, :] |
| ) |
| self.answer_start_similarity_key_norms_array = np.linalg.norm( |
| self.answer_start_similarity_keys_array, |
| axis=1, |
| ).astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| self.answer_start_values_array = ( |
| np.asarray(self.answer_start_values, dtype=np.int64) |
| if self.answer_start_values is not None and len(self.answer_start_values) > 0 |
| else None |
| ) |
| self.answer_start_valid_mask_array = ( |
| self.answer_start_values_array >= 0 |
| if self.answer_start_values_array is not None |
| else None |
| ) |
| self.answer_sequence_keys_array = ( |
| np.asarray(self.answer_sequence_keys, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_sequence_keys is not None and len(self.answer_sequence_keys) > 0 |
| else None |
| ) |
| answer_sequence_key_norms_cache = ( |
| self._runtime_vector_cache( |
| cached_answer_sequence_key_norms_array, |
| int(self.answer_sequence_keys_array.shape[0]), |
| ) |
| if self.answer_sequence_keys_array is not None |
| else None |
| ) |
| self.answer_sequence_key_norms_array = ( |
| answer_sequence_key_norms_cache |
| if answer_sequence_key_norms_cache is not None |
| else self._runtime_key_norms_array( |
| self.answer_sequence_keys_array, |
| self.answer_sequence_key_norms, |
| ) |
| ) |
| self.answer_sequence_similarity_keys_array = None |
| self.answer_sequence_similarity_key_norms_array = None |
| if ( |
| self.answer_sequence_keys_array is not None |
| and len(self.answer_sequence_keys_array.shape) == 2 |
| and self.answer_similarity_mask_array is not None |
| and int(self.answer_sequence_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) |
| ): |
| answer_sequence_similarity_keys_cache = self._runtime_matrix_cache( |
| cached_answer_sequence_similarity_keys_array, |
| int(self.answer_sequence_keys_array.shape[0]), |
| int(self.answer_sequence_keys_array.shape[1]), |
| ) |
| answer_sequence_similarity_key_norms_cache = self._runtime_vector_cache( |
| cached_answer_sequence_similarity_key_norms_array, |
| int(self.answer_sequence_keys_array.shape[0]), |
| ) |
| if ( |
| answer_sequence_similarity_keys_cache is not None |
| and answer_sequence_similarity_key_norms_cache is not None |
| ): |
| self.answer_sequence_similarity_keys_array = answer_sequence_similarity_keys_cache |
| self.answer_sequence_similarity_key_norms_array = answer_sequence_similarity_key_norms_cache |
| else: |
| self.answer_sequence_similarity_keys_array = ( |
| self.answer_sequence_keys_array * self.answer_similarity_mask_array[None, :] |
| ) |
| self.answer_sequence_similarity_key_norms_array = np.linalg.norm( |
| self.answer_sequence_similarity_keys_array, |
| axis=1, |
| ).astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| self.answer_sequence_tokens_array = ( |
| np.asarray(self.answer_sequence_tokens, dtype=np.int64) |
| if self.answer_sequence_tokens is not None and len(self.answer_sequence_tokens) > 0 |
| else None |
| ) |
| self.answer_sequence_prompt_tokens_array = ( |
| np.asarray(self.answer_sequence_prompt_tokens, dtype=np.int64) |
| if self.answer_sequence_prompt_tokens is not None |
| and len(self.answer_sequence_prompt_tokens) > 0 |
| else None |
| ) |
| self.prompt_overlap_valid_token_mask_array = None |
| if not self._defer_answer_sequence_prompt_overlap_cache(): |
| self._refresh_answer_sequence_prompt_overlap_cache() |
| else: |
| self._refresh_answer_sequence_prompt_overlap_cache() |
|
|
| def _defer_answer_sequence_prompt_overlap_cache(self) -> bool: |
| if self.answer_sequence_prompt_tokens is None: |
| return False |
| try: |
| row_count = len(self.answer_sequence_prompt_tokens) |
| except TypeError: |
| return False |
| return ( |
| row_count > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT |
| and np is not None |
| and self.answer_sequence_prompt_tokens_array is not None |
| ) |
|
|
| def _prompt_overlap_valid_token_mask(self) -> object | None: |
| if np is None or self.embedding_model is None: |
| return None |
| if ( |
| self.prompt_overlap_valid_token_mask_array is not None |
| and int(self.prompt_overlap_valid_token_mask_array.shape[0]) == len(self.embedding_model.id_to_token) |
| ): |
| return self.prompt_overlap_valid_token_mask_array |
| mask = np.fromiter( |
| ( |
| not self._should_skip_prompt_overlap_token(token) |
| for token in self.embedding_model.id_to_token |
| ), |
| dtype=bool, |
| count=len(self.embedding_model.id_to_token), |
| ) |
| self.prompt_overlap_valid_token_mask_array = mask |
| return mask |
|
|
| def _answer_prompt_row_ids_from_array(self) -> tuple[dict[int, list[int]], list[list[int]] | None] | None: |
| if ( |
| np is None |
| or self.answer_sequence_prompt_tokens_array is None |
| or self.trace_token_weights is None |
| or self.embedding_model is None |
| ): |
| return None |
| rows = np.asarray(self.answer_sequence_prompt_tokens_array, dtype=np.int64) |
| if len(rows.shape) != 2 or rows.size == 0: |
| return {}, [] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None |
| vocab_size = len(self.trace_token_weights) |
| if vocab_size <= 0: |
| return {}, [] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None |
| valid_token_mask = self._prompt_overlap_valid_token_mask() |
| if valid_token_mask is None: |
| return None |
| bounded = (rows >= 0) & (rows < vocab_size) |
| clipped = np.clip(rows, 0, max(0, vocab_size - 1)) |
| bounded &= valid_token_mask[clipped] |
| row_positions, column_positions = np.nonzero(bounded) |
| if row_positions.size == 0: |
| empty_rows = [[] for _ in range(int(rows.shape[0]))] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None |
| return {}, empty_rows |
| token_values = rows[row_positions, column_positions].astype(np.int64, copy=False) |
| order = np.lexsort((row_positions, token_values)) |
| token_values = token_values[order] |
| row_positions = row_positions[order] |
| unique = np.ones(token_values.shape[0], dtype=bool) |
| unique[1:] = (token_values[1:] != token_values[:-1]) | (row_positions[1:] != row_positions[:-1]) |
| token_values = token_values[unique] |
| row_positions = row_positions[unique] |
| boundaries = np.flatnonzero(token_values[1:] != token_values[:-1]) + 1 |
| token_groups = np.split(token_values, boundaries) |
| row_groups = np.split(row_positions, boundaries) |
| inverted = { |
| int(token_group[0]): row_group.astype(np.int64, copy=False).tolist() |
| for token_group, row_group in zip(token_groups, row_groups) |
| if token_group.size |
| } |
| if rows.shape[0] > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT: |
| return inverted, None |
| row_id_lists: list[list[int]] = [[] for _ in range(int(rows.shape[0]))] |
| for token_id, row_index in zip(token_values.tolist(), row_positions.tolist()): |
| row_id_lists[int(row_index)].append(int(token_id)) |
| return inverted, row_id_lists |
|
|
| def _refresh_answer_sequence_prompt_overlap_cache(self) -> None: |
| self.answer_sequence_prompt_weight_maps = None |
| self.answer_sequence_prompt_weight_norms = None |
| self.answer_sequence_prompt_bigram_sets = None |
| self.answer_sequence_prompt_trigram_sets = None |
| self.answer_sequence_prompt_number_sets = None |
| self.answer_sequence_prompt_inverted_index = None |
| self.answer_sequence_prompt_specificity = None |
| if self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None: |
| return |
| array_index = self._answer_prompt_row_ids_from_array() |
| if array_index is not None: |
| inverted, row_id_lists = array_index |
| total_rows = ( |
| int(self.answer_sequence_prompt_tokens_array.shape[0]) |
| if self.answer_sequence_prompt_tokens_array is not None |
| else len(row_id_lists or []) |
| ) |
| else: |
| inverted = {} |
| row_id_lists = [] |
| for row in self.answer_sequence_prompt_tokens: |
| row_values = row.tolist() if hasattr(row, "tolist") else row |
| row_ids: list[int] = [] |
| for raw_token_id in row_values: |
| token_id = int(raw_token_id) |
| if token_id < 0 or token_id >= len(self.trace_token_weights): |
| continue |
| if self.embedding_model is not None and self._should_skip_prompt_overlap_token( |
| self.embedding_model.id_to_token[token_id] |
| ): |
| continue |
| row_ids.append(token_id) |
| sequence_index = len(row_id_lists) |
| for token_id in set(row_ids): |
| inverted.setdefault(token_id, []).append(sequence_index) |
| row_id_lists.append(row_ids) |
| total_rows = len(row_id_lists) |
|
|
| specificity = { |
| token_id: self._prompt_overlap_token_specificity(len(indices), total_rows) |
| for token_id, indices in inverted.items() |
| } |
| self.answer_sequence_prompt_inverted_index = inverted |
| self.answer_sequence_prompt_specificity = specificity |
|
|
| if total_rows > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT: |
| return |
| if row_id_lists is None: |
| return |
|
|
| weight_maps: list[dict[int, float]] = [] |
| weight_norms: list[float] = [] |
| bigram_sets: list[set[tuple[int, int]]] = [] |
| trigram_sets: list[set[tuple[int, int, int]]] = [] |
| number_sets: list[set[str]] = [] |
| for row_index, row_ids in enumerate(row_id_lists): |
| row_weights: dict[int, float] = {} |
| for token_id in row_ids: |
| row_weights[token_id] = max( |
| row_weights.get(token_id, 0.0), |
| float(self.trace_token_weights[token_id]) * specificity.get(token_id, 1.0), |
| ) |
| weight_maps.append(row_weights) |
| weight_norms.append(sum(value * value for value in row_weights.values()) ** 0.5) |
| bigram_sets.append( |
| { |
| (row_ids[index], row_ids[index + 1]) |
| for index in range(len(row_ids) - 1) |
| } |
| ) |
| trigram_sets.append( |
| { |
| (row_ids[index], row_ids[index + 1], row_ids[index + 2]) |
| for index in range(len(row_ids) - 2) |
| } |
| ) |
| raw_row = self.answer_sequence_prompt_tokens[row_index] |
| raw_values = raw_row.tolist() if hasattr(raw_row, "tolist") else raw_row |
| raw_ids = [ |
| int(value) |
| for value in raw_values |
| if 0 <= int(value) < len(self.embedding_model.id_to_token) |
| ] |
| number_sets.append(self._number_strings_from_token_ids(raw_ids)) |
| self.answer_sequence_prompt_weight_maps = weight_maps |
| self.answer_sequence_prompt_weight_norms = weight_norms |
| self.answer_sequence_prompt_bigram_sets = bigram_sets |
| self.answer_sequence_prompt_trigram_sets = trigram_sets |
| self.answer_sequence_prompt_number_sets = number_sets |
|
|
| @staticmethod |
| def _prompt_overlap_token_specificity(document_frequency: int, total_documents: int) -> float: |
| if document_frequency <= 0 or total_documents <= 0: |
| return 1.0 |
| coverage = min(1.0, document_frequency / total_documents) |
| return max(0.02, 1.0 - (coverage ** 0.5)) |
|
|
| def _number_strings_from_token_ids(self, token_ids: list[int]) -> set[str]: |
| assert self.embedding_model is not None |
| tokens = [ |
| self.embedding_model.id_to_token[token_id] |
| for token_id in token_ids |
| if 0 <= token_id < len(self.embedding_model.id_to_token) |
| ] |
| return self._number_strings_from_tokens(tokens) |
|
|
| def _number_strings_from_tokens(self, tokens: list[str]) -> set[str]: |
| numbers: set[str] = set() |
| current = "" |
| for token in tokens: |
| if self.tokenizer is not None and token in self.tokenizer.special_tokens: |
| if current: |
| numbers.add(current) |
| current = "" |
| continue |
| rendered = self._render_token(token) |
| digits = "".join(character for character in rendered if character.isdigit()) |
| starts_number = self._starts_new_word(token) if self.tokenizer is not None else True |
| if digits and starts_number: |
| if current: |
| numbers.add(current) |
| current = digits |
| elif digits and current: |
| current += digits |
| else: |
| if current: |
| numbers.add(current) |
| current = "" |
| if current: |
| numbers.add(current) |
| return numbers |
|
|
| @staticmethod |
| def _numeric_prompt_can_match(query_numbers: set[str], row_numbers: set[str]) -> bool: |
| if not query_numbers: |
| return True |
| if not row_numbers: |
| return False |
| return query_numbers.issubset(row_numbers) |
|
|
| def _vector_answer_sequence_candidate_indices( |
| self, |
| query_token_ids: object, |
| ) -> list[int] | None: |
| if ( |
| np is None |
| or self.answer_sequence_prompt_tokens_array is None |
| or not hasattr(self.answer_sequence_prompt_tokens_array, "shape") |
| ): |
| return None |
| query_ids = np.asarray(list(query_token_ids), dtype=np.int64) |
| if query_ids.size == 0: |
| return [] |
| prompt_array = self.answer_sequence_prompt_tokens_array |
| if len(prompt_array.shape) != 2 or prompt_array.shape[0] == 0: |
| return None |
| mask = np.isin(prompt_array, query_ids).any(axis=1) |
| return [int(index) for index in np.flatnonzero(mask)] |
|
|
| def _vector_answer_sequence_local_frequency( |
| self, |
| token_id: int, |
| candidate_indices: list[int], |
| ) -> int | None: |
| if ( |
| np is None |
| or self.answer_sequence_prompt_tokens_array is None |
| or not hasattr(self.answer_sequence_prompt_tokens_array, "shape") |
| or not candidate_indices |
| ): |
| return None |
| rows = self.answer_sequence_prompt_tokens_array[ |
| np.asarray(candidate_indices, dtype=np.int64) |
| ] |
| return int(np.any(rows == int(token_id), axis=1).sum()) |
|
|
| def _apply_readout_fast(self, state: Vector) -> Vector: |
| if self.readout_weights_array is None or np is None: |
| assert self.readout_weights is not None |
| centered_state = self._center_state_vector(state) |
| logits = apply_readout(self.readout_weights, centered_state) |
| if self.readout_bias: |
| logits = [ |
| value + self.readout_bias[index] |
| for index, value in enumerate(logits) |
| ] |
| return logits |
| state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: |
| state_array = state_array - self.state_offset_array |
| logits = self.readout_weights_array @ state_array |
| if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: |
| logits = logits + self.readout_bias_array |
| return logits.tolist() |
|
|
| def _apply_readout_array(self, state: object) -> object: |
| assert np is not None |
| assert self.readout_weights_array is not None |
| state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: |
| state_array = state_array - self.state_offset_array |
| logits = self.readout_weights_array @ state_array |
| if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: |
| logits = logits + self.readout_bias_array |
| return logits |
|
|
| def _step_hidden_states( |
| self, |
| hidden_states: list[Vector], |
| context_traces: list[Vector], |
| token: str, |
| ) -> tuple[list[Vector], list[Vector], Vector]: |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
| token_id = self._token_id_for_token(token) |
| embedding = self.embedding_model.vector(token) |
| trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) |
| return self._step_hidden_states_from_embedding( |
| hidden_states, |
| context_traces, |
| embedding, |
| trace_embedding=trace_embedding, |
| ) |
|
|
| def _step_hidden_states_from_embedding( |
| self, |
| hidden_states: list[Vector], |
| context_traces: list[Vector], |
| embedding: Vector | object, |
| *, |
| trace_embedding: Vector | object | None = None, |
| ) -> tuple[list[Vector], list[Vector], Vector]: |
| assert self.memory_units is not None |
| if trace_embedding is None: |
| trace_embedding = embedding |
|
|
| if np is not None and hidden_states and hasattr(hidden_states[0], "shape"): |
| embedding_array = ( |
| embedding |
| if hasattr(embedding, "shape") |
| else np.asarray(embedding, dtype=np.float64) |
| ) |
| trace_embedding_array = ( |
| trace_embedding |
| if hasattr(trace_embedding, "shape") |
| else np.asarray(trace_embedding, dtype=np.float64) |
| ) |
| drive = analytical_embedding_drive_fast(embedding_array, self.config.state_dim) |
| next_states: list[Vector] = [] |
| next_traces: list[Vector] = [] |
| combined_state: Vector = [] |
| for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): |
| next_state = unit.step_vector_fast(state, drive) |
| decay = 1.0 / (1.0 + unit.timescale) |
| next_trace = trace + ((1.0 - decay) * trace_embedding_array) |
| next_states.append(next_state) |
| next_traces.append(next_trace) |
| combined_state.extend(next_state.tolist()) |
| combined_state.extend(next_trace.tolist()) |
| return next_states, next_traces, combined_state |
|
|
| embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else embedding |
| trace_embedding_vector = ( |
| trace_embedding.tolist() |
| if hasattr(trace_embedding, "tolist") |
| else trace_embedding |
| ) |
| drive = analytical_embedding_drive(embedding_vector, self.config.state_dim) |
| next_states: list[Vector] = [] |
| next_traces: list[Vector] = [] |
| combined_state: Vector = [] |
| for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): |
| next_state = unit.step_vector(state, drive) |
| decay = 1.0 / (1.0 + unit.timescale) |
| next_trace = [ |
| previous + ((1.0 - decay) * value) |
| for previous, value in zip(trace, trace_embedding_vector) |
| ] |
| next_states.append(next_state) |
| next_traces.append(next_trace) |
| combined_state.extend(next_state) |
| combined_state.extend(next_trace) |
| return next_states, next_traces, combined_state |
|
|
| def _one_hot(self, token: str) -> Vector: |
| assert self.embedding_model is not None |
| return self._one_hot_from_id(self.embedding_model.token_to_id.get(token, -1)) |
|
|
| def _one_hot_from_id(self, token_id: int) -> Vector: |
| assert self.embedding_model is not None |
| vector = [0.0 for _ in self.embedding_model.id_to_token] |
| if token_id >= 0: |
| vector[token_id] = 1.0 |
| return vector |
|
|
| def _blend_probabilities( |
| self, |
| base: Vector, |
| answer: Vector, |
| associative: Vector, |
| transition: Vector, |
| copy: Vector, |
| source_evidence: Vector, |
| preference: Vector, |
| *, |
| transition_order: int | None, |
| generated_count: int = 0, |
| answer_locked: bool = False, |
| answer_guided_start: bool = False, |
| copy_guided_start: bool = False, |
| ) -> tuple[Vector, dict[str, float]]: |
| base_weight = FAST_BASE_BLEND |
| answer_weight = FAST_ANSWER_BLEND |
| associative_weight = FAST_ASSOCIATIVE_BLEND |
| transition_weight = FAST_TRANSITION_BLEND |
| copy_weight = FAST_COPY_BLEND |
| source_evidence_weight = FAST_SOURCE_EVIDENCE_BLEND |
| preference_weight = FAST_PREFERENCE_BLEND |
| source_grounded = any(value > 0.0 for value in source_evidence) |
| if answer_locked: |
| base_weight *= 0.005 |
| answer_weight *= 250.0 |
| associative_weight *= 0.05 |
| transition_weight *= 0.005 |
| copy_weight *= 0.005 |
| source_evidence_weight *= 0.05 |
| preference_weight *= 0.05 |
| elif answer_guided_start: |
| base_weight *= 0.45 |
| answer_weight *= 3.1 |
| associative_weight *= 0.2 |
| transition_weight *= 0.35 |
| copy_weight *= 0.2 |
| source_evidence_weight *= 1.1 |
| preference_weight *= 0.2 |
| elif copy_guided_start: |
| base_weight *= 0.55 |
| answer_weight *= 0.35 |
| associative_weight *= 0.4 |
| transition_weight *= 0.35 |
| copy_weight *= 4.5 |
| preference_weight *= 0.6 |
| elif generated_count > 0: |
| answer_weight *= 0.32 |
| transition_weight *= 2.0 |
| copy_weight *= 0.75 |
| source_evidence_weight *= 0.85 |
| if source_grounded: |
| base_weight *= 0.45 |
| answer_weight *= 0.35 |
| associative_weight *= 0.50 |
| transition_weight *= 0.25 |
| copy_weight *= 0.50 |
| source_evidence_weight *= 3.50 |
|
|
| if source_grounded: |
| base_weight *= 0.60 |
| answer_weight *= 0.35 |
| associative_weight *= 0.50 |
| transition_weight *= 0.80 |
| copy_weight *= 0.20 |
| source_evidence_weight *= 1.80 |
| else: |
| source_evidence_weight = 0.0 |
|
|
| if transition_order is None: |
| answer_weight *= 1.1 |
| associative_weight *= 0.75 |
| copy_weight += 0.02 |
| elif transition_order <= 2: |
| answer_weight *= 1.15 |
| associative_weight *= 0.65 |
| transition_weight *= 0.55 |
| copy_weight += 0.01 |
| elif transition_order >= 5: |
| transition_weight *= 1.25 |
|
|
| sources: list[tuple[str, float, Vector]] = [("base", base_weight, base)] |
| if any(value > 0.0 for value in answer): |
| sources.append(("answer", answer_weight, answer)) |
| if any(value > 0.0 for value in associative): |
| sources.append(("associative", associative_weight, associative)) |
| if any(value > 0.0 for value in transition): |
| sources.append(("transition", transition_weight, transition)) |
| if any(value > 0.0 for value in copy): |
| sources.append(("copy", copy_weight, copy)) |
| if any(value > 0.0 for value in source_evidence): |
| sources.append(("source_evidence", source_evidence_weight, source_evidence)) |
| if any(value > 0.0 for value in preference): |
| sources.append(("preference", preference_weight, preference)) |
|
|
| total_weight = sum(weight for _, weight, _ in sources) |
| blended = [0.0 for _ in base] |
| blend_weights: dict[str, float] = {} |
| for name, weight, source in sources: |
| normalized_weight = weight / total_weight if total_weight else 0.0 |
| blend_weights[name] = normalized_weight |
| for index, value in enumerate(source): |
| blended[index] += normalized_weight * value |
| return _normalize_vector(blended), blend_weights |
|
|
| def _blend_probability_arrays( |
| self, |
| base: object, |
| answer: object, |
| associative: object, |
| transition: object, |
| copy: object, |
| source_evidence: object, |
| preference: object, |
| *, |
| transition_order: int | None, |
| generated_count: int = 0, |
| answer_locked: bool = False, |
| answer_guided_start: bool = False, |
| copy_guided_start: bool = False, |
| ) -> tuple[object, dict[str, float]]: |
| assert np is not None |
|
|
| base_weight = FAST_BASE_BLEND |
| answer_weight = FAST_ANSWER_BLEND |
| associative_weight = FAST_ASSOCIATIVE_BLEND |
| transition_weight = FAST_TRANSITION_BLEND |
| copy_weight = FAST_COPY_BLEND |
| source_evidence_weight = FAST_SOURCE_EVIDENCE_BLEND |
| preference_weight = FAST_PREFERENCE_BLEND |
| source_grounded = bool(np.any(source_evidence > 0.0)) |
| if answer_locked: |
| base_weight *= 0.005 |
| answer_weight *= 250.0 |
| associative_weight *= 0.05 |
| transition_weight *= 0.005 |
| copy_weight *= 0.005 |
| source_evidence_weight *= 0.05 |
| preference_weight *= 0.05 |
| elif answer_guided_start: |
| base_weight *= 0.45 |
| answer_weight *= 3.1 |
| associative_weight *= 0.2 |
| transition_weight *= 0.35 |
| copy_weight *= 0.2 |
| source_evidence_weight *= 1.1 |
| preference_weight *= 0.2 |
| elif copy_guided_start: |
| base_weight *= 0.55 |
| answer_weight *= 0.35 |
| associative_weight *= 0.4 |
| transition_weight *= 0.35 |
| copy_weight *= 4.5 |
| preference_weight *= 0.6 |
| elif generated_count > 0: |
| answer_weight *= 0.32 |
| transition_weight *= 2.0 |
| copy_weight *= 0.75 |
| source_evidence_weight *= 0.85 |
| if source_grounded: |
| base_weight *= 0.45 |
| answer_weight *= 0.35 |
| associative_weight *= 0.50 |
| transition_weight *= 0.25 |
| copy_weight *= 0.50 |
| source_evidence_weight *= 3.50 |
| if source_grounded: |
| base_weight *= 0.60 |
| answer_weight *= 0.35 |
| associative_weight *= 0.50 |
| transition_weight *= 0.80 |
| copy_weight *= 0.20 |
| source_evidence_weight *= 1.80 |
| else: |
| source_evidence_weight = 0.0 |
| if transition_order is None: |
| answer_weight *= 1.1 |
| associative_weight *= 0.75 |
| copy_weight += 0.02 |
| elif transition_order <= 2: |
| answer_weight *= 1.15 |
| associative_weight *= 0.65 |
| transition_weight *= 0.55 |
| copy_weight += 0.01 |
| elif transition_order >= 5: |
| transition_weight *= 1.25 |
|
|
| sources: list[tuple[str, float, object]] = [("base", base_weight, base)] |
| if np.any(answer > 0.0): |
| sources.append(("answer", answer_weight, answer)) |
| if np.any(associative > 0.0): |
| sources.append(("associative", associative_weight, associative)) |
| if np.any(transition > 0.0): |
| sources.append(("transition", transition_weight, transition)) |
| if np.any(copy > 0.0): |
| sources.append(("copy", copy_weight, copy)) |
| if np.any(source_evidence > 0.0): |
| sources.append(("source_evidence", source_evidence_weight, source_evidence)) |
| if np.any(preference > 0.0): |
| sources.append(("preference", preference_weight, preference)) |
|
|
| total_weight = sum(weight for _, weight, _ in sources) |
| blended = np.zeros_like(base, dtype=np.float64) |
| blend_weights: dict[str, float] = {} |
| for name, weight, source in sources: |
| normalized_weight = weight / total_weight if total_weight else 0.0 |
| blend_weights[name] = normalized_weight |
| blended += normalized_weight * source |
| total = float(blended.sum()) |
| if total <= 0.0: |
| return base, blend_weights |
| return blended / total, blend_weights |
|
|
| def _score_associative_matches( |
| self, |
| state: Vector, |
| *, |
| limit: int = ASSOCIATIVE_TOP_K, |
| ) -> list[tuple[float, int, int]]: |
| if ( |
| self.associative_keys is None |
| or self.associative_values is None |
| or len(self.associative_keys) == 0 |
| or len(self.associative_values) == 0 |
| ): |
| return [] |
|
|
| if ( |
| np is not None |
| and |
| self.associative_keys_array is not None |
| and self.associative_key_norms_array is not None |
| and self.associative_values_array is not None |
| and self.associative_valid_mask_array is not None |
| and limit > 0 |
| ): |
| state_array = self._center_state_array(state).astype(self.associative_keys_array.dtype, copy=False) |
| state_norm = float(np.linalg.norm(state_array)) |
| if state_norm == 0.0: |
| return [] |
| numerators = self.associative_keys_array @ state_array |
| denominators = self.associative_key_norms_array * state_norm |
| valid_mask = self.associative_valid_mask_array & (denominators > 0.0) |
| if np.any(valid_mask): |
| scores = np.zeros_like(numerators, dtype=self.associative_keys_array.dtype) |
| np.divide(numerators, denominators, out=scores, where=valid_mask) |
| positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) |
| if positive_positions.size: |
| selected_positions = positive_positions |
| if positive_positions.size > limit: |
| partition = np.argpartition(scores[positive_positions], -limit)[-limit:] |
| selected_positions = positive_positions[partition] |
| ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] |
| return [ |
| ( |
| float(scores[position]), |
| int(self.associative_values_array[position]), |
| int(position), |
| ) |
| for position in ordered_positions |
| ] |
|
|
| if self.associative_key_norms is None or len(self.associative_key_norms) == 0: |
| return [] |
|
|
| state = self._center_state_vector(state) |
| state_norm = norm(state) |
| if state_norm == 0.0: |
| return [] |
|
|
| scored: list[tuple[float, int, int]] = [] |
| for example_index, (key, key_norm, token_id) in enumerate( |
| zip(self.associative_keys, self.associative_key_norms, self.associative_values) |
| ): |
| if token_id < 0: |
| continue |
| denominator = state_norm * key_norm |
| if denominator == 0.0: |
| continue |
| similarity = dot(state, key) / denominator |
| if similarity > 0.0: |
| scored.append((similarity, token_id, example_index)) |
| scored.sort(key=lambda item: item[0], reverse=True) |
| return scored[:limit] |
|
|
| def _associative_prior_from_matches( |
| self, |
| matches: list[tuple[float, int, int]], |
| ) -> Vector: |
| assert self.embedding_model is not None |
| if not matches: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
|
|
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| for similarity, token_id, _ in matches[:ASSOCIATIVE_TOP_K]: |
| prior[token_id] += similarity |
| return _normalize_vector(prior) |
|
|
| def _associative_prior(self, state: Vector) -> Vector: |
| return self._associative_prior_from_matches(self._score_associative_matches(state)) |
|
|
| def _score_answer_matches( |
| self, |
| answer_anchor_state: Vector | None, |
| *, |
| limit: int = ANSWER_TOP_K, |
| ) -> list[tuple[float, int, int]]: |
| return self._score_prompt_anchor_matches( |
| answer_anchor_state, |
| self.answer_keys, |
| self.answer_key_norms, |
| self.answer_values, |
| self.answer_keys_array, |
| self.answer_key_norms_array, |
| self.answer_values_array, |
| self.answer_valid_mask_array, |
| self.answer_similarity_keys_array, |
| self.answer_similarity_key_norms_array, |
| self.answer_similarity_mask_array, |
| limit=limit, |
| ) |
|
|
| def _score_answer_start_matches( |
| self, |
| answer_anchor_state: Vector | None, |
| *, |
| limit: int = ANSWER_START_TOP_K, |
| ) -> list[tuple[float, int, int]]: |
| matches = self._score_prompt_anchor_matches( |
| answer_anchor_state, |
| self.answer_start_keys, |
| self.answer_start_key_norms, |
| self.answer_start_values, |
| self.answer_start_keys_array, |
| self.answer_start_key_norms_array, |
| self.answer_start_values_array, |
| self.answer_start_valid_mask_array, |
| self.answer_start_similarity_keys_array, |
| self.answer_start_similarity_key_norms_array, |
| self.answer_similarity_mask_array, |
| limit=limit, |
| ) |
| if matches: |
| return matches |
| return self._score_prompt_anchor_matches( |
| answer_anchor_state, |
| self.answer_start_keys, |
| self.answer_start_key_norms, |
| self.answer_start_values, |
| self.answer_start_keys_array, |
| self.answer_start_key_norms_array, |
| self.answer_start_values_array, |
| self.answer_start_valid_mask_array, |
| None, |
| None, |
| None, |
| limit=limit, |
| ) |
|
|
| def _score_answer_sequence_matches( |
| self, |
| answer_anchor_state: Vector | None, |
| context_tokens: list[str], |
| *, |
| limit: int = ANSWER_START_TOP_K, |
| ) -> list[tuple[float, int, int]]: |
| if ( |
| answer_anchor_state is None |
| or self.answer_sequence_keys is None |
| or self.answer_sequence_key_norms is None |
| or self.answer_sequence_tokens is None |
| ): |
| return [] |
| values = list(range(len(self.answer_sequence_tokens))) |
| values_array = np.arange(len(values), dtype=np.int64) if np is not None else None |
| anchor_matches = self._score_prompt_anchor_matches( |
| answer_anchor_state, |
| self.answer_sequence_keys, |
| self.answer_sequence_key_norms, |
| values, |
| self.answer_sequence_keys_array, |
| self.answer_sequence_key_norms_array, |
| values_array, |
| values_array >= 0 if values_array is not None else None, |
| self.answer_sequence_similarity_keys_array, |
| self.answer_sequence_similarity_key_norms_array, |
| self.answer_similarity_mask_array, |
| limit=max(limit * 4, limit), |
| ) |
| overlap_scores = self._answer_sequence_prompt_overlap_scores(context_tokens) |
| if overlap_scores is None: |
| return anchor_matches[:limit] |
| if not overlap_scores: |
| return [] |
| best_overlap = max(overlap_scores.values()) if overlap_scores else 0.0 |
| overlap_floor = max(0.16, best_overlap * 0.90) |
| focused_overlap_scores = { |
| sequence_index: overlap |
| for sequence_index, overlap in overlap_scores.items() |
| if overlap >= overlap_floor |
| } |
| if not focused_overlap_scores: |
| focused_overlap_scores = overlap_scores |
| focused_indices = set(focused_overlap_scores) |
| merged: dict[int, float] = {} |
| for similarity, sequence_index, _ in anchor_matches: |
| if sequence_index not in focused_indices: |
| continue |
| merged[sequence_index] = max(merged.get(sequence_index, 0.0), 0.20 * similarity) |
| for sequence_index, overlap in focused_overlap_scores.items(): |
| merged[sequence_index] = merged.get(sequence_index, 0.0) + (0.80 * overlap) |
| ranked = [ |
| (score, sequence_index, sequence_index) |
| for sequence_index, score in merged.items() |
| if score > 0.0 |
| ] |
| ranked.sort(key=lambda item: item[0], reverse=True) |
| return ranked[:limit] |
|
|
| def _answer_sequence_prompt_overlap_scores( |
| self, |
| context_tokens: list[str], |
| ) -> dict[int, float] | None: |
| if ( |
| self.embedding_model is None |
| or self.answer_sequence_prompt_tokens is None |
| or self.trace_token_weights is None |
| ): |
| return None |
| answer_boundary = _last_index(context_tokens, "<answer>") |
| prompt_tokens = ( |
| context_tokens[:answer_boundary] |
| if answer_boundary is not None |
| else context_tokens |
| ) |
| if ( |
| self.answer_sequence_prompt_specificity is None |
| and not self._defer_answer_sequence_prompt_overlap_cache() |
| ): |
| self._refresh_answer_sequence_prompt_overlap_cache() |
| specificity_map = self.answer_sequence_prompt_specificity or {} |
| query_weights: dict[int, float] = {} |
| query_specificity: dict[int, float] = {} |
| query_segment_multipliers: dict[int, float] = {} |
| query_content_weight = 0.0 |
| query_ids: list[int] = [] |
| primary_query_ids: list[int] = [] |
| inside_tool_evidence = False |
| prompt_segment_index = 0 |
| for token in prompt_tokens: |
| if token in {"<tool_result>", "<source>"}: |
| inside_tool_evidence = True |
| continue |
| if token == "<final>": |
| inside_tool_evidence = False |
| continue |
| if self.tokenizer is not None and token in self.tokenizer.special_tokens: |
| continue |
| if self._is_structural_punctuation_token(token): |
| prompt_segment_index += 1 |
| continue |
| if self._should_skip_prompt_overlap_token(token): |
| continue |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| continue |
| query_ids.append(token_id) |
| specificity = specificity_map.get(token_id, 1.0) |
| evidence_multiplier = 0.35 if inside_tool_evidence else 1.0 |
| segment_multiplier = evidence_multiplier / (1.0 + prompt_segment_index) |
| weight = specificity * segment_multiplier |
| query_weights[token_id] = max( |
| query_weights.get(token_id, 0.0), |
| weight, |
| ) |
| query_specificity[token_id] = max( |
| query_specificity.get(token_id, 0.0), |
| specificity, |
| ) |
| query_segment_multipliers[token_id] = max( |
| query_segment_multipliers.get(token_id, 0.0), |
| segment_multiplier, |
| ) |
| if not inside_tool_evidence: |
| primary_query_ids.append(token_id) |
| if specificity >= 0.20: |
| query_content_weight += weight |
| if not query_weights: |
| return None |
| full_query_token_ids = set(query_ids) |
| primary_query_token_ids = set(primary_query_ids) |
| has_tool_evidence = any(token in {"<tool_result>", "<source>"} for token in prompt_tokens) |
| query_norm = sum(value * value for value in query_weights.values()) ** 0.5 |
| if query_norm <= 0.0: |
| return None |
|
|
| query_bigrams = { |
| (query_ids[index], query_ids[index + 1]) |
| for index in range(len(query_ids) - 1) |
| } |
| query_trigrams = { |
| (query_ids[index], query_ids[index + 1], query_ids[index + 2]) |
| for index in range(len(query_ids) - 2) |
| } |
| query_numbers = self._number_strings_from_tokens(prompt_tokens) |
|
|
| def ordered_ngram_score( |
| query_grams: set[tuple[int, ...]], |
| row_grams: set[tuple[int, ...]], |
| ) -> float: |
| if not query_grams or not row_grams: |
| return 0.0 |
| overlap = len(query_grams & row_grams) |
| if overlap <= 0: |
| return 0.0 |
| return overlap / ((len(query_grams) * len(row_grams)) ** 0.5) |
|
|
| def prompt_length_fit(row_token_count: int) -> float: |
| query_token_count = len(full_query_token_ids) |
| if query_token_count <= 0 or row_token_count <= 0: |
| return 1.0 |
| if row_token_count <= query_token_count: |
| return 1.0 |
| extra_fraction = (row_token_count - query_token_count) / row_token_count |
| return max(0.25, 1.0 - extra_fraction) |
|
|
| cached_maps = self.answer_sequence_prompt_weight_maps |
| cached_norms = self.answer_sequence_prompt_weight_norms |
| cached_bigrams = self.answer_sequence_prompt_bigram_sets |
| cached_trigrams = self.answer_sequence_prompt_trigram_sets |
| cached_numbers = self.answer_sequence_prompt_number_sets |
| cached_index = self.answer_sequence_prompt_inverted_index |
| if ( |
| cached_maps is not None |
| and cached_norms is not None |
| and cached_bigrams is not None |
| and cached_trigrams is not None |
| and cached_numbers is not None |
| and len(cached_maps) == len(self.answer_sequence_prompt_tokens) |
| ): |
| candidate_indices: set[int] | range |
| if cached_index is not None: |
| candidates: set[int] = set() |
| ranked_query_ids = sorted( |
| query_weights, |
| key=lambda token_id: specificity_map.get(token_id, 1.0), |
| reverse=True, |
| ) |
| distinctive_query_ids = [ |
| token_id |
| for token_id in ranked_query_ids |
| if specificity_map.get(token_id, 1.0) >= 0.75 |
| ] or ranked_query_ids[:4] |
| for token_id in distinctive_query_ids: |
| candidates.update(cached_index.get(token_id, ())) |
| candidate_indices = candidates if candidates else range(len(cached_maps)) |
| else: |
| candidate_indices = range(len(cached_maps)) |
| candidate_indices = list(candidate_indices) |
| if cached_index is not None and candidate_indices: |
| candidate_set = set(candidate_indices) |
| local_query_weights: dict[int, float] = {} |
| local_query_specificity: dict[int, float] = {} |
| local_query_content_weight = 0.0 |
| for token_id in query_weights: |
| local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) |
| if local_frequency <= 0: |
| continue |
| specificity = self._prompt_overlap_token_specificity( |
| local_frequency, |
| len(candidate_indices), |
| ) |
| weight = specificity * query_segment_multipliers.get(token_id, 1.0) |
| local_query_weights[token_id] = weight |
| local_query_specificity[token_id] = specificity |
| if specificity >= 0.20: |
| local_query_content_weight += weight |
| local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 |
| if local_query_norm > 0.0: |
| query_weights = local_query_weights |
| query_specificity = local_query_specificity |
| query_norm = local_query_norm |
| scores: dict[int, float] = {} |
| for sequence_index in candidate_indices: |
| row_weights = cached_maps[sequence_index] |
| if not row_weights: |
| continue |
| if query_numbers and not self._numeric_prompt_can_match( |
| query_numbers, |
| cached_numbers[sequence_index], |
| ): |
| continue |
| matched_content_weight = sum( |
| query_weights[token_id] |
| for token_id in query_weights.keys() & row_weights.keys() |
| if query_specificity.get(token_id, 0.0) >= 0.20 |
| ) |
| row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( |
| 1, |
| len(row_weights), |
| ) |
| full_query_coverage = len(full_query_token_ids & row_weights.keys()) / max( |
| 1, |
| len(full_query_token_ids), |
| ) |
| primary_query_coverage = len(primary_query_token_ids & row_weights.keys()) / max( |
| 1, |
| len(primary_query_token_ids), |
| ) |
| if ( |
| has_tool_evidence |
| and len(primary_query_token_ids) >= 3 |
| and primary_query_coverage < 0.45 |
| and row_token_coverage < 0.75 |
| ): |
| continue |
| partial_query_floor = 0.60 if len(full_query_token_ids) < 8 else 0.50 |
| if ( |
| len(full_query_token_ids) >= 5 |
| and full_query_coverage <= partial_query_floor |
| and row_token_coverage < 0.75 |
| ): |
| continue |
| if ( |
| len(full_query_token_ids) >= 12 |
| and full_query_coverage < 0.45 |
| and row_token_coverage <= 0.75 |
| ): |
| continue |
| if ( |
| query_content_weight > 0.0 |
| and matched_content_weight / query_content_weight < 0.40 |
| and row_token_coverage < 0.75 |
| and full_query_coverage < 0.60 |
| ): |
| continue |
| query_coverage = ( |
| matched_content_weight / query_content_weight |
| if query_content_weight > 0.0 |
| else row_token_coverage |
| ) |
| numerator = sum( |
| query_weights[token_id] * row_weights[token_id] |
| for token_id in query_weights.keys() & row_weights.keys() |
| ) |
| if numerator <= 0.0: |
| continue |
| row_norm = cached_norms[sequence_index] |
| if row_norm <= 0.0: |
| continue |
| token_score = numerator / (query_norm * row_norm) |
| bigram_score = ordered_ngram_score( |
| query_bigrams, |
| cached_bigrams[sequence_index], |
| ) |
| trigram_score = ordered_ngram_score( |
| query_trigrams, |
| cached_trigrams[sequence_index], |
| ) |
| scores[sequence_index] = ( |
| (0.35 * token_score) |
| + (0.35 * query_coverage) |
| + (0.15 * bigram_score) |
| + (0.15 * trigram_score) |
| ) * prompt_length_fit(len(row_weights)) |
| return scores |
|
|
| vector_candidate_indices: list[int] | None = None |
| if cached_index is not None: |
| candidate_set: set[int] = set() |
| ranked_query_ids = sorted( |
| query_weights, |
| key=lambda token_id: specificity_map.get(token_id, 1.0), |
| reverse=True, |
| ) |
| distinctive_query_ids = [ |
| token_id |
| for token_id in ranked_query_ids |
| if specificity_map.get(token_id, 1.0) >= 0.75 |
| ] or ranked_query_ids[:4] |
| for token_id in distinctive_query_ids: |
| candidate_set.update(cached_index.get(token_id, ())) |
| if not candidate_set: |
| for token_id in ranked_query_ids: |
| candidate_set.update(cached_index.get(token_id, ())) |
| if candidate_set: |
| break |
| if not candidate_set: |
| candidate_indices = range(len(self.answer_sequence_prompt_tokens)) |
| else: |
| candidate_indices = sorted(candidate_set) |
| local_query_weights: dict[int, float] = {} |
| local_query_specificity: dict[int, float] = {} |
| local_query_content_weight = 0.0 |
| candidate_count = len(candidate_indices) |
| for token_id in query_weights: |
| local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) |
| if local_frequency <= 0: |
| continue |
| specificity = self._prompt_overlap_token_specificity( |
| local_frequency, |
| candidate_count, |
| ) |
| local_query_weights[token_id] = specificity * query_segment_multipliers.get(token_id, 1.0) |
| local_query_specificity[token_id] = specificity |
| if specificity >= 0.20: |
| local_query_content_weight += local_query_weights[token_id] |
| local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 |
| if local_query_norm > 0.0: |
| query_weights = local_query_weights |
| query_specificity = local_query_specificity |
| query_norm = local_query_norm |
| elif self._defer_answer_sequence_prompt_overlap_cache(): |
| vector_candidate_indices = self._vector_answer_sequence_candidate_indices( |
| query_weights.keys() |
| ) |
| if vector_candidate_indices is not None: |
| if not vector_candidate_indices: |
| return {} |
| candidate_indices = vector_candidate_indices |
| local_query_weights = {} |
| local_query_specificity = {} |
| local_query_content_weight = 0.0 |
| candidate_count = len(vector_candidate_indices) |
| for token_id in query_weights: |
| local_frequency = self._vector_answer_sequence_local_frequency( |
| token_id, |
| vector_candidate_indices, |
| ) |
| if local_frequency is None or local_frequency <= 0: |
| continue |
| specificity = self._prompt_overlap_token_specificity( |
| local_frequency, |
| candidate_count, |
| ) |
| local_query_weights[token_id] = specificity * query_segment_multipliers.get(token_id, 1.0) |
| local_query_specificity[token_id] = specificity |
| if specificity >= 0.20: |
| local_query_content_weight += local_query_weights[token_id] |
| local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 |
| if local_query_norm > 0.0: |
| query_weights = local_query_weights |
| query_specificity = local_query_specificity |
| query_norm = local_query_norm |
| else: |
| candidate_indices = range(len(self.answer_sequence_prompt_tokens)) |
|
|
| valid_token_mask = self._prompt_overlap_valid_token_mask() |
| scores: dict[int, float] = {} |
| for sequence_index in candidate_indices: |
| row = self.answer_sequence_prompt_tokens[sequence_index] |
| row_values = row.tolist() if hasattr(row, "tolist") else row |
| row_weights: dict[int, float] = {} |
| row_ids: list[int] = [] |
| raw_row_ids: list[int] = [] |
| for raw_token_id in row_values: |
| token_id = int(raw_token_id) |
| if token_id < 0 or token_id >= len(self.trace_token_weights): |
| continue |
| raw_row_ids.append(token_id) |
| if valid_token_mask is not None: |
| if token_id >= len(valid_token_mask) or not bool(valid_token_mask[token_id]): |
| continue |
| elif self._should_skip_prompt_overlap_token( |
| self.embedding_model.id_to_token[token_id] |
| ): |
| continue |
| row_ids.append(token_id) |
| row_weights[token_id] = max( |
| row_weights.get(token_id, 0.0), |
| specificity_map.get(token_id, 1.0), |
| ) |
| if not row_weights: |
| continue |
| if query_numbers and not self._numeric_prompt_can_match( |
| query_numbers, |
| self._number_strings_from_token_ids(raw_row_ids), |
| ): |
| continue |
| matched_content_weight = sum( |
| query_weights[token_id] |
| for token_id in query_weights.keys() & row_weights.keys() |
| if query_specificity.get(token_id, 0.0) >= 0.20 |
| ) |
| row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( |
| 1, |
| len(row_weights), |
| ) |
| full_query_coverage = len(full_query_token_ids & row_weights.keys()) / max( |
| 1, |
| len(full_query_token_ids), |
| ) |
| primary_query_coverage = len(primary_query_token_ids & row_weights.keys()) / max( |
| 1, |
| len(primary_query_token_ids), |
| ) |
| if ( |
| has_tool_evidence |
| and len(primary_query_token_ids) >= 3 |
| and primary_query_coverage < 0.45 |
| and row_token_coverage < 0.75 |
| ): |
| continue |
| partial_query_floor = 0.60 if len(full_query_token_ids) < 8 else 0.30 |
| if ( |
| len(full_query_token_ids) >= 5 |
| and full_query_coverage <= partial_query_floor |
| and row_token_coverage < 0.75 |
| ): |
| continue |
| if ( |
| len(full_query_token_ids) >= 12 |
| and full_query_coverage < 0.25 |
| and row_token_coverage <= 0.75 |
| ): |
| continue |
| if ( |
| query_content_weight > 0.0 |
| and matched_content_weight / query_content_weight < 0.25 |
| and row_token_coverage < 0.75 |
| and full_query_coverage < 0.60 |
| ): |
| continue |
| query_coverage = ( |
| matched_content_weight / query_content_weight |
| if query_content_weight > 0.0 |
| else row_token_coverage |
| ) |
| numerator = sum( |
| query_weights[token_id] * row_weights[token_id] |
| for token_id in query_weights.keys() & row_weights.keys() |
| ) |
| if numerator <= 0.0: |
| continue |
| row_norm = sum(value * value for value in row_weights.values()) ** 0.5 |
| if row_norm > 0.0: |
| token_score = numerator / (query_norm * row_norm) |
| row_bigrams = { |
| (row_ids[index], row_ids[index + 1]) |
| for index in range(len(row_ids) - 1) |
| } |
| row_trigrams = { |
| (row_ids[index], row_ids[index + 1], row_ids[index + 2]) |
| for index in range(len(row_ids) - 2) |
| } |
| bigram_score = ordered_ngram_score(query_bigrams, row_bigrams) |
| trigram_score = ordered_ngram_score(query_trigrams, row_trigrams) |
| scores[sequence_index] = ( |
| (0.35 * token_score) |
| + (0.35 * query_coverage) |
| + (0.15 * bigram_score) |
| + (0.15 * trigram_score) |
| ) * prompt_length_fit(len(row_weights)) |
| return scores |
|
|
| def _score_prompt_anchor_matches( |
| self, |
| answer_anchor_state: Vector | None, |
| keys: object | None, |
| key_norms_list: object | None, |
| values: object | None, |
| keys_array: object | None, |
| key_norms_array: object | None, |
| values_array: object | None, |
| valid_mask_array: object | None, |
| similarity_keys_array: object | None, |
| similarity_key_norms_array: object | None, |
| similarity_mask_array: object | None, |
| *, |
| limit: int, |
| ) -> list[tuple[float, int, int]]: |
| if ( |
| answer_anchor_state is None |
| or keys is None |
| or key_norms_list is None |
| or values is None |
| ): |
| return [] |
|
|
| if ( |
| np is not None |
| and keys_array is not None |
| and key_norms_array is not None |
| and values_array is not None |
| and valid_mask_array is not None |
| and limit > 0 |
| ): |
| key_array = keys_array |
| key_norms = key_norms_array |
| if ( |
| similarity_keys_array is not None |
| and similarity_key_norms_array is not None |
| and similarity_mask_array is not None |
| ): |
| state_array = self._center_state_array( |
| self._masked_combined_state_array(answer_anchor_state) |
| ).astype(keys_array.dtype, copy=False) |
| state_array = state_array * similarity_mask_array |
| key_array = similarity_keys_array |
| key_norms = similarity_key_norms_array |
| else: |
| state_array = self._center_state_array(answer_anchor_state).astype( |
| keys_array.dtype, |
| copy=False, |
| ) |
| state_norm = float(np.linalg.norm(state_array)) |
| if state_norm == 0.0: |
| return [] |
| numerators = key_array @ state_array |
| denominators = key_norms * state_norm |
| valid_mask = valid_mask_array & (denominators > 0.0) |
| if np.any(valid_mask): |
| scores = np.zeros_like(numerators, dtype=key_array.dtype) |
| np.divide(numerators, denominators, out=scores, where=valid_mask) |
| positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) |
| if positive_positions.size: |
| selected_positions = positive_positions |
| if positive_positions.size > limit: |
| partition = np.argpartition(scores[positive_positions], -limit)[-limit:] |
| selected_positions = positive_positions[partition] |
| ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] |
| return [ |
| ( |
| float(scores[position]), |
| int(values_array[position]), |
| int(position), |
| ) |
| for position in ordered_positions |
| ] |
|
|
| if similarity_mask_array is not None: |
| state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) |
| else: |
| state = self._center_state_vector(answer_anchor_state) |
| state_norm = norm(state) |
| if state_norm == 0.0: |
| return [] |
|
|
| scored: list[tuple[float, int, int]] = [] |
| for example_index, (key, key_norm, token_id) in enumerate( |
| zip(keys, key_norms_list, values) |
| ): |
| if token_id < 0: |
| continue |
| denominator = state_norm * key_norm |
| if denominator == 0.0: |
| continue |
| similarity = dot(state, key) / denominator |
| if similarity > 0.0: |
| scored.append((similarity, token_id, example_index)) |
| scored.sort(key=lambda item: item[0], reverse=True) |
| return scored[:limit] |
|
|
| def _answer_prior_from_matches( |
| self, |
| matches: list[tuple[float, int, int]], |
| generated_tokens: list[str], |
| ) -> Vector: |
| assert self.embedding_model is not None |
| if not matches: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
|
|
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| generated_ids = { |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| } |
| for similarity, token_id, _ in matches[:ANSWER_TOP_K]: |
| token = self.embedding_model.id_to_token[token_id] |
| if not self._allowed_generation_token(token, generated_tokens): |
| continue |
| if token_id in generated_ids: |
| prior[token_id] += similarity * 0.35 |
| else: |
| prior[token_id] += similarity |
| return _normalize_vector(prior) |
|
|
| def _answer_start_matches_from_sequences( |
| self, |
| matches: list[tuple[float, int, int]], |
| ) -> list[tuple[float, int, int]]: |
| if not matches or self.answer_sequence_tokens is None: |
| return [] |
| start_matches: list[tuple[float, int, int]] = [] |
| for similarity, sequence_index, example_index in matches[:ANSWER_START_TOP_K]: |
| if sequence_index >= len(self.answer_sequence_tokens): |
| continue |
| row = self.answer_sequence_tokens[sequence_index] |
| token_ids = [ |
| int(value) |
| for value in (row.tolist() if hasattr(row, "tolist") else row) |
| if int(value) >= 0 |
| ] |
| if token_ids: |
| start_matches.append((similarity, token_ids[0], example_index)) |
| return start_matches |
|
|
| def _answer_sequence_prior_from_matches( |
| self, |
| matches: list[tuple[float, int, int]], |
| generated_tokens: list[str], |
| *, |
| temperature: float = 0.0, |
| ) -> Vector: |
| assert self.embedding_model is not None |
| if not matches or self.answer_sequence_tokens is None: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
|
|
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| best_similarity = matches[0][0] |
| if best_similarity >= 0.9: |
| floor_delta = 0.14 if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE else 0.02 |
| match_floor = best_similarity - floor_delta |
| else: |
| match_floor = 0.0 |
| for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: |
| if similarity < ANSWER_SEQUENCE_MATCH_FLOOR: |
| continue |
| if similarity < match_floor: |
| continue |
| token_ids = self._answer_sequence_token_row(sequence_index) |
| if not token_ids: |
| continue |
| next_token_id = self._next_sequence_token_id(token_ids, generated_ids) |
| if next_token_id is None: |
| continue |
| token = self.embedding_model.id_to_token[next_token_id] |
| if self._allowed_answer_sequence_token(token, generated_tokens): |
| prior[next_token_id] += max(1e-9, similarity - match_floor) |
| return _normalize_vector(prior) |
|
|
| def _answer_sequence_token_row(self, sequence_index: int) -> list[int]: |
| if sequence_index < 0 or self.answer_sequence_tokens is None: |
| return [] |
| if self.answer_sequence_token_id_rows is not None: |
| if sequence_index >= len(self.answer_sequence_token_id_rows): |
| return [] |
| return self.answer_sequence_token_id_rows[sequence_index] |
| if ( |
| np is not None |
| and hasattr(self.answer_sequence_tokens, "shape") |
| and len(self.answer_sequence_tokens.shape) == 2 |
| ): |
| if sequence_index >= int(self.answer_sequence_tokens.shape[0]): |
| return [] |
| row = np.asarray(self.answer_sequence_tokens[sequence_index]) |
| return [int(value) for value in row.tolist() if int(value) >= 0] |
| try: |
| row = self.answer_sequence_tokens[sequence_index] |
| except (IndexError, TypeError): |
| return [] |
| return self._answer_token_ids_from_row(row) |
|
|
| def _filter_avoided_answer_sequence_matches( |
| self, |
| matches: list[tuple[float, int, int]] | None, |
| avoid_token_sequences: Sequence[Sequence[str]] | None, |
| ) -> list[tuple[float, int, int]]: |
| if ( |
| not matches |
| or not avoid_token_sequences |
| or self.embedding_model is None |
| or self.answer_sequence_tokens is None |
| ): |
| return list(matches or []) |
|
|
| token_to_id = self.embedding_model.token_to_id |
| avoided_id_sequences: set[tuple[int, ...]] = set() |
| for sequence in avoid_token_sequences: |
| ids: list[int] = [] |
| for token in sequence: |
| token_id = token_to_id.get(token) |
| if token_id is None: |
| ids = [] |
| break |
| ids.append(token_id) |
| if ids: |
| avoided_id_sequences.add(tuple(ids)) |
| if not avoided_id_sequences: |
| return list(matches) |
|
|
| sequence_rows = self._answer_sequence_token_rows() |
| filtered: list[tuple[float, int, int]] = [] |
| for match in matches: |
| _, sequence_index, _ = match |
| if sequence_index >= len(sequence_rows): |
| filtered.append(match) |
| continue |
| if tuple(sequence_rows[sequence_index]) in avoided_id_sequences: |
| continue |
| filtered.append(match) |
| return filtered |
|
|
| def _answer_sequence_token_rows(self) -> list[list[int]]: |
| if self.answer_sequence_token_id_rows is not None: |
| return self.answer_sequence_token_id_rows |
| rows: list[list[int]] = [] |
| if ( |
| np is not None |
| and self.answer_sequence_tokens is not None |
| and hasattr(self.answer_sequence_tokens, "shape") |
| and len(self.answer_sequence_tokens.shape) == 2 |
| ): |
| token_rows = np.asarray(self.answer_sequence_tokens).tolist() |
| rows = [ |
| [int(value) for value in row if int(value) >= 0] |
| for row in token_rows |
| ] |
| elif self.answer_sequence_tokens is not None: |
| for row in self.answer_sequence_tokens: |
| rows.append(self._answer_token_ids_from_row(row)) |
| self.answer_sequence_token_id_rows = rows |
| return rows |
|
|
| @staticmethod |
| def _answer_token_ids_from_row(row: object) -> list[int]: |
| values = row.tolist() if hasattr(row, "tolist") else row |
| if not isinstance(values, list): |
| return [] |
| return [int(value) for value in values if int(value) >= 0] |
|
|
| @staticmethod |
| def _answer_fingerprint_from_token_ids(token_ids: list[int]) -> tuple[int, ...]: |
| payload = ",".join(str(token_id) for token_id in token_ids).encode("ascii") |
| digest = hashlib.blake2s( |
| payload, |
| digest_size=ANSWER_FINGERPRINT_WORDS * 4, |
| ).digest() |
| return tuple( |
| int.from_bytes( |
| digest[index * 4 : (index + 1) * 4], |
| "little", |
| signed=True, |
| ) |
| for index in range(ANSWER_FINGERPRINT_WORDS) |
| ) |
|
|
| def _refresh_answer_fingerprint_hashes(self) -> None: |
| hashes: set[tuple[int, ...]] = set() |
| lengths: set[int] = set() |
| sequences_by_length: dict[int, set[tuple[int, ...]]] = {} |
| if self.answer_sequence_tokens is not None: |
| for token_ids in self._answer_sequence_token_rows(): |
| if token_ids: |
| token_length = len(token_ids) |
| lengths.add(token_length) |
| sequences_by_length.setdefault(token_length, set()).add(tuple(token_ids)) |
| hashes.add(self._answer_fingerprint_from_token_ids(token_ids)) |
| self.answer_fingerprint_hashes = hashes |
| self.answer_fingerprint_token_lengths = lengths |
| self.answer_fingerprint_token_sequences_by_length = sequences_by_length |
|
|
| def _answer_fingerprint_tensor(self) -> list[list[int]]: |
| if self.answer_fingerprint_hashes is None: |
| self._refresh_answer_fingerprint_hashes() |
| return [ |
| list(fingerprint) |
| for fingerprint in sorted(self.answer_fingerprint_hashes or set()) |
| ] |
|
|
| @staticmethod |
| def _coerce_answer_fingerprint_hashes(raw_fingerprints: object) -> set[tuple[int, ...]]: |
| rows = raw_fingerprints.tolist() if hasattr(raw_fingerprints, "tolist") else raw_fingerprints |
| hashes: set[tuple[int, ...]] = set() |
| if not isinstance(rows, list): |
| return hashes |
| for row in rows: |
| values = row.tolist() if hasattr(row, "tolist") else row |
| if not isinstance(values, list): |
| continue |
| fingerprint = tuple(int(value) for value in values) |
| if len(fingerprint) == ANSWER_FINGERPRINT_WORDS: |
| hashes.add(fingerprint) |
| return hashes |
|
|
| def _answer_fingerprint_lengths(self) -> set[int]: |
| if self.answer_fingerprint_token_lengths is not None: |
| return self.answer_fingerprint_token_lengths |
| lengths: set[int] = set() |
| if ( |
| np is not None |
| and self.answer_sequence_tokens is not None |
| and hasattr(self.answer_sequence_tokens, "shape") |
| and len(self.answer_sequence_tokens.shape) == 2 |
| ): |
| token_matrix = np.asarray(self.answer_sequence_tokens) |
| length_values = np.sum(token_matrix >= 0, axis=1) |
| lengths = { |
| int(length) |
| for length in np.unique(length_values).tolist() |
| if int(length) > 0 |
| } |
| elif self.answer_sequence_tokens is not None: |
| for token_ids in self._answer_sequence_token_rows(): |
| if token_ids: |
| lengths.add(len(token_ids)) |
| self.answer_fingerprint_token_lengths = lengths |
| return lengths |
|
|
| def _use_runtime_fingerprint_blacklist(self) -> bool: |
| if ( |
| np is None |
| or self.answer_sequence_tokens is None |
| or not hasattr(self.answer_sequence_tokens, "shape") |
| or len(self.answer_sequence_tokens.shape) != 2 |
| ): |
| return False |
| return int(self.answer_sequence_tokens.shape[0]) > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT |
|
|
| def _answer_fingerprint_token_sequence_sets(self) -> dict[int, set[tuple[int, ...]]]: |
| if self.answer_fingerprint_token_sequences_by_length is not None: |
| return self.answer_fingerprint_token_sequences_by_length |
| sequences_by_length: dict[int, set[tuple[int, ...]]] = {} |
| lengths: set[int] = set() |
| if self.answer_sequence_tokens is not None: |
| for token_ids in self._answer_sequence_token_rows(): |
| if token_ids: |
| token_length = len(token_ids) |
| lengths.add(token_length) |
| sequences_by_length.setdefault(token_length, set()).add(tuple(token_ids)) |
| self.answer_fingerprint_token_lengths = lengths |
| self.answer_fingerprint_token_sequences_by_length = sequences_by_length |
| return sequences_by_length |
|
|
| def _token_ids_for_generated_tokens(self, generated_tokens: Sequence[str]) -> list[int] | None: |
| if self.embedding_model is None: |
| return None |
| token_ids: list[int] = [] |
| for token in generated_tokens: |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| return None |
| token_ids.append(token_id) |
| return token_ids |
|
|
| def _would_complete_blacklisted_answer( |
| self, |
| generated_tokens: list[str], |
| candidate: str, |
| ) -> bool: |
| generated_token_ids = self._token_ids_for_generated_tokens(generated_tokens) |
| return self._would_complete_blacklisted_answer_ids(generated_token_ids, candidate) |
|
|
| def _would_complete_blacklisted_answer_ids( |
| self, |
| generated_token_ids: Sequence[int] | None, |
| candidate: str, |
| ) -> bool: |
| if ( |
| self.embedding_model is None |
| or not self.answer_fingerprint_hashes |
| or candidate not in self.embedding_model.token_to_id |
| or generated_token_ids is None |
| ): |
| return False |
| candidate_id = self.embedding_model.token_to_id[candidate] |
| if self._is_terminal_punctuation_text(self._render_token(candidate)): |
| return False |
| candidate_length = len(generated_token_ids) + 1 |
| if self._use_runtime_fingerprint_blacklist(): |
| lengths = self._answer_fingerprint_lengths() |
| if lengths and candidate_length not in lengths: |
| return False |
| token_ids = [*generated_token_ids, candidate_id] |
| if not token_ids: |
| return False |
| return self._answer_fingerprint_from_token_ids(token_ids) in self.answer_fingerprint_hashes |
| sequence_sets = self._answer_fingerprint_token_sequence_sets() |
| candidate_sequences = sequence_sets.get(candidate_length) |
| if candidate_sequences is not None: |
| return (*generated_token_ids, candidate_id) in candidate_sequences |
| if self.answer_sequence_tokens is not None: |
| return False |
| lengths = self._answer_fingerprint_lengths() |
| if lengths and candidate_length not in lengths: |
| return False |
| token_ids = [*generated_token_ids, candidate_id] |
| if not token_ids: |
| return False |
| return self._answer_fingerprint_from_token_ids(token_ids) in self.answer_fingerprint_hashes |
|
|
| def _would_follow_blacklisted_answer_prefix_ids( |
| self, |
| generated_token_ids: Sequence[int] | None, |
| candidate: str, |
| *, |
| minimum_prefix_length: int = ANSWER_REPLAY_PREFIX_MIN_TOKENS, |
| ) -> bool: |
| if ( |
| self.embedding_model is None |
| or self.answer_sequence_tokens is None |
| or candidate not in self.embedding_model.token_to_id |
| or generated_token_ids is None |
| ): |
| return False |
| candidate_id = self.embedding_model.token_to_id[candidate] |
| candidate_path = (*generated_token_ids, candidate_id) |
| if len(candidate_path) < minimum_prefix_length: |
| return False |
| prefix_sets = self._answer_sequence_prefix_sets(minimum_prefix_length) |
| return candidate_path in prefix_sets.get(len(candidate_path), set()) |
|
|
| def _answer_sequence_prefix_sets( |
| self, |
| minimum_prefix_length: int = ANSWER_REPLAY_PREFIX_MIN_TOKENS, |
| ) -> dict[int, set[tuple[int, ...]]]: |
| cached = self.answer_sequence_prefixes_by_length |
| if cached is not None: |
| return cached |
| prefixes: dict[int, set[tuple[int, ...]]] = {} |
| for token_ids in self._answer_sequence_token_rows(): |
| for length in range(minimum_prefix_length, len(token_ids) + 1): |
| prefixes.setdefault(length, set()).add(tuple(token_ids[:length])) |
| self.answer_sequence_prefixes_by_length = prefixes |
| return prefixes |
|
|
| def _avoid_text_token_sequences( |
| self, |
| avoid_texts: Sequence[str] | None, |
| ) -> list[list[str]]: |
| if not avoid_texts or self.tokenizer is None: |
| return [] |
| sequences: list[list[str]] = [] |
| seen: set[tuple[str, ...]] = set() |
| for text in avoid_texts: |
| if not isinstance(text, str) or not text.strip(): |
| continue |
| tokens = [ |
| token |
| for token in self.tokenizer.encode(text) |
| if token not in self.tokenizer.special_tokens |
| ] |
| key = tuple(tokens) |
| if tokens and key not in seen: |
| seen.add(key) |
| sequences.append(tokens) |
| return sequences |
|
|
| @staticmethod |
| def _runtime_generation_history_key(context: str) -> str: |
| return " ".join(context.split()).casefold() |
|
|
| @staticmethod |
| def _runtime_history_enabled(context: str, *, temperature: float) -> bool: |
| if temperature < ANSWER_REPLAY_PREFIX_TEMPERATURE: |
| return False |
| lowered = context.casefold() |
| return "<source>" not in lowered and "<tool_result>" not in lowered |
|
|
| def _runtime_avoid_texts( |
| self, |
| context: str, |
| avoid_texts: Sequence[str] | None, |
| *, |
| temperature: float, |
| ) -> list[str]: |
| combined: list[str] = [] |
| seen: set[str] = set() |
| for text in avoid_texts or (): |
| cleaned = " ".join(str(text).split()) |
| if cleaned and cleaned not in seen: |
| combined.append(cleaned) |
| seen.add(cleaned) |
| if not self._runtime_history_enabled(context, temperature=temperature): |
| return combined |
| history = self.runtime_generation_history.get( |
| self._runtime_generation_history_key(context), |
| [], |
| ) |
| for text in history: |
| cleaned = " ".join(str(text).split()) |
| if cleaned and cleaned not in seen: |
| combined.append(cleaned) |
| seen.add(cleaned) |
| return combined |
|
|
| def _remember_runtime_generation( |
| self, |
| context: str, |
| generated_text: str, |
| *, |
| temperature: float, |
| ) -> None: |
| if not self._runtime_history_enabled(context, temperature=temperature): |
| return |
| cleaned = " ".join(generated_text.split()) |
| if not cleaned: |
| return |
| key = self._runtime_generation_history_key(context) |
| history = [ |
| existing |
| for existing in self.runtime_generation_history.get(key, []) |
| if existing != cleaned |
| ] |
| history.append(cleaned) |
| self.runtime_generation_history[key] = history[-RUNTIME_GENERATION_HISTORY_LIMIT:] |
|
|
| @staticmethod |
| def _would_follow_avoided_sequence( |
| generated_tokens: list[str], |
| candidate: str, |
| avoid_token_sequences: Sequence[Sequence[str]] | None, |
| ) -> bool: |
| if not avoid_token_sequences: |
| return False |
| prefix_length = len(generated_tokens) + 1 |
| if prefix_length < AVOID_SEQUENCE_MIN_TOKENS: |
| return False |
| candidate_path = [*generated_tokens, candidate] |
| for sequence in avoid_token_sequences: |
| if prefix_length <= len(sequence) and list(sequence[:prefix_length]) == candidate_path: |
| return True |
| return False |
|
|
| def _should_stop_answer_sequence( |
| self, |
| decode_state: DecodeState, |
| generated_tokens: list[str], |
| ) -> bool: |
| matches = decode_state.answer_sequence_matches |
| if matches is None: |
| matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| ) |
| return self._answer_sequence_is_complete(generated_tokens, matches) |
|
|
| def _should_stop_after_answer_path_drift( |
| self, |
| decode_state: DecodeState, |
| generated_tokens: list[str], |
| ) -> bool: |
| matches = decode_state.answer_sequence_matches |
| if matches is None: |
| matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| ) |
| if not matches or matches[0][0] < ANSWER_SEQUENCE_MATCH_FLOOR: |
| return False |
| if self._answer_sequence_has_continuation(generated_tokens, matches): |
| return False |
| if self._generated_answer_ends_terminal_sentence(generated_tokens): |
| return True |
| return self._generated_word_count(generated_tokens) >= 14 |
|
|
| def _generated_answer_ends_terminal_sentence(self, generated_tokens: list[str]) -> bool: |
| if not generated_tokens: |
| return False |
| rendered = self._render_token(generated_tokens[-1]) |
| if not self._is_terminal_punctuation_text(rendered): |
| return False |
| return self._generated_word_count(generated_tokens) > 0 |
|
|
| def _answer_decode_has_continuation( |
| self, |
| decode_state: DecodeState, |
| generated_tokens: list[str], |
| ) -> bool: |
| matches = decode_state.answer_sequence_matches |
| if matches is None: |
| matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| ) |
| return self._answer_sequence_has_continuation(generated_tokens, matches) |
|
|
| def _answer_sequence_is_complete( |
| self, |
| generated_tokens: list[str], |
| matches: list[tuple[float, int, int]], |
| ) -> bool: |
| if ( |
| self.embedding_model is None |
| or self.answer_sequence_tokens is None |
| or not generated_tokens |
| or not matches |
| ): |
| return False |
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if not generated_ids: |
| return False |
| for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: |
| if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): |
| continue |
| row = self.answer_sequence_tokens[sequence_index] |
| token_ids = [ |
| int(value) |
| for value in (row.tolist() if hasattr(row, "tolist") else row) |
| if int(value) >= 0 |
| ] |
| if not token_ids: |
| continue |
| if len(generated_ids) >= len(token_ids) and generated_ids[: len(token_ids)] == token_ids: |
| return True |
| if ( |
| self.answer_fingerprint_hashes |
| and len(generated_ids) + 1 == len(token_ids) |
| and generated_ids == token_ids[: len(generated_ids)] |
| and self._answer_fingerprint_from_token_ids(token_ids) |
| in self.answer_fingerprint_hashes |
| ): |
| generated_tail = self._render_token(generated_tokens[-1]) |
| if self._is_structural_punctuation_text( |
| generated_tail |
| ) and not self._is_terminal_punctuation_text(generated_tail): |
| continue |
| final_token = self.embedding_model.id_to_token[token_ids[-1]] |
| if self._is_terminal_punctuation_text(self._render_token(final_token)): |
| continue |
| return True |
| return False |
|
|
| def _answer_sequence_has_continuation( |
| self, |
| generated_tokens: list[str], |
| matches: list[tuple[float, int, int]], |
| ) -> bool: |
| if ( |
| self.embedding_model is None |
| or self.answer_sequence_tokens is None |
| or not generated_tokens |
| or not matches |
| ): |
| return False |
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if not generated_ids: |
| return False |
| for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: |
| if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): |
| continue |
| row = self.answer_sequence_tokens[sequence_index] |
| token_ids = [ |
| int(value) |
| for value in (row.tolist() if hasattr(row, "tolist") else row) |
| if int(value) >= 0 |
| ] |
| if not token_ids: |
| continue |
| next_token_id = self._next_sequence_token_id(token_ids, generated_ids) |
| if next_token_id is None: |
| continue |
| token = self.embedding_model.id_to_token[next_token_id] |
| if self._allowed_answer_sequence_token(token, generated_tokens): |
| return True |
| return False |
|
|
| def _next_sequence_token_id( |
| self, |
| token_ids: list[int], |
| generated_ids: list[int], |
| ) -> int | None: |
| if not generated_ids: |
| return token_ids[0] |
| if len(generated_ids) >= len(token_ids): |
| return None |
| if token_ids[: len(generated_ids)] != generated_ids: |
| return None |
| return token_ids[len(generated_ids)] |
|
|
| def _transition_prior(self, context_tokens: list[str]) -> Vector: |
| prior, _ = self._transition_prior_with_order(context_tokens) |
| return prior |
|
|
| def _transition_prior_with_order( |
| self, |
| context_tokens: list[str], |
| ) -> tuple[Vector, int | None]: |
| assert self.embedding_model is not None |
| if self.transition_id_tables: |
| for order in TRANSITION_ORDERS: |
| if len(context_tokens) < order: |
| continue |
| key_ids: list[int] = [] |
| for token in context_tokens[-order:]: |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| key_ids = [] |
| break |
| key_ids.append(token_id) |
| if not key_ids: |
| continue |
| transitions = self._transition_tensor_lookup(order, key_ids) |
| if transitions is None: |
| transitions = self.transition_id_tables.get(order, {}).get(tuple(key_ids)) |
| if not transitions: |
| continue |
| next_token_ids, probabilities = transitions |
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| for token_id, probability in zip(next_token_ids, probabilities): |
| token_index = int(token_id) |
| if 0 <= token_index < len(prior): |
| prior[token_index] = float(probability) |
| return _normalize_vector(prior), order |
| if not self.transition_tables: |
| return [0.0 for _ in self.embedding_model.id_to_token], None |
|
|
| for order in TRANSITION_ORDERS: |
| if len(context_tokens) < order: |
| continue |
| key = tuple(context_tokens[-order:]) |
| transitions = self.transition_tables.get(order, {}).get(key) |
| if not transitions: |
| continue |
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| for token, probability in transitions.items(): |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is not None: |
| prior[token_id] = probability |
| return _normalize_vector(prior), order |
| return [0.0 for _ in self.embedding_model.id_to_token], None |
|
|
| def _transition_prior_array_with_order( |
| self, |
| context_tokens: list[str], |
| ) -> tuple[object, int | None]: |
| assert np is not None |
| assert self.embedding_model is not None |
| prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| if self.transition_id_tables: |
| for order in TRANSITION_ORDERS: |
| if len(context_tokens) < order: |
| continue |
| key_ids: list[int] = [] |
| for token in context_tokens[-order:]: |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| key_ids = [] |
| break |
| key_ids.append(token_id) |
| if not key_ids: |
| continue |
| transitions = self._transition_tensor_lookup(order, key_ids) |
| if transitions is None: |
| transitions = self.transition_id_tables.get(order, {}).get(tuple(key_ids)) |
| if not transitions: |
| continue |
| next_token_ids, probabilities = transitions |
| token_ids_array = np.asarray(next_token_ids, dtype=np.int64) |
| probabilities_array = np.asarray(probabilities, dtype=np.float64) |
| valid = ( |
| (token_ids_array >= 0) |
| & (token_ids_array < len(self.embedding_model.id_to_token)) |
| & (probabilities_array > 0.0) |
| ) |
| if np.any(valid): |
| prior[token_ids_array[valid]] = probabilities_array[valid] |
| total = float(prior.sum()) |
| if total > 0.0: |
| prior /= total |
| return prior, order |
| return prior, None |
| if not self.transition_tables: |
| return prior, None |
|
|
| for order in TRANSITION_ORDERS: |
| if len(context_tokens) < order: |
| continue |
| key = tuple(context_tokens[-order:]) |
| transitions = self.transition_tables.get(order, {}).get(key) |
| if not transitions: |
| continue |
| for token, probability in transitions.items(): |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is not None: |
| prior[token_id] = probability |
| total = float(prior.sum()) |
| if total > 0.0: |
| prior /= total |
| return prior, order |
| return prior, None |
|
|
| def _copy_prior(self, context_tokens: list[str]) -> Vector: |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
|
|
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| decay = 0.82 |
| answer_start = None |
| for index in range(len(context_tokens) - 1, -1, -1): |
| if context_tokens[index] == "<answer>": |
| answer_start = index + 1 |
| break |
| source_tokens = ( |
| context_tokens[: max(0, answer_start - 1)] |
| if answer_start is not None |
| else context_tokens |
| ) |
| if not source_tokens: |
| return prior |
| for distance, token in enumerate(reversed(source_tokens)): |
| if token in self.tokenizer.special_tokens: |
| continue |
| if not self._eligible_copy_token(token): |
| continue |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| continue |
| prior[token_id] += (decay**distance) * self._copy_token_distinctiveness(token) |
| return _normalize_vector(prior) |
|
|
| def _copy_prior_array(self, context_tokens: list[str]) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
|
|
| prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| decay = 0.82 |
| answer_start = None |
| for index in range(len(context_tokens) - 1, -1, -1): |
| if context_tokens[index] == "<answer>": |
| answer_start = index + 1 |
| break |
| source_tokens = ( |
| context_tokens[: max(0, answer_start - 1)] |
| if answer_start is not None |
| else context_tokens |
| ) |
| for distance, token in enumerate(reversed(source_tokens)): |
| if token in self.tokenizer.special_tokens: |
| continue |
| if not self._eligible_copy_token(token): |
| continue |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| continue |
| prior[token_id] += (decay**distance) * self._copy_token_distinctiveness(token) |
| total = float(prior.sum()) |
| if total > 0.0: |
| prior /= total |
| return prior |
|
|
| def _copy_token_distinctiveness(self, token: str) -> float: |
| rendered = self._render_token(token).strip() |
| if not rendered: |
| return 0.0 |
| letters = sum(character.isalpha() for character in rendered) |
| digits = sum(character.isdigit() for character in rendered) |
| symbols = sum( |
| not character.isalnum() and not character.isspace() |
| for character in rendered |
| ) |
| score = 1.0 |
| if any(character.isupper() for character in rendered) and letters: |
| score += 0.8 |
| if digits: |
| score += 0.9 |
| if symbols: |
| score += 0.5 |
| if len(rendered) >= 4: |
| score += 0.2 |
| return score |
|
|
| def _prompt_copy_evidence_is_distinctive(self, context_tokens: list[str]) -> bool: |
| answer_start = None |
| for index in range(len(context_tokens) - 1, -1, -1): |
| if context_tokens[index] == "<answer>": |
| answer_start = index |
| break |
| prompt_tokens = context_tokens[:answer_start] if answer_start is not None else context_tokens |
| for token in prompt_tokens: |
| if self.tokenizer is not None and token in self.tokenizer.special_tokens: |
| continue |
| rendered = self._render_token(token).strip() |
| if any(character.isdigit() for character in rendered): |
| return True |
| if sum(character.isupper() for character in rendered) >= 2: |
| return True |
| return False |
|
|
| def _source_evidence_prior( |
| self, |
| context_tokens: list[str], |
| generated_tokens: list[str] | None = None, |
| ) -> Vector: |
| assert self.embedding_model is not None |
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| for token_id, weight in self._source_evidence_token_weights( |
| context_tokens, |
| generated_tokens or [], |
| ).items(): |
| if 0 <= token_id < len(prior): |
| prior[token_id] += weight |
| return _normalize_vector(prior) |
|
|
| def _source_evidence_prior_array( |
| self, |
| context_tokens: list[str], |
| generated_tokens: list[str] | None = None, |
| ) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| for token_id, weight in self._source_evidence_token_weights( |
| context_tokens, |
| generated_tokens or [], |
| ).items(): |
| if 0 <= token_id < prior.size: |
| prior[token_id] += weight |
| total = float(prior.sum()) |
| if total > 0.0: |
| prior /= total |
| return prior |
|
|
| def _source_evidence_token_weights( |
| self, |
| context_tokens: list[str], |
| generated_tokens: list[str], |
| ) -> dict[int, float]: |
| if self.embedding_model is None or self.tokenizer is None: |
| return {} |
| segments = self._source_evidence_segments(context_tokens) |
| if not segments: |
| return {} |
|
|
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| first_source_index = _first_index(context_tokens, "<source>") |
| query_tokens = ( |
| context_tokens[:first_source_index] |
| if first_source_index is not None |
| else context_tokens |
| ) |
| query_token_ids = { |
| self.embedding_model.token_to_id[token] |
| for token in query_tokens |
| if token in self.embedding_model.token_to_id |
| and token not in self.tokenizer.special_tokens |
| and self._eligible_copy_token(token) |
| } |
| weights: dict[int, float] = {} |
|
|
| def add_token(token: str, weight: float, *, allow_piece: bool = False) -> None: |
| if token in self.tokenizer.special_tokens: |
| return |
| if not allow_piece and not self._allowed_generation_token(token, generated_tokens): |
| return |
| if allow_piece: |
| rendered = self._render_token(token) |
| if not rendered or not rendered.strip(): |
| return |
| elif not self._eligible_copy_token(token): |
| return |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| return |
| weights[token_id] = weights.get(token_id, 0.0) + weight |
|
|
| for segment_tokens, segment_weight, segment_role in segments[-6:]: |
| if generated_ids and segment_role != "snippet": |
| continue |
| token_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in segment_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| aligned = False |
| if generated_ids and token_ids: |
| max_suffix = min(8, len(generated_ids), len(token_ids)) |
| for suffix_length in range(max_suffix, 0, -1): |
| suffix = generated_ids[-suffix_length:] |
| for index in range(len(token_ids) - suffix_length): |
| if token_ids[index : index + suffix_length] != suffix: |
| continue |
| next_token_id = token_ids[index + suffix_length] |
| next_token = self.embedding_model.id_to_token[next_token_id] |
| add_token( |
| next_token, |
| segment_weight * (3.0 + suffix_length), |
| allow_piece=True, |
| ) |
| aligned = True |
| if aligned: |
| break |
| if aligned: |
| continue |
|
|
| content_rank = 0 |
| anchor_seen = False |
| segment_has_query_anchor = any(token_id in query_token_ids for token_id in token_ids) |
| for token in segment_tokens: |
| rendered = self._render_token(token) |
| if "://" in rendered or rendered.casefold().startswith("http"): |
| continue |
| if not self._eligible_copy_token(token): |
| continue |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| continue |
| if segment_has_query_anchor: |
| in_query = token_id in query_token_ids |
| if in_query: |
| weight = segment_weight * 0.42 |
| anchor_seen = True |
| elif anchor_seen: |
| weight = segment_weight * 2.10 |
| else: |
| weight = segment_weight * 0.32 |
| elif content_rank == 0: |
| weight = segment_weight * 4.0 |
| elif content_rank == 1: |
| weight = segment_weight * 1.35 |
| else: |
| weight = segment_weight * 0.65 |
| weight *= 0.94 ** min(content_rank, 24) |
| add_token(token, weight) |
| content_rank += 1 |
| return weights |
|
|
| def _source_evidence_segments(self, context_tokens: list[str]) -> list[tuple[list[str], float, str]]: |
| if self.tokenizer is None: |
| return [] |
| answer_boundary = _last_index(context_tokens, "<answer>") |
| upper_bound = answer_boundary if answer_boundary is not None else len(context_tokens) |
| boundary_tokens = {"<source>", "<tool_result>", "<tool_call>", "<final>", "<answer>"} |
| segments: list[tuple[list[str], float, str]] = [] |
| index = 0 |
| while index < upper_bound: |
| if context_tokens[index] != "<source>": |
| index += 1 |
| continue |
| start = index + 1 |
| end = start |
| while ( |
| end < upper_bound |
| and context_tokens[end] not in boundary_tokens |
| and self._render_token(context_tokens[end]) != "\n" |
| ): |
| end += 1 |
| source_tokens = context_tokens[start:end] |
| pipe_positions = [ |
| position |
| for position, token in enumerate(source_tokens) |
| if self._render_token(token).strip() == "|" |
| ] |
| if pipe_positions: |
| snippet_tokens = source_tokens[pipe_positions[-1] + 1 :] |
| if snippet_tokens: |
| segments.append((snippet_tokens, 1.0, "snippet")) |
| elif source_tokens: |
| segments.append((source_tokens, 0.90, "snippet")) |
| index = end + 1 |
| return segments |
|
|
| def _source_evidence_is_complete( |
| self, |
| context_tokens: list[str], |
| generated_tokens: list[str], |
| ) -> bool: |
| if ( |
| self.embedding_model is None |
| or self.tokenizer is None |
| or self._generated_word_count(generated_tokens) < 5 |
| ): |
| return False |
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if not generated_ids: |
| return False |
| for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens): |
| if segment_role != "snippet": |
| continue |
| segment_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in segment_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if len(generated_ids) > len(segment_ids): |
| continue |
| max_suffix = min(12, len(generated_ids), len(segment_ids)) |
| for suffix_length in range(max_suffix, 4, -1): |
| suffix_ids = generated_ids[-suffix_length:] |
| for start in range(len(segment_ids) - suffix_length + 1): |
| if segment_ids[start : start + suffix_length] != suffix_ids: |
| continue |
| next_index = start + suffix_length |
| if next_index >= len(segment_ids): |
| return True |
| next_token = self.embedding_model.id_to_token[segment_ids[next_index]] |
| if self._source_punctuation_continues_numeric_span( |
| segment_ids, |
| next_index, |
| ): |
| return False |
| if self._is_terminal_punctuation_text(self._render_token(next_token)): |
| return True |
| return False |
|
|
| def _source_evidence_has_continuation( |
| self, |
| context_tokens: list[str], |
| generated_tokens: list[str], |
| ) -> bool: |
| if self.embedding_model is None or not generated_tokens: |
| return False |
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if not generated_ids: |
| return False |
| for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens): |
| if segment_role != "snippet": |
| continue |
| segment_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in segment_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| max_suffix = min(12, len(generated_ids), len(segment_ids)) |
| for suffix_length in range(max_suffix, 0, -1): |
| suffix_ids = generated_ids[-suffix_length:] |
| for start in range(len(segment_ids) - suffix_length + 1): |
| if segment_ids[start : start + suffix_length] != suffix_ids: |
| continue |
| next_index = start + suffix_length |
| if next_index >= len(segment_ids): |
| return False |
| if self._source_punctuation_continues_numeric_span( |
| segment_ids, |
| next_index, |
| ) or self._source_punctuation_continues_numeric_span( |
| segment_ids, |
| next_index - 1, |
| ): |
| return True |
| next_token = self.embedding_model.id_to_token[segment_ids[next_index]] |
| return not self._is_terminal_punctuation_text( |
| self._render_token(next_token) |
| ) |
| return False |
|
|
| def _source_evidence_next_token( |
| self, |
| context_tokens: list[str], |
| generated_tokens: list[str], |
| ) -> str | None: |
| if self.embedding_model is None: |
| return None |
| for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens): |
| if segment_role != "snippet" or not segment_tokens: |
| continue |
| if not generated_tokens: |
| return segment_tokens[0] |
| segment_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in segment_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if not segment_ids or not generated_ids: |
| continue |
| max_suffix = min(12, len(generated_ids), len(segment_ids)) |
| for suffix_length in range(max_suffix, 0, -1): |
| suffix_ids = generated_ids[-suffix_length:] |
| for start in range(len(segment_ids) - suffix_length + 1): |
| if segment_ids[start : start + suffix_length] != suffix_ids: |
| continue |
| next_index = start + suffix_length |
| if next_index < len(segment_ids): |
| return self.embedding_model.id_to_token[segment_ids[next_index]] |
| return None |
|
|
| def _source_punctuation_continues_numeric_span( |
| self, |
| segment_ids: list[int], |
| punctuation_index: int, |
| ) -> bool: |
| if self.embedding_model is None: |
| return False |
| if punctuation_index <= 0 or punctuation_index + 1 >= len(segment_ids): |
| return False |
| punctuation_text = self._render_token( |
| self.embedding_model.id_to_token[segment_ids[punctuation_index]] |
| ).strip() |
| if not self._is_structural_punctuation_text(punctuation_text): |
| return False |
| previous_text = self._render_token( |
| self.embedding_model.id_to_token[segment_ids[punctuation_index - 1]] |
| ) |
| next_text = self._render_token( |
| self.embedding_model.id_to_token[segment_ids[punctuation_index + 1]] |
| ) |
| return any(character.isdigit() for character in previous_text) and any( |
| character.isdigit() for character in next_text |
| ) |
|
|
| def _preference_prior(self) -> Vector: |
| assert self.embedding_model is not None |
| if not self.preference_bias or not any(value != 0.0 for value in self.preference_bias): |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| eligible_indices = [ |
| index |
| for index, token in enumerate(self.embedding_model.id_to_token) |
| if self.preference_bias[index] > 0.0 and self._eligible_preference_token(token) |
| ] |
| if not eligible_indices: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| eligible_probabilities = self._calibrated_softmax( |
| [self.preference_bias[index] for index in eligible_indices] |
| ) |
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| for index, probability in zip(eligible_indices, eligible_probabilities): |
| prior[index] = probability |
| return prior |
|
|
| def _preference_prior_array(self) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| if self.preference_bias_array is None or not np.any(self.preference_bias_array != 0.0): |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| if self.preference_valid_mask_array is None or not np.any(self.preference_valid_mask_array): |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| positive_mask = self.preference_bias_array > 0.0 |
| active_mask = self.preference_valid_mask_array & positive_mask |
| if not np.any(active_mask): |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| prior[active_mask] = self._calibrated_softmax_array( |
| self.preference_bias_array[active_mask] |
| ) |
| return prior |
|
|
| def _eligible_preference_token(self, token: str) -> bool: |
| assert self.tokenizer is not None |
| if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens: |
| return False |
| if not self._starts_new_word(token): |
| return False |
| rendered = self._render_token(token) |
| if not rendered.strip() or self._is_punctuation_piece(rendered): |
| return False |
| alphanumeric = "".join(character for character in rendered if character.isalnum()) |
| return len(alphanumeric) >= 1 |
|
|
| def _build_transition_tables( |
| self, |
| tokens: list[str], |
| ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: |
| counts: dict[int, dict[tuple[str, ...], dict[str, int]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
| for order in sorted(TRANSITION_ORDERS): |
| for index in range(order - 1, len(tokens) - 1): |
| key = tuple(tokens[index - order + 1 : index + 1]) |
| nxt = tokens[index + 1] |
| bucket = counts[order].setdefault(key, {}) |
| bucket[nxt] = bucket.get(nxt, 0) + 1 |
|
|
| probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
| for order, mapping in counts.items(): |
| items = list(mapping.items()) |
| items.sort(key=lambda item: (-sum(item[1].values()), item[0])) |
| if ( |
| self.config.max_transition_contexts_per_order is not None |
| and self.config.max_transition_contexts_per_order >= 0 |
| ): |
| items = items[: self.config.max_transition_contexts_per_order] |
| for key, bucket in items: |
| next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0])) |
| if self.config.max_transition_next_tokens > 0: |
| next_items = next_items[: self.config.max_transition_next_tokens] |
| total = sum(value for _, value in next_items) |
| if total <= 0: |
| continue |
| probabilities[order][key] = { |
| token: value / total |
| for token, value in next_items |
| } |
| return probabilities |
|
|
| def _transition_table_tensors(self) -> dict[str, object]: |
| assert self.embedding_model is not None |
| if self.transition_tensor_cache is not None: |
| return { |
| "transition_orders": self.transition_tensor_cache["orders"], |
| "transition_key_offsets": self.transition_tensor_cache["key_offsets"], |
| "transition_key_token_ids": self.transition_tensor_cache["key_token_ids"], |
| "transition_next_offsets": self.transition_tensor_cache["next_offsets"], |
| "transition_next_token_ids": self.transition_tensor_cache["next_token_ids"], |
| "transition_next_probabilities": self.transition_tensor_cache["next_probabilities"], |
| } |
| if not self.transition_tables: |
| return { |
| "transition_orders": [], |
| "transition_key_offsets": [0], |
| "transition_key_token_ids": [], |
| "transition_next_offsets": [0], |
| "transition_next_token_ids": [], |
| "transition_next_probabilities": [], |
| } |
| token_to_id = self.embedding_model.token_to_id |
| orders: list[int] = [] |
| key_offsets: list[int] = [0] |
| key_token_ids: list[int] = [] |
| next_offsets: list[int] = [0] |
| next_token_ids: list[int] = [] |
| next_probabilities: list[float] = [] |
| for order in sorted(self.transition_tables): |
| mapping = self.transition_tables.get(order, {}) |
| for key, transitions in mapping.items(): |
| key_ids = [token_to_id.get(token, -1) for token in key] |
| if len(key_ids) != order or any(token_id < 0 for token_id in key_ids): |
| continue |
| next_items = [ |
| (token_to_id[token], float(probability)) |
| for token, probability in transitions.items() |
| if token in token_to_id and probability > 0.0 |
| ] |
| if not next_items: |
| continue |
| orders.append(order) |
| key_token_ids.extend(key_ids) |
| key_offsets.append(len(key_token_ids)) |
| for token_id, probability in next_items: |
| next_token_ids.append(token_id) |
| next_probabilities.append(probability) |
| next_offsets.append(len(next_token_ids)) |
| return { |
| "transition_orders": orders, |
| "transition_key_offsets": key_offsets, |
| "transition_key_token_ids": key_token_ids, |
| "transition_next_offsets": next_offsets, |
| "transition_next_token_ids": next_token_ids, |
| "transition_next_probabilities": next_probabilities, |
| } |
|
|
| def _deserialize_transition_id_tables_from_tensors( |
| self, |
| tensors: dict[str, object], |
| ) -> dict[int, dict[tuple[int, ...], tuple[object, object]]] | None: |
| required = ( |
| "transition_orders", |
| "transition_key_offsets", |
| "transition_key_token_ids", |
| "transition_next_offsets", |
| "transition_next_token_ids", |
| "transition_next_probabilities", |
| ) |
| if any(name not in tensors for name in required): |
| return None |
|
|
| def _as_sequence(name: str) -> object: |
| value = tensors.get(name, []) |
| return value if hasattr(value, "shape") else list(value) |
|
|
| orders = _as_sequence("transition_orders") |
| key_offsets = _as_sequence("transition_key_offsets") |
| key_token_ids = _as_sequence("transition_key_token_ids") |
| next_offsets = _as_sequence("transition_next_offsets") |
| next_token_ids = _as_sequence("transition_next_token_ids") |
| next_probabilities = _as_sequence("transition_next_probabilities") |
| row_count = len(orders) |
| if row_count == 0: |
| return {order: {} for order in sorted(TRANSITION_ORDERS)} |
| if len(key_offsets) != row_count + 1 or len(next_offsets) != row_count + 1: |
| return None |
| if np is not None and hasattr(orders, "shape"): |
| self.transition_tensor_cache = { |
| "orders": orders, |
| "key_offsets": key_offsets, |
| "key_token_ids": key_token_ids, |
| "next_offsets": next_offsets, |
| "next_token_ids": next_token_ids, |
| "next_probabilities": next_probabilities, |
| "order_spans": {}, |
| } |
| self.transition_built_orders = set() |
| return {order: {} for order in sorted(TRANSITION_ORDERS)} |
| tables: dict[int, dict[tuple[int, ...], tuple[object, object]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
| for index in range(row_count): |
| order = int(orders[index]) |
| key_start = int(key_offsets[index]) |
| key_end = int(key_offsets[index + 1]) |
| next_start = int(next_offsets[index]) |
| next_end = int(next_offsets[index + 1]) |
| key = tuple(int(token_id) for token_id in key_token_ids[key_start:key_end]) |
| if len(key) != order or next_end <= next_start: |
| continue |
| tables.setdefault(order, {})[key] = ( |
| next_token_ids[next_start:next_end], |
| next_probabilities[next_start:next_end], |
| ) |
| return tables |
|
|
| def _serialize_transition_tables(self) -> dict[str, dict[str, dict[str, float]]]: |
| assert self.transition_tables is not None |
| return { |
| str(order): { |
| _encode_ngram_key(key): value |
| for key, value in mapping.items() |
| } |
| for order, mapping in self.transition_tables.items() |
| } |
|
|
| def _deserialize_transition_tables( |
| self, |
| payload: dict[str, dict[str, dict[str, float]]], |
| ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: |
| tables: dict[int, dict[tuple[str, ...], dict[str, float]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
| for order_text, mapping in payload.items(): |
| order = int(order_text) |
| tables[order] = { |
| _decode_ngram_key(key): { |
| str(token): float(probability) |
| for token, probability in value.items() |
| } |
| for key, value in mapping.items() |
| } |
| return tables |
|
|
| def _transition_tensor_order_span(self, order: int) -> tuple[int, int] | None: |
| if np is None or self.transition_tensor_cache is None: |
| return None |
| spans = self.transition_tensor_cache.get("order_spans") |
| if isinstance(spans, dict) and order in spans: |
| return spans[order] |
| orders = self.transition_tensor_cache["orders"] |
| positions = np.flatnonzero(orders == order) |
| span = ( |
| (int(positions[0]), int(positions[-1]) + 1) |
| if positions.size |
| else None |
| ) |
| if isinstance(spans, dict): |
| spans[order] = span |
| return span |
|
|
| def _transition_tensor_lookup( |
| self, |
| order: int, |
| key_ids: list[int], |
| ) -> tuple[object, object] | None: |
| if ( |
| np is None |
| or self.transition_tensor_cache is None |
| or len(key_ids) != order |
| ): |
| return None |
| span = self._transition_tensor_order_span(order) |
| if span is None: |
| return None |
| row_start, row_end = span |
| key_offsets = self.transition_tensor_cache["key_offsets"] |
| key_token_ids = self.transition_tensor_cache["key_token_ids"] |
| next_offsets = self.transition_tensor_cache["next_offsets"] |
| next_token_ids = self.transition_tensor_cache["next_token_ids"] |
| next_probabilities = self.transition_tensor_cache["next_probabilities"] |
| key_start = int(key_offsets[row_start]) |
| key_end = int(key_offsets[row_end]) |
| key_block = np.asarray(key_token_ids[key_start:key_end], dtype=np.int64) |
| row_count = row_end - row_start |
| if row_count <= 0 or key_block.size != row_count * order: |
| return None |
| keys = key_block.reshape(row_count, order) |
| query = np.asarray(key_ids, dtype=np.int64) |
| matches = np.flatnonzero(np.all(keys == query[None, :], axis=1)) |
| if not matches.size: |
| return None |
| row = row_start + int(matches[0]) |
| next_start = int(next_offsets[row]) |
| next_end = int(next_offsets[row + 1]) |
| if next_end <= next_start: |
| return None |
| return ( |
| next_token_ids[next_start:next_end], |
| next_probabilities[next_start:next_end], |
| ) |
|
|
| def _eligible_copy_token(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| if not rendered.strip(): |
| return False |
| if self._is_punctuation_piece(rendered): |
| return False |
| if not self._starts_new_word(token): |
| return False |
| alphanumeric = "".join(character for character in rendered if character.isalnum()) |
| return len(alphanumeric) >= 2 |
|
|
| def _allowed_generation_token( |
| self, |
| token: str, |
| generated_tokens: list[str], |
| context_tokens: list[str] | None = None, |
| ) -> bool: |
| return self._allowed_generation_token_with_meta( |
| token, |
| self._generation_token_meta(token), |
| generated_tokens, |
| context_tokens, |
| ) |
|
|
| def _allowed_generation_token_with_meta( |
| self, |
| token: str, |
| meta: GenerationTokenMeta, |
| generated_tokens: list[str], |
| context_tokens: list[str] | None = None, |
| ) -> bool: |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
| if token == self.tokenizer.unk_token: |
| return False |
| if token in self.tokenizer.special_tokens: |
| return self._allowed_tool_protocol_token( |
| token, |
| generated_tokens=generated_tokens, |
| context_tokens=context_tokens or [], |
| ) |
| if len(self.embedding_model.id_to_token) < 1024: |
| return True |
| if meta.rendered == "\n": |
| return bool(generated_tokens) |
| if not meta.stripped: |
| return False |
| if meta.word_joiner: |
| return ( |
| self._can_attach_word_joiner(generated_tokens) |
| or self._can_start_line_with_word_joiner(token, generated_tokens) |
| ) |
| if meta.structural_punctuation: |
| return bool(generated_tokens) or self._can_start_answer_with_structural_punctuation(token) |
| if meta.structural_symbol: |
| return bool(generated_tokens) or meta.starts_new_word |
| if not meta.starts_new_word: |
| if not generated_tokens: |
| return False |
| previous_rendered = self._render_token(generated_tokens[-1]) |
| return ( |
| bool(previous_rendered) |
| and any(character.isalnum() for character in previous_rendered) |
| and bool(meta.alphanumeric) |
| ) |
| return len(meta.alphanumeric) >= 1 or not meta.punctuation_piece |
|
|
| @staticmethod |
| def _allowed_tool_protocol_token( |
| token: str, |
| *, |
| generated_tokens: list[str], |
| context_tokens: list[str], |
| ) -> bool: |
| if token not in TOOL_PROTOCOL_TOKENS: |
| return False |
| if token == "<tool_call>": |
| return ( |
| ReframrModel._context_requests_tool_call(context_tokens) |
| and |
| "<tool_call>" not in generated_tokens |
| and "<tool_result>" not in generated_tokens |
| and "<source>" not in generated_tokens |
| ) |
| if token in {"<tool_result>", "<source>"}: |
| return False |
| if token == "<final>": |
| return ( |
| "<tool_result>" in context_tokens |
| or "<source>" in context_tokens |
| or "<final>" in context_tokens |
| ) |
| return True |
|
|
| @staticmethod |
| def _context_requests_tool_call(context_tokens: list[str]) -> bool: |
| rendered_terms: list[str] = [] |
| for token in context_tokens: |
| if token in TOOL_PROTOCOL_TOKENS or token.startswith("<"): |
| continue |
| normalized = token.replace("▁", " ").strip().casefold() |
| if not normalized: |
| continue |
| rendered_terms.append(normalized) |
| pieces = { |
| "".join( |
| character |
| for character in piece |
| if character.isalnum() or character in {"-", "."} |
| ) |
| for piece in normalized.split() |
| } |
| if pieces & TOOL_CALL_CONTEXT_TERMS: |
| return True |
| joined = " ".join(rendered_terms) |
| compact = "".join(rendered_terms) |
| return any( |
| term in joined or term.replace("-", "") in compact |
| for term in TOOL_CALL_CONTEXT_TERMS |
| ) |
|
|
| def _would_repeat_recent_pattern( |
| self, |
| candidate: str, |
| generated_tokens: list[str], |
| recent_rendered_words: list[str] | None = None, |
| ) -> bool: |
| if len(generated_tokens) >= 2 and generated_tokens[-1] == candidate and generated_tokens[-2] == candidate: |
| return True |
|
|
| if len(generated_tokens) >= 2: |
| trigram = tuple(generated_tokens[-2:] + [candidate]) |
| recent_tokens = generated_tokens[-12:] |
| for index in range(max(0, len(recent_tokens) - 4)): |
| if tuple(recent_tokens[index : index + 3]) == trigram: |
| return True |
|
|
| rendered_words = recent_rendered_words |
| if rendered_words is None: |
| rendered_words = self._recent_rendered_words(generated_tokens) |
| candidate_meta = self._generation_token_meta(candidate) |
| candidate_word = candidate_meta.rendered.casefold() |
| if ( |
| rendered_words |
| and candidate_meta.starts_new_word |
| and any(character.isalnum() for character in candidate_word) |
| ): |
| candidate_bigram = (rendered_words[-1], candidate_word) |
| recent_window = rendered_words[-10:] |
| recent_bigrams = { |
| (recent_window[index], recent_window[index + 1]) |
| for index in range(len(recent_window) - 1) |
| } |
| if candidate_bigram in recent_bigrams: |
| return True |
| if ( |
| len(candidate_word) > 2 |
| and rendered_words[-10:].count(candidate_word) >= 2 |
| and not candidate_meta.common_connector |
| ): |
| return True |
|
|
| return False |
|
|
| @staticmethod |
| def _is_inside_tool_protocol_continuation(generated_tokens: list[str]) -> bool: |
| return any(token in TOOL_PROTOCOL_TOKENS for token in generated_tokens[-6:]) |
|
|
| def _would_repeat_recent_phrase( |
| self, |
| candidate: str, |
| generated_tokens: list[str], |
| *, |
| recent_rendered_words: list[str] | None = None, |
| ) -> bool: |
| if not self._starts_new_word(candidate): |
| return False |
| rendered_words = list( |
| recent_rendered_words |
| if recent_rendered_words is not None |
| else self._recent_rendered_words(generated_tokens) |
| ) |
| candidate_word = self._render_token(candidate).casefold() |
| if not any(character.isalnum() for character in candidate_word): |
| return False |
| rendered_words.append(candidate_word) |
| recent_window = rendered_words[-48:] |
| for span in range(4, min(8, len(recent_window)) + 1): |
| suffix = tuple(recent_window[-span:]) |
| earlier = recent_window[:-span] |
| for index in range(len(earlier) - span + 1): |
| if tuple(earlier[index : index + span]) == suffix: |
| return True |
| return False |
|
|
| def _recent_phrase_repeat_candidate_words( |
| self, |
| recent_rendered_words: list[str], |
| ) -> set[str]: |
| repeat_candidates: set[str] = set() |
| base_window = recent_rendered_words[-47:] |
| max_span = min(8, len(base_window) + 1) |
| if max_span < 4: |
| return repeat_candidates |
| for span in range(4, max_span + 1): |
| prefix_length = span - 1 |
| suffix_prefix = tuple(base_window[-prefix_length:]) |
| earlier_length = len(base_window) - prefix_length |
| if earlier_length < span: |
| continue |
| for index in range(earlier_length - span + 1): |
| earlier_segment = base_window[index : index + span] |
| if tuple(earlier_segment[:-1]) == suffix_prefix: |
| candidate_word = earlier_segment[-1] |
| if any(character.isalnum() for character in candidate_word): |
| repeat_candidates.add(candidate_word) |
| return repeat_candidates |
|
|
| def _recent_rendered_words(self, generated_tokens: list[str]) -> list[str]: |
| rendered_words: list[str] = [] |
| for token in generated_tokens: |
| if not self._starts_new_word(token): |
| continue |
| rendered = self._render_token(token).casefold() |
| if any(character.isalnum() for character in rendered): |
| rendered_words.append(rendered) |
| return rendered_words |
|
|
| def _select_generation_token( |
| self, |
| distribution: dict[str, float], |
| *, |
| context_tokens: list[str] | None = None, |
| generated_tokens: list[str] | None = None, |
| temperature: float = DEFAULT_GENERATION_TEMPERATURE, |
| top_k: int = DEFAULT_GENERATION_TOP_K, |
| top_p: float = DEFAULT_GENERATION_TOP_P, |
| repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
| preserve_dominant_candidates: bool = False, |
| avoid_token_sequences: Sequence[Sequence[str]] | None = None, |
| ) -> str: |
| assert self.tokenizer is not None |
| generated_tokens = generated_tokens or [] |
| candidates = self._prepare_generation_candidates( |
| distribution, |
| context_tokens=context_tokens or [], |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=preserve_dominant_candidates, |
| avoid_token_sequences=avoid_token_sequences, |
| ) |
| if candidates: |
| return self._sample_generation_candidate( |
| candidates, |
| context_tokens=context_tokens or [], |
| generated_tokens=generated_tokens, |
| stochastic=temperature > 0.0, |
| preserve_dominant_candidates=preserve_dominant_candidates, |
| ) |
|
|
| for token, _ in sorted(distribution.items(), key=lambda item: item[1], reverse=True): |
| if token in self.tokenizer.special_tokens and token not in TOOL_PROTOCOL_TOKENS: |
| continue |
| if token == self.tokenizer.unk_token: |
| continue |
| if not self._allowed_generation_token(token, generated_tokens, context_tokens or []): |
| continue |
| if self._would_complete_blacklisted_answer(generated_tokens, token): |
| continue |
| return token |
| return "" |
|
|
| def _select_generation_token_from_array( |
| self, |
| probabilities: object, |
| *, |
| context_tokens: list[str], |
| generated_tokens: list[str], |
| temperature: float = DEFAULT_GENERATION_TEMPERATURE, |
| top_k: int = DEFAULT_GENERATION_TOP_K, |
| top_p: float = DEFAULT_GENERATION_TOP_P, |
| repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
| preserve_dominant_candidates: bool = False, |
| avoid_token_sequences: Sequence[Sequence[str]] | None = None, |
| ) -> str: |
| assert np is not None |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
|
|
| values = np.asarray(probabilities, dtype=np.float64) |
| if values.size == 0: |
| return "" |
| first_pool_size = min(values.size, max(top_k, 64)) |
| if first_pool_size <= 0: |
| first_pool_size = min(values.size, 64) |
| expanded_pool_size = min(values.size, max(top_k * 4, 64)) |
| pool_sizes: list[int] = [] |
| for pool_size in (first_pool_size, expanded_pool_size, values.size): |
| if pool_size > 0 and pool_size not in pool_sizes: |
| pool_sizes.append(pool_size) |
|
|
| for pool_size in pool_sizes: |
| if pool_size < values.size: |
| candidate_indices = np.argpartition(values, -pool_size)[-pool_size:] |
| candidate_indices = candidate_indices[np.argsort(values[candidate_indices])[::-1]] |
| else: |
| candidate_indices = np.argsort(values)[::-1] |
|
|
| distribution: dict[str, float] = {} |
| for raw_index in candidate_indices: |
| index = int(raw_index) |
| score = float(values[index]) |
| if score <= 0.0: |
| continue |
| token = self.embedding_model.id_to_token[index] |
| if ( |
| token == self.tokenizer.unk_token |
| or token in self.tokenizer.special_tokens |
| and token not in TOOL_PROTOCOL_TOKENS |
| ): |
| continue |
| distribution[token] = score |
| selected = self._select_generation_token( |
| distribution, |
| context_tokens=context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=preserve_dominant_candidates, |
| avoid_token_sequences=avoid_token_sequences, |
| ) |
| if selected: |
| return selected |
| return "" |
|
|
| def _prepare_generation_candidates( |
| self, |
| distribution: dict[str, float], |
| *, |
| context_tokens: list[str] | None = None, |
| generated_tokens: list[str], |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| preserve_dominant_candidates: bool = False, |
| avoid_token_sequences: Sequence[Sequence[str]] | None = None, |
| ) -> list[tuple[str, float]]: |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| context_tokens = context_tokens or [] |
|
|
| generated_word_count = self._generated_word_count(generated_tokens) |
| clause_words = self._words_since_clause_break(generated_tokens) |
| recent_rendered_words = self._recent_rendered_words(generated_tokens) |
| generated_token_ids = self._token_ids_for_generated_tokens(generated_tokens) |
| inside_tool_protocol = self._is_inside_tool_protocol_continuation(generated_tokens) |
| phrase_repeat_candidate_words = ( |
| self._recent_phrase_repeat_candidate_words(recent_rendered_words) |
| if generated_word_count >= MIN_COMPLETE_ANSWER_WORDS and not inside_tool_protocol |
| else set() |
| ) |
| prompt_content_tokens = [ |
| token |
| for token in context_tokens |
| if token not in self.tokenizer.special_tokens |
| and self._generation_token_meta(token).starts_new_word |
| and self._generation_token_meta(token).alphanumeric |
| and not self._generation_token_meta(token).punctuation_piece |
| ] |
| initial_prompt_content_token = ( |
| prompt_content_tokens[0] |
| if len(prompt_content_tokens) > 1 |
| else None |
| ) |
| best_probability = max(distribution.values(), default=0.0) |
| has_uppercase_start_candidate = any( |
| probability > 0.0 |
| and self._generation_token_meta(token).starts_new_word |
| and self._generation_token_meta(token).rendered[:1].isupper() |
| for token, probability in distribution.items() |
| ) |
| adjusted: list[tuple[str, float]] = [] |
| for token, probability in sorted(distribution.items(), key=lambda item: item[1], reverse=True): |
| if token in self.tokenizer.special_tokens and token not in TOOL_PROTOCOL_TOKENS: |
| continue |
| if token == self.tokenizer.unk_token or probability <= 0.0: |
| continue |
| meta = self._generation_token_meta(token) |
| allowed_by_general_filter = self._allowed_generation_token_with_meta( |
| token, |
| meta, |
| generated_tokens, |
| context_tokens, |
| ) |
| if not allowed_by_general_filter: |
| dominant_learned_continuation = ( |
| preserve_dominant_candidates |
| and best_probability > 0.0 |
| and probability >= best_probability * 0.99 |
| and self._allowed_answer_sequence_token(token, generated_tokens) |
| ) |
| if not dominant_learned_continuation: |
| continue |
| if self._would_complete_blacklisted_answer_ids(generated_token_ids, token): |
| continue |
| repeats_recent_pattern = self._would_repeat_recent_pattern( |
| token, |
| generated_tokens, |
| recent_rendered_words=recent_rendered_words, |
| ) |
| hard_phrase_loop = ( |
| generated_word_count >= MIN_COMPLETE_ANSWER_WORDS |
| and not inside_tool_protocol |
| and meta.starts_new_word |
| and meta.rendered.casefold() in phrase_repeat_candidate_words |
| ) |
| if hard_phrase_loop: |
| continue |
| if repeats_recent_pattern: |
| dominant_candidate_allowed = ( |
| preserve_dominant_candidates |
| and best_probability > 0.0 |
| and probability >= best_probability * 0.80 |
| ) |
| if not dominant_candidate_allowed: |
| continue |
|
|
| score = probability |
| if ( |
| temperature >= ANSWER_REPLAY_PREFIX_TEMPERATURE |
| and not inside_tool_protocol |
| and self._would_follow_blacklisted_answer_prefix_ids( |
| generated_token_ids, |
| token, |
| ) |
| ): |
| score *= ANSWER_REPLAY_PREFIX_PENALTY |
| if ( |
| temperature > 0.0 |
| and self._would_follow_avoided_sequence( |
| generated_tokens, |
| token, |
| avoid_token_sequences, |
| ) |
| ): |
| score *= 0.12 |
| rendered = meta.rendered |
| punctuation_token = meta.structural_punctuation |
| starts_new_word = meta.starts_new_word |
| alphanumeric = meta.alphanumeric |
| if ( |
| not generated_tokens |
| and initial_prompt_content_token is not None |
| and token == initial_prompt_content_token |
| ): |
| dominant_answer_candidate = ( |
| preserve_dominant_candidates |
| and best_probability > 0.0 |
| and probability >= best_probability * 0.80 |
| ) |
| if not dominant_answer_candidate: |
| continue |
| if ( |
| not generated_tokens |
| and temperature > 0.0 |
| and has_uppercase_start_candidate |
| and starts_new_word |
| and rendered[:1].islower() |
| and best_probability > 0.0 |
| and probability < best_probability * 0.85 |
| ): |
| continue |
| if generated_tokens and starts_new_word and alphanumeric: |
| previous_alphanumeric = self._generation_token_meta( |
| generated_tokens[-1] |
| ).alphanumeric |
| if previous_alphanumeric.casefold() == alphanumeric.casefold(): |
| continue |
| common_connector = meta.common_connector |
| if ( |
| starts_new_word |
| and len(alphanumeric) == 1 |
| and not common_connector |
| ): |
| score *= 0.08 |
| recent_count = generated_tokens[-12:].count(token) |
| if recent_count > 0 and not common_connector: |
| score /= repetition_penalty ** (2 * recent_count) |
| if generated_tokens and token == generated_tokens[-1]: |
| score /= repetition_penalty**3 |
| if generated_tokens and token in generated_tokens[-4:] and not common_connector: |
| score *= 0.35 |
| if generated_tokens and not starts_new_word and self._starts_new_word(generated_tokens[-1]): |
| score *= 0.08 |
| if not generated_tokens and punctuation_token: |
| if best_probability <= 0.0 or probability < best_probability * 0.80: |
| score *= 0.01 |
| elif not generated_tokens and not starts_new_word: |
| score *= 0.02 |
| if ( |
| not generated_tokens |
| and temperature > 0.0 |
| and has_uppercase_start_candidate |
| and starts_new_word |
| and rendered[:1].islower() |
| ): |
| score *= 0.03 |
| if punctuation_token: |
| if generated_tokens and self._is_structural_punctuation_token(generated_tokens[-1]): |
| score *= 0.05 |
| if clause_words >= 6: |
| score *= 1.0 + min(1.4, 0.18 * (clause_words - 5)) |
| elif generated_word_count >= 12: |
| score *= 1.1 |
| if score > 0.0: |
| adjusted.append((token, score)) |
|
|
| if not adjusted: |
| return [] |
| adjusted.sort(key=lambda item: item[1], reverse=True) |
| if preserve_dominant_candidates: |
| top_score = adjusted[0][1] |
| second_score = adjusted[1][1] if len(adjusted) > 1 else 0.0 |
| if top_score >= 0.5 and ( |
| second_score <= 0.0 |
| or top_score >= second_score * 1.2 |
| or top_score - second_score >= 0.08 |
| ): |
| return [(adjusted[0][0], 1.0)] |
| effective_top_k = top_k |
| if ( |
| temperature >= CREATIVE_EARLY_POOL_TEMPERATURE |
| and generated_word_count < CREATIVE_EARLY_POOL_WORD_LIMIT |
| and not inside_tool_protocol |
| and top_k > CREATIVE_EARLY_POOL_MAX |
| ): |
| effective_top_k = CREATIVE_EARLY_POOL_MAX |
| if effective_top_k > 0: |
| adjusted = adjusted[:effective_top_k] |
| if 0.0 < top_p < 1.0: |
| kept: list[tuple[str, float]] = [] |
| cumulative = 0.0 |
| total = sum(score for _, score in adjusted) |
| for token, score in adjusted: |
| normalized = score / total if total else 0.0 |
| kept.append((token, score)) |
| cumulative += normalized |
| if cumulative >= top_p: |
| break |
| adjusted = kept |
|
|
| if temperature <= 0.0: |
| return [(adjusted[0][0], 1.0)] |
|
|
| exponent = 1.0 / temperature |
| tempered = [ |
| (token, score**exponent) |
| for token, score in adjusted |
| if score > 0.0 |
| ] |
| total = sum(score for _, score in tempered) |
| if total <= 0.0: |
| return [] |
| return [(token, score / total) for token, score in tempered] |
|
|
| def _sample_generation_candidate( |
| self, |
| candidates: list[tuple[str, float]], |
| *, |
| context_tokens: list[str], |
| generated_tokens: list[str], |
| stochastic: bool = False, |
| preserve_dominant_candidates: bool = False, |
| ) -> str: |
| if not candidates: |
| return "" |
| if len(candidates) == 1: |
| return candidates[0][0] |
| top_probability = candidates[0][1] |
| second_probability = candidates[1][1] |
| top_has_clear_half_majority = top_probability >= 0.5 and ( |
| second_probability <= 0.0 |
| or top_probability - second_probability >= 0.02 |
| ) |
| if preserve_dominant_candidates and top_has_clear_half_majority: |
| return candidates[0][0] |
| decisive_stochastic_winner = stochastic and ( |
| top_probability >= 0.985 |
| or ( |
| top_probability >= 0.96 |
| and second_probability > 0.0 |
| and top_probability >= second_probability * 20.0 |
| ) |
| or ( |
| top_probability >= 0.90 |
| and second_probability > 0.0 |
| and top_probability >= second_probability * 40.0 |
| ) |
| or ( |
| top_probability >= 0.90 |
| and top_probability - second_probability >= 0.75 |
| ) |
| ) |
| decisive_deterministic_winner = not stochastic and ( |
| top_has_clear_half_majority |
| or (second_probability > 0.0 and top_probability >= second_probability * 2.5) |
| or ( |
| top_probability >= 0.08 |
| and second_probability > 0.0 |
| and top_probability >= second_probability * 1.35 |
| ) |
| ) |
| if decisive_stochastic_winner or decisive_deterministic_winner: |
| return candidates[0][0] |
| if stochastic: |
| threshold = random.random() |
| else: |
| seed_payload = "\u0002".join([*context_tokens, "<generated>", *generated_tokens, str(len(candidates))]) |
| seed = int.from_bytes(hashlib.sha256(seed_payload.encode("utf-8")).digest()[:8], "big") |
| threshold = random.Random(seed).random() |
| cumulative = 0.0 |
| for token, probability in candidates: |
| cumulative += probability |
| if threshold <= cumulative: |
| return token |
| return candidates[-1][0] |
|
|
| def _top_entries_from_vector( |
| self, |
| values: Vector, |
| limit: int, |
| ) -> list[dict[str, object]]: |
| if limit <= 0: |
| return [] |
| ranked = sorted( |
| enumerate(values), |
| key=lambda item: item[1], |
| reverse=True, |
| ) |
| return [ |
| self._token_entry(index, probability) |
| for index, probability in ranked[:limit] |
| if probability > 0.0 |
| ] |
|
|
| def _token_entry( |
| self, |
| index: int, |
| probability: float, |
| ) -> dict[str, object]: |
| assert self.embedding_model is not None |
| token = self.embedding_model.id_to_token[index] |
| return { |
| "token": token, |
| "text": self._render_token(token), |
| "probability": probability, |
| } |
|
|
| def _build_reasoning_summary( |
| self, |
| transition_order: int | None, |
| blend_weights: dict[str, float], |
| ) -> str: |
| dominant_source = max(blend_weights.items(), key=lambda item: item[1])[0] if blend_weights else "base" |
| if transition_order is not None: |
| transition_message = f" Transition prior is using order-{transition_order} context." |
| else: |
| transition_message = " Transition prior found no matching n-gram." |
|
|
| return ( |
| "Generation is running on analytical state, recurrent traces, and corpus-derived token transitions." |
| f"{transition_message}" |
| f" Dominant blend source: {dominant_source}." |
| ) |
|
|
| def _generated_word_count(self, tokens: list[str]) -> int: |
| count = 0 |
| for token in tokens: |
| rendered = self._render_token(token) |
| if not any(character.isalnum() for character in rendered): |
| continue |
| if self._starts_new_word(token) or count == 0: |
| count += 1 |
| return count |
|
|
| def _is_structural_punctuation_text(self, text: str) -> bool: |
| if len(text) != 1: |
| return False |
| if self._is_word_joiner_text(text): |
| return False |
| category = unicodedata.category(text) |
| return category.startswith("P") |
|
|
| def _is_structural_punctuation_token(self, token: str) -> bool: |
| return self._is_structural_punctuation_text(self._render_token(token)) |
|
|
| def _is_structural_symbol_token(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| return len(rendered) == 1 and unicodedata.category(rendered).startswith("S") |
|
|
| def _is_word_joiner_token(self, token: str) -> bool: |
| return self._is_word_joiner_text(self._render_token(token)) |
|
|
| def _is_word_joiner_text(self, text: str) -> bool: |
| if len(text) != 1: |
| return False |
| category = unicodedata.category(text) |
| if category in ("Pc", "Pd", "Lm"): |
| return True |
| name = unicodedata.name(text, "") |
| return "APOSTROPHE" in name or ( |
| "SINGLE" in name and "QUOTATION MARK" in name |
| ) |
|
|
| def _can_start_line_with_word_joiner(self, token: str, generated_tokens: list[str]) -> bool: |
| rendered = self._render_token(token) |
| if len(rendered) != 1 or unicodedata.category(rendered) != "Pd": |
| return False |
| if not self._starts_new_word(token): |
| return False |
| return not generated_tokens or self._render_token(generated_tokens[-1]) == "\n" |
|
|
| def _can_start_answer_with_structural_punctuation(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| if len(rendered) != 1 or not self._starts_new_word(token): |
| return False |
| return unicodedata.category(rendered) in ("Ps", "Pi") |
|
|
| def _is_common_connector_token(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| return rendered.isalpha() and len(rendered) == 1 and rendered.islower() |
|
|
| def _can_attach_word_joiner(self, generated_tokens: list[str]) -> bool: |
| if not generated_tokens: |
| return False |
| rendered = self._render_token(generated_tokens[-1]) |
| if not rendered: |
| return False |
| if any(character.isalnum() for character in rendered): |
| return True |
| if len(rendered) != 1: |
| return False |
| return unicodedata.category(rendered) in ("Ps", "Pi") |
|
|
| def _words_since_clause_break(self, tokens: list[str]) -> int: |
| assert self.tokenizer is not None |
|
|
| words = 0 |
| for token in reversed(tokens): |
| if token in self.tokenizer.special_tokens: |
| continue |
| rendered = self._render_token(token) |
| if self._is_structural_punctuation_text(rendered): |
| break |
| if self._starts_new_word(token) and not self._is_punctuation_piece(rendered): |
| words += 1 |
| return words |
|
|
| def _should_stop_generation(self, generated_tokens: list[str]) -> bool: |
| if not generated_tokens: |
| return False |
| if not self._is_terminal_punctuation_text(self._render_token(generated_tokens[-1])): |
| return False |
| word_count = self._generated_word_count(generated_tokens) |
| if word_count >= MIN_COMPLETE_ANSWER_WORDS: |
| return True |
| return ( |
| word_count >= MIN_COMPLETE_MULTI_SENTENCE_WORDS |
| and self._terminal_sentence_count(generated_tokens) >= 2 |
| ) |
|
|
| def _terminal_sentence_count(self, tokens: list[str]) -> int: |
| return sum( |
| 1 |
| for token in tokens |
| if self._is_terminal_punctuation_text(self._render_token(token)) |
| ) |
|
|
| def _is_terminal_punctuation_text(self, text: str) -> bool: |
| stripped = text.strip() |
| if not stripped: |
| return False |
| terminal_character = stripped[-1] |
| if not self._is_structural_punctuation_text(terminal_character): |
| return False |
| return not self._is_word_joiner_text(terminal_character) |
|
|
| def _should_skip_prompt_overlap_token(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| if not rendered.strip(): |
| return True |
| if ( |
| self.embedding_model is not None |
| and len(self.embedding_model.id_to_token) >= 1024 |
| and not self._starts_new_word(token) |
| ): |
| return True |
| if self._is_structural_punctuation_text(rendered): |
| return True |
| return rendered.strip().casefold() in PROMPT_ENVELOPE_TERMS |
|
|
| def _starts_new_word(self, token: str) -> bool: |
| assert self.tokenizer is not None |
| if token in self.tokenizer.special_tokens: |
| return True |
| if token.startswith(self.tokenizer.word_prefix): |
| return True |
| return len(token) == 1 and not token.isalnum() and not self._is_word_joiner_token(token) |
|
|
| def _generation_token_meta(self, token: str) -> GenerationTokenMeta: |
| cache = self.generation_token_meta_cache |
| if cache is None: |
| cache = {} |
| self.generation_token_meta_cache = cache |
| cached = cache.get(token) |
| if cached is not None: |
| return cached |
| rendered = self._render_token(token) |
| meta = GenerationTokenMeta( |
| rendered=rendered, |
| stripped=rendered.strip(), |
| starts_new_word=self._starts_new_word(token), |
| punctuation_piece=self._is_punctuation_piece(rendered), |
| structural_punctuation=self._is_structural_punctuation_token(token), |
| structural_symbol=self._is_structural_symbol_token(token), |
| word_joiner=self._is_word_joiner_token(token), |
| alphanumeric="".join(character for character in rendered if character.isalnum()), |
| common_connector=self._is_common_connector_token(token), |
| ) |
| cache[token] = meta |
| return meta |
|
|
| def _decode_tokens(self, tokens: list[str]) -> str: |
| assert self.tokenizer is not None |
| return self.tokenizer.decode( |
| tokens, |
| preserve_special_tokens=TOOL_PROTOCOL_TOKENS, |
| ) |
|
|
| @staticmethod |
| def _normalize_generated_tool_protocol_text(text: str, *, context: str | None = None) -> str: |
| marker = "<tool_call>" |
| call_index = text.find(marker) |
| if call_index < 0: |
| return text |
|
|
| cleaned = text[:] |
| for boundary in ("<tool_result>", "<source>", "<final>"): |
| boundary_index = cleaned.find(boundary, call_index + len(marker)) |
| if boundary_index >= 0: |
| cleaned = cleaned[:boundary_index].rstrip() |
|
|
| second_call_index = cleaned.find(marker, call_index + len(marker)) |
| if second_call_index >= 0: |
| cleaned = cleaned[:second_call_index].rstrip() |
|
|
| brace_start = cleaned.find("{", call_index) |
| if brace_start < 0: |
| return cleaned.strip() |
|
|
| depth = 0 |
| in_string = False |
| escaped = False |
| last_top_level_comma: int | None = None |
| for index in range(brace_start, len(cleaned)): |
| character = cleaned[index] |
| if escaped: |
| escaped = False |
| continue |
| if in_string and character == "\\": |
| escaped = True |
| continue |
| if character == '"': |
| in_string = not in_string |
| continue |
| if in_string: |
| continue |
| if character == "{": |
| depth += 1 |
| continue |
| if character == "}": |
| depth -= 1 |
| if depth <= 0: |
| candidate = cleaned[: index + 1].strip() |
| return ReframrModel._repair_tool_call_payload_if_needed( |
| candidate, |
| context=context, |
| ) |
| continue |
| if character == "," and depth == 1: |
| last_top_level_comma = index |
|
|
| if depth > 0: |
| if last_top_level_comma is not None: |
| candidate = cleaned[:last_top_level_comma].rstrip() + "}" |
| return ReframrModel._repair_tool_call_payload_if_needed( |
| candidate, |
| context=context, |
| ) |
| candidate = cleaned.rstrip() + "}" |
| return ReframrModel._repair_tool_call_payload_if_needed( |
| candidate, |
| context=context, |
| ) |
| return ReframrModel._repair_tool_call_payload_if_needed( |
| cleaned.strip(), |
| context=context, |
| ) |
|
|
| @staticmethod |
| def _repair_tool_call_payload_if_needed(text: str, *, context: str | None = None) -> str: |
| marker = "<tool_call>" |
| if not text.startswith(marker): |
| return text |
| brace_start = text.find("{", len(marker)) |
| if brace_start < 0: |
| return text |
| tool_name = text[len(marker) : brace_start].strip() |
| payload_text = text[brace_start:].strip() |
| try: |
| payload = json.loads(payload_text) |
| if isinstance(payload, dict) and tool_name == "web.search": |
| repaired_query = ReframrModel._repair_search_query_from_context_if_weak( |
| str(payload.get("query", "")), |
| context, |
| ) |
| if repaired_query is not None: |
| payload["query"] = repaired_query |
| return f"{marker} {tool_name} {json.dumps(payload, ensure_ascii=False)}" |
| return text |
| except (TypeError, json.JSONDecodeError): |
| pass |
| body = payload_text.strip() |
| if body.startswith("{"): |
| body = body[1:] |
| if body.endswith("}"): |
| body = body[:-1] |
| body = " ".join(body.replace('"', "").split()) |
| if not tool_name or not body: |
| return text |
| if tool_name == "web.search": |
| payload = { |
| "query": ReframrModel._repair_search_query_from_context_if_weak( |
| body, |
| context, |
| ) |
| or body |
| } |
| else: |
| payload = {"input": body} |
| return f"{marker} {tool_name} {json.dumps(payload, ensure_ascii=False)}" |
|
|
| @staticmethod |
| def _repair_search_query_from_context_if_weak( |
| query: str, |
| context: str | None, |
| ) -> str | None: |
| cleaned_query = " ".join(query.replace("{", " ").replace("}", " ").split()) |
| normalized_words = [ |
| word.strip(" \t\r\n:,.;!?\"'()[]{}").casefold() |
| for word in cleaned_query.split() |
| if word.strip(" \t\r\n:,.;!?\"'()[]{}") |
| ] |
| unique_content_words = { |
| word |
| for word in normalized_words |
| if word not in {"query", "web.search", "tool_call"} |
| } |
| lowered_query = cleaned_query.casefold() |
| weak = ( |
| len(unique_content_words) < 3 |
| or lowered_query.startswith("query:") |
| or "web.search" in lowered_query |
| or any( |
| marker in lowered_query |
| for marker in ("<tool", "<source>", "<final>", "according to") |
| ) |
| ) |
| if not weak: |
| return None |
| context_query = ReframrModel._search_query_from_context(context or "") |
| return context_query or None |
|
|
| @staticmethod |
| def _search_query_from_context(context: str) -> str: |
| if not context: |
| return "" |
| before_tool_result = context.split("<tool_result>", 1)[0] |
| before_final = before_tool_result.split("<final>", 1)[0] |
| lines = [line.strip() for line in before_final.splitlines() if line.strip()] |
| if not lines: |
| lines = [before_final.strip()] |
| latest_user = "" |
| for line in lines: |
| lowered = line.casefold() |
| if lowered.startswith("user:"): |
| latest_user = line.split(":", 1)[1].strip() |
| elif lowered.startswith("question:"): |
| latest_user = line.split(":", 1)[1].strip() |
| if not latest_user: |
| latest_user = lines[-1] |
| for prefix in ("User:", "Question:", "Prompt:", "Context:"): |
| if latest_user.casefold().startswith(prefix.casefold()): |
| latest_user = latest_user[len(prefix) :].strip() |
| cleaned = " ".join(latest_user.split()) |
| return cleaned.strip(" \t\r\n\"'") |
|
|
| @staticmethod |
| def _finalize_generated_text(text: str) -> str: |
| stripped = text.rstrip() |
| if not stripped: |
| return stripped |
| if stripped.startswith("<tool_call>"): |
| return stripped |
| stripped = ReframrModel._remove_separator_punctuation_before_boundary(stripped) |
| if stripped and ReframrModel._is_separator_punctuation(stripped[-1:]): |
| stripped = stripped[:-1].rstrip() |
| if not stripped: |
| return stripped |
| if ( |
| ReframrModel._is_surface_punctuation(stripped[:1]) |
| or ReframrModel._is_surface_punctuation(stripped[-1:]) |
| ): |
| return stripped |
| if any(character.isalnum() for character in stripped[-8:]): |
| return f"{stripped}." |
| return stripped |
|
|
| @staticmethod |
| def _remove_separator_punctuation_before_boundary(text: str) -> str: |
| cleaned: list[str] = [] |
| for character in text: |
| if ( |
| ReframrModel._is_separator_punctuation(character) |
| and cleaned |
| and ReframrModel._is_separator_punctuation(cleaned[-1]) |
| ): |
| cleaned.pop() |
| cleaned.append(character) |
| return "".join(cleaned) |
|
|
| @staticmethod |
| def _is_surface_punctuation(character: str) -> bool: |
| return len(character) == 1 and unicodedata.category(character).startswith("P") |
|
|
| @staticmethod |
| def _is_separator_punctuation(character: str) -> bool: |
| return ( |
| ReframrModel._is_surface_punctuation(character) |
| and unicodedata.bidirectional(character) == "CS" |
| ) |
|
|
| def _render_token(self, token: str) -> str: |
| assert self.tokenizer is not None |
| if token.startswith(self.tokenizer.word_prefix): |
| return token[len(self.tokenizer.word_prefix) :] |
| return token |
|
|
| def _require_fit(self) -> None: |
| if ( |
| self.tokenizer is None |
| or self.embedding_model is None |
| or self.memory_units is None |
| or self.readout_weights is None |
| or self.ternary_mask is None |
| or self.associative_keys is None |
| or ( |
| self.associative_key_norms is None |
| and self.associative_key_norms_array is None |
| ) |
| or self.associative_values is None |
| or self.transition_tables is None |
| ): |
| raise RuntimeError("Call fit() before using the REFRAMR model.") |
|
|
| def _ensure_numeric_caches(self) -> None: |
| if np is None: |
| return |
| if self.readout_weights_array is None: |
| self._refresh_numeric_caches() |
|
|