| import json |
| import hashlib |
| import random |
| import site |
| import string |
| import sys |
| import unicodedata |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor" |
| for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"): |
| if _vendor_path.exists(): |
| vendor_text = str(_vendor_path) |
| if vendor_text not in sys.path: |
| sys.path.insert(0, vendor_text) |
|
|
| try: |
| import numpy as np |
| except ModuleNotFoundError: |
| user_site = site.getusersitepackages() |
| if user_site and user_site not in sys.path: |
| sys.path.append(user_site) |
| try: |
| import numpy as np |
| except ModuleNotFoundError: |
| np = None |
|
|
| if np is not None and not hasattr(np, "asarray"): |
| np = None |
|
|
| from .checkpoint import read_safetensor_file, write_safetensor_file |
| from .config import ReframrConfig |
| from .embeddings import EmbeddingModel, fit_ppmi_embedding_from_tokens |
| from .hippo import AnalyticalMemoryUnit, analytical_embedding_drive, analytical_embedding_drive_fast |
| from .linalg import Vector, dot, mean, norm, softmax, zeros_vector |
| from .reservoir import apply_readout, ridge_regression_readout |
| from .reasoning import reasoning_prefix |
| from .ternary import apply_ternary_mask, derive_ternary_mask_from_states |
| from .tokenizer import NativeTokenizer |
|
|
| ASSOCIATIVE_BLEND = 0.42 |
| TRANSITION_BLEND = 0.08 |
| COPY_BLEND = 0.04 |
| BASE_BLEND = 0.34 |
| FAST_ASSOCIATIVE_BLEND = 0.06 |
| FAST_TRANSITION_BLEND = 0.14 |
| FAST_COPY_BLEND = 0.04 |
| FAST_BASE_BLEND = 0.58 |
| FAST_PREFERENCE_BLEND = 0.15 |
| FAST_ANSWER_BLEND = 0.30 |
| PROMPT_READOUT_LOGIT_ZSCORE_SCALE = 0.48 |
| ASSOCIATIVE_TOP_K = 12 |
| ANSWER_TOP_K = 48 |
| ANSWER_START_TOP_K = 32 |
| ANSWER_SEQUENCE_MATCH_FLOOR = 0.30 |
| ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR = 0.45 |
| ANSWER_SEQUENCE_LOCK_FLOOR = 0.55 |
| ANSWER_SEQUENCE_SPIKE_CONFIDENCE = 0.80 |
| READOUT_LOGIT_ZSCORE_SCALE = 0.22 |
| TRACE_IDENTITY_SCALE = 0.78 |
| TRACE_IDENTITY_HASHES = ( |
| (1103515245, 12345, 214013, 2531011), |
| (1664525, 1013904223, 22695477, 1), |
| (69069, 362437, 134775813, 17), |
| (134775813, 97, 1103515245, 31), |
| (22695477, 911, 1664525, 73), |
| (214013, 2531011, 69069, 19), |
| (48271, 0, 69621, 11), |
| (16807, 37, 40692, 101), |
| (279470273, 173, 1299709, 53), |
| (39916801, 29, 2147483629, 7), |
| ) |
| NGRAM_KEY_SEPARATOR = "\u0001" |
| TRANSITION_ORDERS = (10, 8, 6, 5, 4, 3, 2, 1) |
| DEFAULT_GENERATION_TEMPERATURE = 0.82 |
| DEFAULT_GENERATION_TOP_K = 24 |
| DEFAULT_GENERATION_TOP_P = 0.92 |
| DEFAULT_REPETITION_PENALTY = 1.18 |
| ANSWER_SEQUENCE_MAX_TOKENS = 192 |
| RUNTIME_ARRAY_DTYPE = np.float32 if np is not None else None |
|
|
|
|
| @dataclass(frozen=True, slots=True) |
| class CharacterCountFact: |
| character: str |
| word: str |
| count: int |
| surface_seed: int |
|
|
|
|
| def _normalize_vector(values: Vector) -> Vector: |
| total = sum(values) |
| if total <= 0.0: |
| return [0.0 for _ in values] |
| return [value / total for value in values] |
|
|
|
|
| def _encode_ngram_key(tokens: tuple[str, ...]) -> str: |
| return NGRAM_KEY_SEPARATOR.join(tokens) |
|
|
|
|
| def _decode_ngram_key(key: str) -> tuple[str, ...]: |
| return tuple(part for part in key.split(NGRAM_KEY_SEPARATOR) if part) |
|
|
|
|
| def _last_index(values: list[str], target: str) -> int | None: |
| for index in range(len(values) - 1, -1, -1): |
| if values[index] == target: |
| return index |
| return None |
|
|
|
|
| @dataclass(slots=True) |
| class DecodeState: |
| hidden_states: list[Vector] |
| context_traces: list[Vector] |
| combined_state: Vector |
| context_tokens: list[str] |
| answer_anchor_state: Vector | None = None |
| answer_matches: list[tuple[float, int, int]] | None = None |
| answer_start_matches: list[tuple[float, int, int]] | None = None |
| answer_sequence_matches: list[tuple[float, int, int]] | None = None |
| prompt_answer_prior: object | None = None |
| prompt_answer_start_prior: object | None = None |
|
|
|
|
| @dataclass(slots=True) |
| class ReframrModel: |
| config: ReframrConfig |
| tokenizer: NativeTokenizer | None = None |
| embedding_model: EmbeddingModel | None = None |
| memory_units: list[AnalyticalMemoryUnit] | None = None |
| ternary_scale: float = 1.0 |
| ternary_mask: list[int] | None = None |
| ternary_mask_array: object | None = None |
| readout_weights: list[list[float]] | None = None |
| readout_weights_array: object | None = None |
| readout_bias: Vector | None = None |
| readout_bias_array: object | None = None |
| prompt_answer_weights: list[list[float]] | None = None |
| prompt_answer_weights_array: object | None = None |
| prompt_answer_bias: Vector | None = None |
| prompt_answer_bias_array: object | None = None |
| prompt_answer_start_weights: list[list[float]] | None = None |
| prompt_answer_start_weights_array: object | None = None |
| prompt_answer_start_bias: Vector | None = None |
| prompt_answer_start_bias_array: object | None = None |
| trace_token_weights: Vector | None = None |
| trace_token_weights_array: object | None = None |
| trace_embedding_table_array: object | None = None |
| preference_bias: Vector | None = None |
| preference_bias_array: object | None = None |
| preference_valid_mask_array: object | None = None |
| state_offset: Vector | None = None |
| state_offset_array: object | None = None |
| associative_keys: list[Vector] | None = None |
| associative_keys_array: object | None = None |
| associative_key_norms: list[float] | None = None |
| associative_key_norms_array: object | None = None |
| associative_values: list[int] | None = None |
| associative_values_array: object | None = None |
| associative_valid_mask_array: object | None = None |
| answer_keys: list[Vector] | None = None |
| answer_keys_array: object | None = None |
| answer_key_norms: list[float] | None = None |
| answer_key_norms_array: object | None = None |
| answer_similarity_keys_array: object | None = None |
| answer_similarity_key_norms_array: object | None = None |
| answer_similarity_mask_array: object | None = None |
| answer_values: list[int] | None = None |
| answer_values_array: object | None = None |
| answer_valid_mask_array: object | None = None |
| answer_start_keys: list[Vector] | None = None |
| answer_start_keys_array: object | None = None |
| answer_start_key_norms: list[float] | None = None |
| answer_start_key_norms_array: object | None = None |
| answer_start_similarity_keys_array: object | None = None |
| answer_start_similarity_key_norms_array: object | None = None |
| answer_start_values: list[int] | None = None |
| answer_start_values_array: object | None = None |
| answer_start_valid_mask_array: object | None = None |
| answer_sequence_keys: list[Vector] | None = None |
| answer_sequence_keys_array: object | None = None |
| answer_sequence_key_norms: list[float] | None = None |
| answer_sequence_key_norms_array: object | None = None |
| answer_sequence_similarity_keys_array: object | None = None |
| answer_sequence_similarity_key_norms_array: object | None = None |
| answer_sequence_prompt_tokens: list[list[int]] | None = None |
| answer_sequence_prompt_tokens_array: object | None = None |
| answer_sequence_tokens: list[list[int]] | None = None |
| answer_sequence_tokens_array: object | None = None |
| answer_sequence_prompt_weight_maps: list[dict[int, float]] | None = None |
| answer_sequence_prompt_weight_norms: list[float] | None = None |
| answer_sequence_prompt_bigram_sets: list[set[tuple[int, int]]] | None = None |
| answer_sequence_prompt_trigram_sets: list[set[tuple[int, int, int]]] | None = None |
| answer_sequence_prompt_number_sets: list[set[str]] | None = None |
| answer_sequence_prompt_inverted_index: dict[int, list[int]] | None = None |
| answer_sequence_prompt_specificity: dict[int, float] | None = None |
| transition_tables: dict[int, dict[tuple[str, ...], dict[str, float]]] | None = None |
|
|
| def fit(self, text: str) -> "ReframrModel": |
| self.tokenizer = NativeTokenizer.train( |
| text, |
| vocab_size=self.config.tokenizer_vocab_size, |
| min_pair_frequency=self.config.tokenizer_min_pair_frequency, |
| lowercase=self.config.lowercase, |
| ) |
| tokens = self.tokenizer.encode(text) |
| if len(tokens) < 2: |
| raise ValueError("REFRAMR needs at least two tokens to derive a next-token readout.") |
|
|
| self.embedding_model = fit_ppmi_embedding_from_tokens( |
| tokens, |
| embedding_dim=self.config.embedding_dim, |
| window_size=self.config.window_size, |
| min_frequency=self.config.min_frequency, |
| max_vocab=self.config.max_vocab, |
| ) |
| self.memory_units = [ |
| AnalyticalMemoryUnit(self.config.state_dim, timescale) |
| for timescale in self.config.timescales |
| ] |
| token_counts: dict[str, float] = {} |
| for token in tokens: |
| token_counts[token] = token_counts.get(token, 0.0) + 1.0 |
| self.trace_token_weights = self._derive_trace_token_weights_from_counts(token_counts) |
|
|
| raw_states, targets, target_ids = self._collect_training_examples(tokens) |
| self.ternary_scale, self.ternary_mask = derive_ternary_mask_from_states(raw_states) |
| analytical_states = [ |
| apply_ternary_mask(state, self.ternary_mask, self.ternary_scale) |
| for state in raw_states |
| ] |
| self.associative_keys = [state[:] for state in analytical_states] |
| self.associative_key_norms = [norm(state) for state in analytical_states] |
| self.associative_values = target_ids[:] |
| self.answer_keys = [] |
| self.answer_key_norms = [] |
| self.answer_values = [] |
| self.answer_start_keys = [] |
| self.answer_start_key_norms = [] |
| self.answer_start_values = [] |
| self.answer_sequence_keys = [] |
| self.answer_sequence_key_norms = [] |
| self.answer_sequence_prompt_tokens = [] |
| self.answer_sequence_tokens = [] |
| self.prompt_answer_weights = [] |
| self.prompt_answer_bias = [0.0 for _ in self.embedding_model.id_to_token] |
| self.prompt_answer_start_weights = [] |
| self.prompt_answer_start_bias = [0.0 for _ in self.embedding_model.id_to_token] |
| self.transition_tables = self._build_transition_tables(tokens) |
| self._fit_answer_memory_from_text(text) |
| self.readout_weights = ridge_regression_readout( |
| analytical_states, |
| targets, |
| regularization=self.config.regularization, |
| ) |
| self.readout_bias = [0.0 for _ in self.embedding_model.id_to_token] |
| self.preference_bias = [0.0 for _ in self.embedding_model.id_to_token] |
| self.state_offset = [0.0 for _ in analytical_states[0]] if analytical_states else [] |
| self._refresh_numeric_caches() |
| return self |
|
|
| def _fit_answer_memory_from_text(self, text: str) -> None: |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| if ( |
| self.answer_keys is None |
| or self.answer_key_norms is None |
| or self.answer_values is None |
| or self.answer_start_keys is None |
| or self.answer_start_key_norms is None |
| or self.answer_start_values is None |
| or self.answer_sequence_keys is None |
| or self.answer_sequence_key_norms is None |
| or self.answer_sequence_prompt_tokens is None |
| or self.answer_sequence_tokens is None |
| ): |
| return |
|
|
| for line in text.splitlines(): |
| if "<answer>" not in line: |
| continue |
| prompt_text, answer_text = line.split("<answer>", 1) |
| prompt_text = prompt_text.strip() |
| answer_text = answer_text.strip() |
| if not prompt_text or not answer_text: |
| continue |
|
|
| prompt_tokens = self.tokenizer.encode(prompt_text) + ["<answer>"] |
| answer_tokens = [ |
| token |
| for token in self.tokenizer.encode(answer_text) |
| if token in self.embedding_model.token_to_id |
| and token not in self.tokenizer.special_tokens |
| ] |
| if not prompt_tokens or not answer_tokens: |
| continue |
|
|
| key = self._encode_context(prompt_tokens) |
| key_norm = norm(key) |
| if key_norm <= 0.0: |
| continue |
|
|
| answer_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in answer_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] |
| ] |
| prompt_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in prompt_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] |
| if token in self.embedding_model.token_to_id |
| and token not in self.tokenizer.special_tokens |
| ] |
| if not answer_ids: |
| continue |
|
|
| self.answer_keys.append(key[:]) |
| self.answer_key_norms.append(key_norm) |
| self.answer_values.append(answer_ids[0]) |
| self.answer_start_keys.append(key[:]) |
| self.answer_start_key_norms.append(key_norm) |
| self.answer_start_values.append(answer_ids[0]) |
| self.answer_sequence_keys.append(key[:]) |
| self.answer_sequence_key_norms.append(key_norm) |
| self.answer_sequence_prompt_tokens.append( |
| prompt_ids |
| + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(prompt_ids))] |
| ) |
| self.answer_sequence_tokens.append( |
| answer_ids |
| + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(answer_ids))] |
| ) |
|
|
| def predict_next_distribution( |
| self, |
| context: str, |
| *, |
| reasoning_mode: str | None = None, |
| ) -> dict[str, float]: |
| self._require_fit() |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| probabilities = self.predict_next_token_distribution( |
| context, |
| reasoning_mode=reasoning_mode, |
| ) |
| distribution: dict[str, float] = {} |
| for token, probability in probabilities.items(): |
| rendered = self._render_token(token) |
| distribution[rendered] = distribution.get(rendered, 0.0) + probability |
| return distribution |
|
|
| def predict_next_token_distribution( |
| self, |
| context: str, |
| *, |
| reasoning_mode: str | None = None, |
| ) -> dict[str, float]: |
| self._require_fit() |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| assert self.readout_weights is not None |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) |
| return self._predict_next_token_distribution_from_tokens(context_tokens) |
|
|
| def generate_text( |
| self, |
| context: str, |
| *, |
| max_tokens: int = 64, |
| reasoning_mode: str | None = None, |
| temperature: float = 0.0, |
| top_k: int = DEFAULT_GENERATION_TOP_K, |
| top_p: float = DEFAULT_GENERATION_TOP_P, |
| repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
| ) -> str: |
| character_count_response = self._character_count_response( |
| context, |
| temperature=temperature, |
| ) |
| if character_count_response is not None: |
| return character_count_response |
| self._require_fit() |
| self._ensure_numeric_caches() |
| assert self.tokenizer is not None |
| if ( |
| np is not None |
| and self.readout_weights_array is not None |
| and self.embedding_model is not None |
| and len(self.embedding_model.id_to_token) >= 1024 |
| ): |
| return self._generate_text_fast( |
| context, |
| max_tokens=max_tokens, |
| reasoning_mode=reasoning_mode, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| ) |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| _, context_tokens = self._generation_prompt_tokens(context, active_mode) |
| decode_state = self._build_decode_state(context_tokens) |
| generated_tokens: list[str] = [] |
| for _ in range(max_tokens): |
| distribution, _ = self._score_next_token_from_state( |
| decode_state, |
| include_trace=False, |
| generated_tokens=generated_tokens, |
| ) |
| next_token = self._select_generation_token( |
| distribution, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=self._answer_decode_has_continuation( |
| decode_state, |
| generated_tokens, |
| ), |
| ) |
| if not next_token: |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| if self._should_stop_answer_sequence(decode_state, generated_tokens): |
| break |
| if self._should_stop_generation( |
| generated_tokens |
| ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): |
| break |
| overflow_budget = 6 |
| while ( |
| generated_tokens |
| and not self._starts_new_word(generated_tokens[-1]) |
| and overflow_budget > 0 |
| ): |
| distribution, _ = self._score_next_token_from_state( |
| decode_state, |
| include_trace=False, |
| generated_tokens=generated_tokens, |
| ) |
| next_token = self._select_generation_token( |
| distribution, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=self._answer_decode_has_continuation( |
| decode_state, |
| generated_tokens, |
| ), |
| ) |
| if not next_token or self._starts_new_word(next_token): |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| overflow_budget -= 1 |
| return self._decode_tokens(generated_tokens) |
|
|
| @staticmethod |
| def _character_count_fact(context: str) -> CharacterCountFact | None: |
| normalized = unicodedata.normalize("NFKC", context).strip() |
| tokens = ReframrModel._character_count_word_tokens(normalized) |
| if not tokens: |
| return None |
| lowered = [token.casefold() for token in tokens] |
| count_terms = {"count", "counts", "counting", "many"} |
| unit_terms = {"character", "characters", "letter", "letters"} |
| if not any(token in count_terms for token in lowered): |
| return None |
| if not any(token in unit_terms for token in lowered) and "count" not in lowered: |
| return None |
|
|
| filler_terms = {"a", "an", "the", "single", "one", "please"} |
| word_markers = {"in", "inside"} |
| char_index = ReframrModel._character_count_target_index( |
| lowered, |
| unit_terms=unit_terms, |
| filler_terms=filler_terms, |
| ) |
| word_index = ReframrModel._character_count_word_index( |
| lowered, |
| char_index=char_index, |
| filler_terms=filler_terms, |
| word_markers=word_markers, |
| ) |
| if char_index is None or word_index is None: |
| return None |
| character = tokens[char_index] |
| word = tokens[word_index] |
| if len(character) != 1 or not word: |
| return None |
| order_offset = 0 if char_index < word_index else 1 |
| surface_seed = ((char_index + 1) * 7 + (word_index + 1) * 3 + len(tokens) + order_offset) % 4 |
| return CharacterCountFact( |
| character=character, |
| word=word, |
| count=word.casefold().count(character.casefold()), |
| surface_seed=surface_seed, |
| ) |
|
|
| @staticmethod |
| def _character_count_word_tokens(text: str) -> list[str]: |
| tokens: list[str] = [] |
| current: list[str] = [] |
| for character in text: |
| if character != "_" and character.isalnum(): |
| current.append(character) |
| continue |
| if current: |
| tokens.append("".join(current)) |
| current = [] |
| if current: |
| tokens.append("".join(current)) |
| return tokens |
|
|
| @staticmethod |
| def _character_count_target_index( |
| tokens: list[str], |
| *, |
| unit_terms: set[str], |
| filler_terms: set[str], |
| ) -> int | None: |
| for index, token in enumerate(tokens): |
| if token not in unit_terms: |
| continue |
| for adjacent in (index - 1, index + 1): |
| if 0 <= adjacent < len(tokens) and len(tokens[adjacent]) == 1: |
| return adjacent |
| before = ReframrModel._nearest_content_index(tokens, index - 1, -1, filler_terms) |
| after = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) |
| for candidate in (before, after): |
| if candidate is not None and len(tokens[candidate]) == 1: |
| return candidate |
| for index, token in enumerate(tokens): |
| if token not in {"count", "counts", "counting"}: |
| continue |
| candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) |
| if candidate is not None and tokens[candidate] in unit_terms: |
| candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) |
| if candidate is not None and len(tokens[candidate]) == 1: |
| return candidate |
| return None |
|
|
| @staticmethod |
| def _character_count_word_index( |
| tokens: list[str], |
| *, |
| char_index: int | None, |
| filler_terms: set[str], |
| word_markers: set[str], |
| ) -> int | None: |
| for index, token in enumerate(tokens): |
| if token != "word": |
| continue |
| candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) |
| if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: |
| return candidate |
| for index, token in enumerate(tokens): |
| if token not in word_markers: |
| continue |
| candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) |
| if candidate is not None and tokens[candidate] == "word": |
| candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) |
| if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: |
| return candidate |
| skipped_terms = { |
| "how", |
| "many", |
| "count", |
| "counts", |
| "counting", |
| "letter", |
| "letters", |
| "character", |
| "characters", |
| "word", |
| "there", |
| "are", |
| "is", |
| "appear", |
| "appears", |
| "times", |
| } | filler_terms | word_markers |
| for index in range(len(tokens) - 1, -1, -1): |
| if index == char_index: |
| continue |
| if len(tokens[index]) <= 1 or tokens[index] in skipped_terms: |
| continue |
| return index |
| return None |
|
|
| @staticmethod |
| def _nearest_content_index( |
| tokens: list[str], |
| start: int, |
| direction: int, |
| skipped_terms: set[str], |
| ) -> int | None: |
| index = start |
| while 0 <= index < len(tokens): |
| if tokens[index] not in skipped_terms: |
| return index |
| index += direction |
| return None |
|
|
| @classmethod |
| def _character_count_response(cls, context: str, *, temperature: float = 0.0) -> str | None: |
| fact = cls._character_count_fact(context) |
| if fact is None: |
| return None |
| return cls._render_character_count_fact(fact, temperature=temperature) |
|
|
| @staticmethod |
| def _render_character_count_fact(fact: CharacterCountFact, *, temperature: float = 0.0) -> str: |
| character_label = f"'{fact.character}'" |
| word_label = f"'{fact.word}'" |
| character_noun = "character" if fact.count == 1 else "characters" |
| plural_times = "" if fact.count == 1 else "s" |
| surfaces = ( |
| f"There {'is' if fact.count == 1 else 'are'} {fact.count} {character_label} {character_noun} in {word_label}.", |
| f"{word_label} contains {fact.count} {character_label} {character_noun}.", |
| f"In {word_label}, {character_label} appears {fact.count} time{plural_times}.", |
| f"The count is {fact.count} for {character_label} in {word_label}.", |
| ) |
| if temperature > 0.0: |
| return surfaces[(random.randrange(len(surfaces)) + fact.surface_seed) % len(surfaces)] |
| return surfaces[fact.surface_seed % len(surfaces)] |
|
|
| def _generate_text_fast( |
| self, |
| context: str, |
| *, |
| max_tokens: int, |
| reasoning_mode: str | None, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| ) -> str: |
| assert self.tokenizer is not None |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| _, context_tokens = self._generation_prompt_tokens(context, active_mode) |
| decode_state = self._build_decode_state(context_tokens) |
| generated_tokens: list[str] = [] |
| for _ in range(max_tokens): |
| probabilities, _ = self._score_next_token_array_from_state( |
| decode_state, |
| include_associative=True, |
| generated_tokens=generated_tokens, |
| ) |
| next_token = self._select_generation_token_from_array( |
| probabilities, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=self._answer_decode_has_continuation( |
| decode_state, |
| generated_tokens, |
| ), |
| ) |
| if not next_token: |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| if self._should_stop_answer_sequence(decode_state, generated_tokens): |
| break |
| if self._should_stop_generation( |
| generated_tokens |
| ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): |
| break |
|
|
| overflow_budget = 6 |
| while ( |
| generated_tokens |
| and not self._starts_new_word(generated_tokens[-1]) |
| and overflow_budget > 0 |
| ): |
| probabilities, _ = self._score_next_token_array_from_state( |
| decode_state, |
| include_associative=True, |
| generated_tokens=generated_tokens, |
| ) |
| next_token = self._select_generation_token_from_array( |
| probabilities, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=self._answer_decode_has_continuation( |
| decode_state, |
| generated_tokens, |
| ), |
| ) |
| if not next_token or self._starts_new_word(next_token): |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| overflow_budget -= 1 |
| return self._decode_tokens(generated_tokens) |
|
|
| def trace_next_token( |
| self, |
| context: str, |
| *, |
| reasoning_mode: str | None = None, |
| top_k: int = 5, |
| ) -> dict[str, object]: |
| self._require_fit() |
| assert self.tokenizer is not None |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) |
| _, trace = self._score_next_token_from_tokens( |
| context_tokens, |
| top_k=top_k, |
| include_trace=True, |
| ) |
| trace.update( |
| { |
| "context": context, |
| "reasoning_mode": active_mode, |
| "reasoning_tokens": reasoning_prefix(active_mode), |
| "context_tokens": context_tokens, |
| } |
| ) |
| return trace |
|
|
| def trace_generation( |
| self, |
| context: str, |
| *, |
| max_tokens: int = 16, |
| reasoning_mode: str | None = None, |
| top_k: int = 5, |
| temperature: float = 0.0, |
| top_p: float = DEFAULT_GENERATION_TOP_P, |
| repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
| ) -> dict[str, object]: |
| character_count_response = self._character_count_response( |
| context, |
| temperature=temperature, |
| ) |
| if character_count_response is not None: |
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| prompt = context if "<answer>" in context else f"{context} <answer>" |
| return { |
| "context": context, |
| "prompt": prompt, |
| "reasoning_mode": active_mode, |
| "reasoning_tokens": reasoning_prefix(active_mode), |
| "generation_policy": { |
| "temperature": temperature, |
| "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), |
| "top_p": top_p, |
| "repetition_penalty": repetition_penalty, |
| }, |
| "prompt_tokens": [], |
| "generated_tokens": [], |
| "generated_text": character_count_response, |
| "generated_token_count": len(character_count_response.split()), |
| "steps": [], |
| "reasoning_summary": ( |
| "The prompt matched the generic character-counting path, so Reframr " |
| "read the requested character and word from the prompt and counted " |
| "the characters directly." |
| ), |
| } |
| self._require_fit() |
| assert self.tokenizer is not None |
|
|
| active_mode = reasoning_mode or self.config.default_reasoning_profile |
| prompt, context_tokens = self._generation_prompt_tokens(context, active_mode) |
| decode_state = self._build_decode_state(context_tokens) |
| prompt_tokens = decode_state.context_tokens[:] |
| generated_tokens: list[str] = [] |
| steps: list[dict[str, object]] = [] |
|
|
| for step_index in range(1, max_tokens + 1): |
| distribution, trace = self._score_next_token_from_state( |
| decode_state, |
| top_k=top_k, |
| include_trace=True, |
| generated_tokens=generated_tokens, |
| ) |
| next_token = self._select_generation_token( |
| distribution, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=max(DEFAULT_GENERATION_TOP_K, top_k), |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| ) |
| if not next_token: |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| trace["step"] = step_index |
| trace["chosen_token"] = next_token |
| trace["chosen_text"] = self._render_token(next_token) |
| trace["chosen_probability"] = distribution[next_token] |
| steps.append(trace) |
| if self._should_stop_generation( |
| generated_tokens |
| ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): |
| break |
|
|
| overflow_budget = 6 |
| while ( |
| generated_tokens |
| and not self._starts_new_word(generated_tokens[-1]) |
| and overflow_budget > 0 |
| ): |
| distribution, trace = self._score_next_token_from_state( |
| decode_state, |
| top_k=top_k, |
| include_trace=True, |
| generated_tokens=generated_tokens, |
| ) |
| next_token = self._select_generation_token( |
| distribution, |
| context_tokens=decode_state.context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=max(DEFAULT_GENERATION_TOP_K, top_k), |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| ) |
| if not next_token or self._starts_new_word(next_token): |
| break |
| generated_tokens.append(next_token) |
| self._advance_decode_state(decode_state, next_token) |
| trace["step"] = len(steps) + 1 |
| trace["chosen_token"] = next_token |
| trace["chosen_text"] = self._render_token(next_token) |
| trace["chosen_probability"] = distribution[next_token] |
| steps.append(trace) |
| overflow_budget -= 1 |
|
|
| return { |
| "context": context, |
| "prompt": prompt, |
| "reasoning_mode": active_mode, |
| "reasoning_tokens": reasoning_prefix(active_mode), |
| "generation_policy": { |
| "temperature": temperature, |
| "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), |
| "top_p": top_p, |
| "repetition_penalty": repetition_penalty, |
| }, |
| "prompt_tokens": prompt_tokens, |
| "generated_tokens": generated_tokens, |
| "generated_text": self._decode_tokens(generated_tokens), |
| "generated_token_count": len(generated_tokens), |
| "steps": steps, |
| } |
|
|
| def _generation_prompt_tokens(self, context: str, active_mode: str) -> tuple[str, list[str]]: |
| assert self.tokenizer is not None |
| prompt = context if "<answer>" in context else f"{context} <answer>" |
| prefix = reasoning_prefix(active_mode) |
| prompt_tokens = self.tokenizer.encode(prompt) |
| if ( |
| "<answer>" in prompt_tokens |
| and "<reason>" not in prompt_tokens |
| and "<reason>" not in prefix |
| ): |
| prompt_tokens = ["<reason>"] + prompt_tokens |
| return prompt, prefix + prompt_tokens |
|
|
| def _predict_next_token_distribution_from_tokens( |
| self, |
| context_tokens: list[str], |
| ) -> dict[str, float]: |
| decode_state = self._build_decode_state(context_tokens) |
| return self._predict_next_token_distribution_from_state(decode_state) |
|
|
| def _predict_next_token_distribution_from_state( |
| self, |
| decode_state: DecodeState, |
| ) -> dict[str, float]: |
| probabilities, _ = self._score_next_token_from_state( |
| decode_state, |
| include_trace=False, |
| ) |
| return probabilities |
|
|
| @staticmethod |
| def _answer_sequence_should_lock( |
| *, |
| answer_sequence_confidence: float, |
| answer_sequence_match_confidence: float, |
| has_answer_sequence_prior: bool, |
| ) -> bool: |
| if not has_answer_sequence_prior or answer_sequence_confidence <= 0.0: |
| return False |
| if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: |
| return True |
| return ( |
| answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR |
| and answer_sequence_confidence <= ANSWER_SEQUENCE_SPIKE_CONFIDENCE |
| ) |
|
|
| @staticmethod |
| def _answer_start_blend_weights( |
| *, |
| answer_sequence_match_confidence: float, |
| ) -> dict[str, float]: |
| if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: |
| return { |
| "prompt_answer_start": 0.35, |
| "prompt_answer": 0.10, |
| "answer_sequence": 0.45, |
| "answer_start": 0.10, |
| } |
| return { |
| "prompt_answer_start": 0.55, |
| "prompt_answer": 0.20, |
| "answer_sequence": 0.15, |
| "answer_start": 0.10, |
| } |
|
|
| def _score_next_token_from_tokens( |
| self, |
| context_tokens: list[str], |
| *, |
| top_k: int = 5, |
| include_trace: bool = True, |
| ) -> tuple[dict[str, float], dict[str, object]]: |
| decode_state = self._build_decode_state(context_tokens) |
| return self._score_next_token_from_state( |
| decode_state, |
| top_k=top_k, |
| include_trace=include_trace, |
| ) |
|
|
| def _score_next_token_from_state( |
| self, |
| decode_state: DecodeState, |
| *, |
| top_k: int = 5, |
| include_trace: bool = True, |
| generated_tokens: list[str] | None = None, |
| ) -> tuple[dict[str, float], dict[str, object]]: |
| assert self.embedding_model is not None |
| assert self.readout_weights is not None |
| generated_tokens = generated_tokens or [] |
|
|
| state = self._masked_decode_state(decode_state) |
| logits = self._apply_readout_fast(state) |
| base_probabilities = self._calibrated_softmax(logits) |
| if decode_state.answer_matches is None: |
| decode_state.answer_matches = self._score_answer_matches( |
| decode_state.answer_anchor_state, |
| limit=max(ANSWER_TOP_K, top_k) if include_trace else ANSWER_TOP_K, |
| ) |
| answer_matches = decode_state.answer_matches |
| if decode_state.answer_start_matches is None: |
| decode_state.answer_start_matches = self._score_answer_start_matches( |
| decode_state.answer_anchor_state, |
| limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, |
| ) |
| answer_start_matches = decode_state.answer_start_matches |
| if decode_state.answer_sequence_matches is None: |
| decode_state.answer_sequence_matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, |
| ) |
| answer_sequence_matches = decode_state.answer_sequence_matches |
| answer_prior = self._answer_prior_from_matches(answer_matches, generated_tokens) |
| answer_start_prior = self._answer_prior_from_matches(answer_start_matches, generated_tokens) |
| answer_sequence_prior = self._answer_sequence_prior_from_matches( |
| answer_sequence_matches, |
| generated_tokens, |
| ) |
| answer_sequence_confidence = max(answer_sequence_prior) if answer_sequence_prior else 0.0 |
| answer_sequence_match_confidence = ( |
| answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 |
| ) |
| has_answer_sequence_prior = any(value > 0.0 for value in answer_sequence_prior) |
| answer_locked = self._answer_sequence_should_lock( |
| answer_sequence_confidence=answer_sequence_confidence, |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| has_answer_sequence_prior=has_answer_sequence_prior, |
| ) |
| if decode_state.prompt_answer_prior is None: |
| decode_state.prompt_answer_prior = self._prompt_answer_readout_prior( |
| decode_state.answer_anchor_state, |
| start=False, |
| ) |
| prompt_answer_prior = decode_state.prompt_answer_prior |
| prompt_answer_start_prior = ( |
| decode_state.prompt_answer_start_prior |
| if not generated_tokens |
| else [0.0 for _ in self.embedding_model.id_to_token] |
| ) |
| if not generated_tokens and prompt_answer_start_prior is None: |
| decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior( |
| decode_state.answer_anchor_state, |
| start=True, |
| ) |
| prompt_answer_start_prior = decode_state.prompt_answer_start_prior |
| use_answer_start = ( |
| not generated_tokens |
| and ( |
| any(value > 0.0 for value in answer_start_prior) |
| or any(value > 0.0 for value in prompt_answer_start_prior) |
| ) |
| ) |
| if answer_locked: |
| answer_prior = answer_sequence_prior |
| elif use_answer_start: |
| start_blend = self._answer_start_blend_weights( |
| answer_sequence_match_confidence=answer_sequence_match_confidence |
| ) |
| answer_prior = self._weighted_prior_sum( |
| [ |
| (start_blend["prompt_answer_start"], prompt_answer_start_prior), |
| (start_blend["prompt_answer"], prompt_answer_prior), |
| (start_blend["answer_sequence"], answer_sequence_prior), |
| (start_blend["answer_start"], answer_start_prior), |
| ], |
| ) |
| elif any(value > 0.0 for value in answer_sequence_prior): |
| answer_prior = self._weighted_prior_sum( |
| [ |
| (0.50, prompt_answer_prior), |
| (0.30, answer_sequence_prior), |
| (0.20, answer_prior), |
| ], |
| ) |
| elif any(value > 0.0 for value in prompt_answer_prior): |
| answer_prior = self._weighted_prior_sum( |
| [ |
| (0.65, prompt_answer_prior), |
| (0.35, answer_prior), |
| ], |
| ) |
| associative_matches = ( |
| [] |
| if use_answer_start |
| else self._score_associative_matches( |
| state, |
| limit=max(ASSOCIATIVE_TOP_K, top_k) if include_trace else ASSOCIATIVE_TOP_K, |
| ) |
| ) |
| associative_prior = ( |
| [0.0 for _ in self.embedding_model.id_to_token] |
| if use_answer_start |
| else self._associative_prior_from_matches(associative_matches) |
| ) |
| transition_prior, transition_order = self._transition_prior_with_order(decode_state.context_tokens) |
| copy_prior = self._copy_prior(decode_state.context_tokens) |
| preference_prior = self._preference_prior() |
| probabilities, blend_weights = self._blend_probabilities( |
| base_probabilities, |
| answer_prior, |
| associative_prior, |
| transition_prior, |
| copy_prior, |
| preference_prior, |
| transition_order=transition_order, |
| generated_count=len(generated_tokens), |
| answer_locked=answer_locked, |
| answer_guided_start=use_answer_start, |
| ) |
| distribution = { |
| token: probabilities[index] |
| for index, token in enumerate(self.embedding_model.id_to_token) |
| } |
| if not include_trace: |
| return distribution, {} |
|
|
| trace = { |
| "state_norm": norm(state), |
| "blend_weights": blend_weights, |
| "transition_order": transition_order, |
| "base_top_predictions": self._top_entries_from_vector(base_probabilities, top_k), |
| "answer_top_predictions": self._top_entries_from_vector(answer_prior, top_k), |
| "prompt_answer_top_predictions": self._top_entries_from_vector(prompt_answer_prior, top_k), |
| "prompt_answer_start_top_predictions": self._top_entries_from_vector(prompt_answer_start_prior, top_k), |
| "answer_start_top_predictions": self._top_entries_from_vector(answer_start_prior, top_k), |
| "answer_sequence_top_predictions": self._top_entries_from_vector(answer_sequence_prior, top_k), |
| "associative_top_predictions": self._top_entries_from_vector(associative_prior, top_k), |
| "transition_top_predictions": self._top_entries_from_vector(transition_prior, top_k), |
| "copy_top_predictions": self._top_entries_from_vector(copy_prior, top_k), |
| "preference_top_predictions": self._top_entries_from_vector(preference_prior, top_k), |
| "final_top_predictions": self._top_entries_from_vector(probabilities, top_k), |
| "associative_matches": [ |
| { |
| "example_index": example_index, |
| "similarity": similarity, |
| **self._token_entry(token_id, similarity), |
| } |
| for similarity, token_id, example_index in associative_matches[:top_k] |
| ], |
| "answer_matches": [ |
| { |
| "example_index": example_index, |
| "similarity": similarity, |
| **self._token_entry(token_id, similarity), |
| } |
| for similarity, token_id, example_index in answer_matches[:top_k] |
| ], |
| "answer_start_matches": [ |
| { |
| "example_index": example_index, |
| "similarity": similarity, |
| **self._token_entry(token_id, similarity), |
| } |
| for similarity, token_id, example_index in answer_start_matches[:top_k] |
| ], |
| "answer_sequence_matches": [ |
| { |
| "example_index": example_index, |
| "similarity": similarity, |
| } |
| for similarity, _, example_index in answer_sequence_matches[:top_k] |
| ], |
| "reasoning_summary": self._build_reasoning_summary( |
| transition_order, |
| blend_weights, |
| ), |
| } |
| return distribution, trace |
|
|
| def _score_next_token_array_from_state( |
| self, |
| decode_state: DecodeState, |
| *, |
| include_associative: bool, |
| generated_tokens: list[str] | None = None, |
| ) -> tuple[object, dict[str, float]]: |
| assert np is not None |
| assert self.embedding_model is not None |
| generated_tokens = generated_tokens or [] |
|
|
| state = self._masked_decode_state_array(decode_state) |
| logits = self._apply_readout_array(state) |
| base_probabilities = self._calibrated_softmax_array(logits) |
| if decode_state.answer_matches is None: |
| decode_state.answer_matches = self._score_answer_matches(decode_state.answer_anchor_state) |
| answer_prior = np.asarray( |
| self._answer_prior_from_matches( |
| decode_state.answer_matches, |
| generated_tokens, |
| ), |
| dtype=np.float64, |
| ) |
| if decode_state.answer_sequence_matches is None: |
| decode_state.answer_sequence_matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| ) |
| answer_sequence_matches = decode_state.answer_sequence_matches |
| answer_sequence_prior = np.asarray( |
| self._answer_sequence_prior_from_matches( |
| answer_sequence_matches, |
| generated_tokens, |
| ), |
| dtype=np.float64, |
| ) |
| answer_sequence_confidence = ( |
| float(answer_sequence_prior.max()) if answer_sequence_prior.size else 0.0 |
| ) |
| answer_sequence_match_confidence = ( |
| answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 |
| ) |
| has_answer_sequence_prior = bool(np.any(answer_sequence_prior > 0.0)) |
| answer_locked = self._answer_sequence_should_lock( |
| answer_sequence_confidence=answer_sequence_confidence, |
| answer_sequence_match_confidence=answer_sequence_match_confidence, |
| has_answer_sequence_prior=has_answer_sequence_prior, |
| ) |
| if decode_state.prompt_answer_prior is None: |
| decode_state.prompt_answer_prior = self._prompt_answer_readout_prior_array( |
| decode_state.answer_anchor_state, |
| start=False, |
| ) |
| prompt_answer_prior = decode_state.prompt_answer_prior |
| prompt_answer_start_prior = np.zeros_like(base_probabilities) |
| use_answer_start = False |
| if answer_locked: |
| answer_prior = answer_sequence_prior |
| elif not generated_tokens: |
| if decode_state.prompt_answer_start_prior is None: |
| decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior_array( |
| decode_state.answer_anchor_state, |
| start=True, |
| ) |
| prompt_answer_start_prior = decode_state.prompt_answer_start_prior |
| if decode_state.answer_start_matches is None: |
| decode_state.answer_start_matches = self._score_answer_start_matches( |
| decode_state.answer_anchor_state |
| ) |
| answer_start_prior = np.asarray( |
| self._answer_prior_from_matches( |
| decode_state.answer_start_matches, |
| generated_tokens, |
| ), |
| dtype=np.float64, |
| ) |
| if np.any(answer_start_prior > 0.0) or np.any(prompt_answer_start_prior > 0.0): |
| start_blend = self._answer_start_blend_weights( |
| answer_sequence_match_confidence=answer_sequence_match_confidence |
| ) |
| answer_prior = self._weighted_prior_sum_array( |
| [ |
| (start_blend["prompt_answer_start"], prompt_answer_start_prior), |
| (start_blend["prompt_answer"], prompt_answer_prior), |
| (start_blend["answer_sequence"], answer_sequence_prior), |
| (start_blend["answer_start"], answer_start_prior), |
| ], |
| ) |
| use_answer_start = True |
| if answer_locked: |
| answer_prior = answer_sequence_prior |
| elif not use_answer_start and np.any(answer_sequence_prior > 0.0): |
| answer_prior = self._weighted_prior_sum_array( |
| [ |
| (0.50, prompt_answer_prior), |
| (0.30, answer_sequence_prior), |
| (0.20, answer_prior), |
| ], |
| ) |
| elif not use_answer_start and np.any(prompt_answer_prior > 0.0): |
| answer_prior = self._weighted_prior_sum_array( |
| [ |
| (0.65, prompt_answer_prior), |
| (0.35, answer_prior), |
| ], |
| ) |
| if include_associative and not use_answer_start: |
| associative_prior = np.asarray( |
| self._associative_prior_from_matches( |
| self._score_associative_matches(state) |
| ), |
| dtype=np.float64, |
| ) |
| else: |
| associative_prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| transition_prior, transition_order = self._transition_prior_array_with_order( |
| decode_state.context_tokens |
| ) |
| copy_prior = self._copy_prior_array(decode_state.context_tokens) |
| preference_prior = self._preference_prior_array() |
| return self._blend_probability_arrays( |
| base_probabilities, |
| answer_prior, |
| associative_prior, |
| transition_prior, |
| copy_prior, |
| preference_prior, |
| transition_order=transition_order, |
| generated_count=len(generated_tokens), |
| answer_locked=answer_locked, |
| answer_guided_start=use_answer_start, |
| ) |
|
|
| def _calibrated_softmax( |
| self, |
| logits: Vector, |
| *, |
| scale: float = READOUT_LOGIT_ZSCORE_SCALE, |
| ) -> Vector: |
| if np is not None: |
| return self._calibrated_softmax_array( |
| np.asarray(logits, dtype=np.float64), |
| scale=scale, |
| ).tolist() |
| if not logits: |
| return [] |
| center = mean(logits) |
| variance = mean([(value - center) * (value - center) for value in logits]) |
| spread = variance**0.5 |
| if spread <= 1e-12: |
| return softmax(logits) |
| calibrated = [ |
| max(-20.0, min(20.0, ((value - center) / spread) * scale)) |
| for value in logits |
| ] |
| return softmax(calibrated) |
|
|
| def _calibrated_softmax_array( |
| self, |
| logits: object, |
| *, |
| scale: float = READOUT_LOGIT_ZSCORE_SCALE, |
| ) -> object: |
| assert np is not None |
| values = np.asarray(logits, dtype=np.float64) |
| if values.size == 0: |
| return values |
| spread = float(values.std()) |
| if spread > 1e-12: |
| values = ((values - float(values.mean())) / spread) * scale |
| values = np.clip(values, -20.0, 20.0) |
| else: |
| values = values - float(values.max()) |
| values = values - float(values.max()) |
| exponentials = np.exp(values) |
| total = float(exponentials.sum()) |
| if total <= 0.0: |
| return np.full(values.shape, 1.0 / max(1, values.size), dtype=np.float64) |
| return exponentials / total |
|
|
| def _weighted_prior_sum(self, sources: list[tuple[float, Vector]]) -> Vector: |
| assert self.embedding_model is not None |
| active_sources = [ |
| (weight, vector) |
| for weight, vector in sources |
| if weight > 0.0 and any(value > 0.0 for value in vector) |
| ] |
| if not active_sources: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| total_weight = sum(weight for weight, _ in active_sources) |
| merged = [0.0 for _ in self.embedding_model.id_to_token] |
| for weight, vector in active_sources: |
| normalized_weight = weight / total_weight |
| for index, value in enumerate(vector): |
| merged[index] += normalized_weight * value |
| return _normalize_vector(merged) |
|
|
| def _weighted_prior_sum_array(self, sources: list[tuple[float, object]]) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| active_sources = [ |
| (weight, np.asarray(vector, dtype=np.float64)) |
| for weight, vector in sources |
| if weight > 0.0 and np.any(np.asarray(vector, dtype=np.float64) > 0.0) |
| ] |
| if not active_sources: |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| total_weight = sum(weight for weight, _ in active_sources) |
| merged = np.zeros_like(active_sources[0][1], dtype=np.float64) |
| for weight, vector in active_sources: |
| merged += (weight / total_weight) * vector |
| total = float(merged.sum()) |
| if total > 0.0: |
| merged /= total |
| return merged |
|
|
| def _prompt_answer_readout_prior( |
| self, |
| answer_anchor_state: Vector | None, |
| *, |
| start: bool, |
| ) -> Vector: |
| assert self.embedding_model is not None |
| if answer_anchor_state is None: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| weights = self.prompt_answer_start_weights if start else self.prompt_answer_weights |
| bias = self.prompt_answer_start_bias if start else self.prompt_answer_bias |
| if np is not None: |
| return self._prompt_answer_readout_prior_array( |
| answer_anchor_state, |
| start=start, |
| ).tolist() |
| if not weights: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) |
| logits = apply_readout(weights, state) |
| if bias: |
| logits = [value + bias[index] for index, value in enumerate(logits)] |
| return self._calibrated_softmax( |
| logits, |
| scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, |
| ) |
|
|
| def _prompt_answer_readout_prior_array( |
| self, |
| answer_anchor_state: Vector | None, |
| *, |
| start: bool, |
| ) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| if answer_anchor_state is None: |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| weights = ( |
| self.prompt_answer_start_weights_array |
| if start |
| else self.prompt_answer_weights_array |
| ) |
| bias = self.prompt_answer_start_bias_array if start else self.prompt_answer_bias_array |
| if weights is None: |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| state_array = self._center_state_array( |
| self._masked_combined_state_array(answer_anchor_state) |
| ) |
| logits = weights @ state_array |
| if bias is not None and bias.shape == logits.shape: |
| logits = logits + bias |
| return self._calibrated_softmax_array( |
| logits, |
| scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, |
| ) |
|
|
| def save(self, path: str | Path) -> None: |
| self._require_fit() |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
| assert self.ternary_mask is not None |
| assert self.readout_weights is not None |
| assert self.associative_keys is not None |
| assert self.associative_values is not None |
| assert self.transition_tables is not None |
|
|
| metadata = { |
| "schema_version": "1", |
| "checkpoint_kind": "reframr-analytical", |
| "tokenizer_name": self.tokenizer.name, |
| "config": json.dumps(self.config.to_dict(), separators=(",", ":")), |
| "tokenizer": json.dumps(self.tokenizer.to_dict(), separators=(",", ":")), |
| "embedding_id_to_token": json.dumps(self.embedding_model.id_to_token, separators=(",", ":")), |
| "tokenizer_vocab_size": str(self.tokenizer.vocab_size), |
| "transition_tables": json.dumps(self._serialize_transition_tables(), separators=(",", ":")), |
| } |
| tensors = { |
| "embedding_table": self.embedding_model.embeddings, |
| "ternary_scale": [self.ternary_scale], |
| "ternary_mask": self.ternary_mask, |
| "readout_weights": self.readout_weights, |
| "readout_bias": self.readout_bias |
| or [0.0 for _ in self.embedding_model.id_to_token], |
| "prompt_answer_weights": self.prompt_answer_weights |
| if self.prompt_answer_weights is not None |
| else [], |
| "prompt_answer_bias": self.prompt_answer_bias |
| or [0.0 for _ in self.embedding_model.id_to_token], |
| "prompt_answer_start_weights": self.prompt_answer_start_weights |
| if self.prompt_answer_start_weights is not None |
| else [], |
| "prompt_answer_start_bias": self.prompt_answer_start_bias |
| or [0.0 for _ in self.embedding_model.id_to_token], |
| "trace_token_weights": self.trace_token_weights |
| or [1.0 for _ in self.embedding_model.id_to_token], |
| "preference_bias": self.preference_bias |
| or [0.0 for _ in self.embedding_model.id_to_token], |
| "state_offset": self.state_offset |
| or [0.0 for _ in range(self._combined_state_width())], |
| "associative_keys": self.associative_keys, |
| "associative_values": self.associative_values, |
| "answer_keys": self.answer_keys if self.answer_keys is not None else [], |
| "answer_values": self.answer_values if self.answer_values is not None else [], |
| "answer_start_keys": self.answer_start_keys if self.answer_start_keys is not None else [], |
| "answer_start_values": self.answer_start_values if self.answer_start_values is not None else [], |
| "answer_sequence_keys": self.answer_sequence_keys if self.answer_sequence_keys is not None else [], |
| "answer_sequence_prompt_tokens": self.answer_sequence_prompt_tokens if self.answer_sequence_prompt_tokens is not None else [], |
| "answer_sequence_tokens": self.answer_sequence_tokens if self.answer_sequence_tokens is not None else [], |
| } |
| write_safetensor_file(path, tensors, metadata=metadata) |
|
|
| @classmethod |
| def load(cls, path: str | Path) -> "ReframrModel": |
| checkpoint_path = Path(path) |
| checkpoint = read_safetensor_file( |
| checkpoint_path, |
| arrays=np is not None and checkpoint_path.stat().st_size > 10_000_000, |
| ) |
| metadata = checkpoint.metadata |
| config = ReframrConfig.from_dict(json.loads(metadata["config"])) |
| model = cls(config) |
| model.tokenizer = NativeTokenizer.from_dict(json.loads(metadata["tokenizer"])) |
| id_to_token = [str(token) for token in json.loads(metadata["embedding_id_to_token"])] |
| embedding_table = checkpoint.tensors["embedding_table"] |
| if np is not None and hasattr(embedding_table, "shape"): |
| embeddings = embedding_table.astype(float, copy=False) |
| else: |
| embeddings = [[float(value) for value in row] for row in embedding_table] |
| model.embedding_model = EmbeddingModel( |
| token_to_id={token: index for index, token in enumerate(id_to_token)}, |
| id_to_token=id_to_token, |
| embeddings=embeddings, |
| ppmi_matrix=[], |
| ) |
| model.memory_units = [ |
| AnalyticalMemoryUnit(model.config.state_dim, timescale) |
| for timescale in model.config.timescales |
| ] |
| model.ternary_scale = float(checkpoint.tensors["ternary_scale"][0]) |
| model.ternary_mask = [int(value) for value in checkpoint.tensors["ternary_mask"]] |
| readout_tensor = checkpoint.tensors["readout_weights"] |
| model.readout_weights = ( |
| readout_tensor.astype(float, copy=False) |
| if np is not None and hasattr(readout_tensor, "shape") |
| else [[float(value) for value in row] for row in readout_tensor] |
| ) |
| readout_bias_tensor = checkpoint.tensors.get("readout_bias", []) |
| model.readout_bias = [ |
| float(value) for value in ( |
| readout_bias_tensor.tolist() |
| if hasattr(readout_bias_tensor, "tolist") |
| else readout_bias_tensor |
| ) |
| ] |
| if not model.readout_bias: |
| model.readout_bias = [0.0 for _ in id_to_token] |
| prompt_answer_tensor = checkpoint.tensors.get("prompt_answer_weights", []) |
| model.prompt_answer_weights = ( |
| prompt_answer_tensor.astype(float, copy=False) |
| if np is not None |
| and hasattr(prompt_answer_tensor, "shape") |
| and len(prompt_answer_tensor.shape) == 2 |
| else [[float(value) for value in row] for row in prompt_answer_tensor] |
| ) |
| prompt_answer_bias_tensor = checkpoint.tensors.get("prompt_answer_bias", []) |
| model.prompt_answer_bias = [ |
| float(value) for value in ( |
| prompt_answer_bias_tensor.tolist() |
| if hasattr(prompt_answer_bias_tensor, "tolist") |
| else prompt_answer_bias_tensor |
| ) |
| ] |
| if not model.prompt_answer_bias: |
| model.prompt_answer_bias = [0.0 for _ in id_to_token] |
| prompt_answer_start_tensor = checkpoint.tensors.get("prompt_answer_start_weights", []) |
| model.prompt_answer_start_weights = ( |
| prompt_answer_start_tensor.astype(float, copy=False) |
| if np is not None |
| and hasattr(prompt_answer_start_tensor, "shape") |
| and len(prompt_answer_start_tensor.shape) == 2 |
| else [[float(value) for value in row] for row in prompt_answer_start_tensor] |
| ) |
| prompt_answer_start_bias_tensor = checkpoint.tensors.get("prompt_answer_start_bias", []) |
| model.prompt_answer_start_bias = [ |
| float(value) for value in ( |
| prompt_answer_start_bias_tensor.tolist() |
| if hasattr(prompt_answer_start_bias_tensor, "tolist") |
| else prompt_answer_start_bias_tensor |
| ) |
| ] |
| if not model.prompt_answer_start_bias: |
| model.prompt_answer_start_bias = [0.0 for _ in id_to_token] |
| trace_weight_tensor = checkpoint.tensors.get("trace_token_weights", []) |
| model.trace_token_weights = [ |
| float(value) for value in ( |
| trace_weight_tensor.tolist() |
| if hasattr(trace_weight_tensor, "tolist") |
| else trace_weight_tensor |
| ) |
| ] |
| if not model.trace_token_weights: |
| model.trace_token_weights = [ |
| 0.0 if token in model.tokenizer.special_tokens else 1.0 |
| for token in id_to_token |
| ] |
| preference_bias_tensor = checkpoint.tensors.get("preference_bias", []) |
| model.preference_bias = [ |
| float(value) for value in ( |
| preference_bias_tensor.tolist() |
| if hasattr(preference_bias_tensor, "tolist") |
| else preference_bias_tensor |
| ) |
| ] |
| if not model.preference_bias: |
| model.preference_bias = [0.0 for _ in id_to_token] |
| state_offset_tensor = checkpoint.tensors.get("state_offset", []) |
| model.state_offset = [ |
| float(value) for value in ( |
| state_offset_tensor.tolist() |
| if hasattr(state_offset_tensor, "tolist") |
| else state_offset_tensor |
| ) |
| ] |
| if not model.state_offset: |
| model.state_offset = [0.0 for _ in range(model._combined_state_width())] |
| associative_tensor = checkpoint.tensors.get("associative_keys", []) |
| model.associative_keys = ( |
| associative_tensor.astype(float, copy=False) |
| if np is not None and hasattr(associative_tensor, "shape") |
| else [[float(value) for value in row] for row in associative_tensor] |
| ) |
| if np is not None and hasattr(model.associative_keys, "shape"): |
| model.associative_key_norms = np.linalg.norm(model.associative_keys, axis=1).tolist() |
| else: |
| model.associative_key_norms = [norm(key) for key in model.associative_keys] |
| raw_associative_values = checkpoint.tensors.get("associative_values", []) |
| model.associative_values = [ |
| int(value) for value in ( |
| raw_associative_values.tolist() |
| if hasattr(raw_associative_values, "tolist") |
| else raw_associative_values |
| ) |
| ] |
| answer_tensor = checkpoint.tensors.get("answer_keys", []) |
| if np is not None and hasattr(answer_tensor, "shape"): |
| model.answer_keys = ( |
| answer_tensor.astype(float, copy=False) |
| if len(answer_tensor.shape) == 2 |
| else [] |
| ) |
| else: |
| model.answer_keys = [[float(value) for value in row] for row in answer_tensor] |
| if ( |
| np is not None |
| and hasattr(model.answer_keys, "shape") |
| and len(model.answer_keys.shape) == 2 |
| ): |
| model.answer_key_norms = np.linalg.norm(model.answer_keys, axis=1).tolist() |
| else: |
| model.answer_key_norms = [norm(key) for key in model.answer_keys] |
| raw_answer_values = checkpoint.tensors.get("answer_values", []) |
| model.answer_values = [ |
| int(value) for value in ( |
| raw_answer_values.tolist() |
| if hasattr(raw_answer_values, "tolist") |
| else raw_answer_values |
| ) |
| ] |
| answer_start_tensor = checkpoint.tensors.get("answer_start_keys", []) |
| if np is not None and hasattr(answer_start_tensor, "shape"): |
| model.answer_start_keys = ( |
| answer_start_tensor.astype(float, copy=False) |
| if len(answer_start_tensor.shape) == 2 |
| else [] |
| ) |
| else: |
| model.answer_start_keys = [ |
| [float(value) for value in row] for row in answer_start_tensor |
| ] |
| if ( |
| np is not None |
| and hasattr(model.answer_start_keys, "shape") |
| and len(model.answer_start_keys.shape) == 2 |
| ): |
| model.answer_start_key_norms = np.linalg.norm(model.answer_start_keys, axis=1).tolist() |
| else: |
| model.answer_start_key_norms = [norm(key) for key in model.answer_start_keys] |
| raw_answer_start_values = checkpoint.tensors.get("answer_start_values", []) |
| model.answer_start_values = [ |
| int(value) for value in ( |
| raw_answer_start_values.tolist() |
| if hasattr(raw_answer_start_values, "tolist") |
| else raw_answer_start_values |
| ) |
| ] |
| answer_sequence_tensor = checkpoint.tensors.get("answer_sequence_keys", []) |
| if np is not None and hasattr(answer_sequence_tensor, "shape"): |
| model.answer_sequence_keys = ( |
| answer_sequence_tensor.astype(float, copy=False) |
| if len(answer_sequence_tensor.shape) == 2 |
| else [] |
| ) |
| else: |
| model.answer_sequence_keys = [ |
| [float(value) for value in row] for row in answer_sequence_tensor |
| ] |
| if ( |
| np is not None |
| and hasattr(model.answer_sequence_keys, "shape") |
| and len(model.answer_sequence_keys.shape) == 2 |
| ): |
| model.answer_sequence_key_norms = np.linalg.norm( |
| model.answer_sequence_keys, |
| axis=1, |
| ).tolist() |
| else: |
| model.answer_sequence_key_norms = [norm(key) for key in model.answer_sequence_keys] |
| raw_answer_sequence_prompt_tokens = checkpoint.tensors.get("answer_sequence_prompt_tokens", []) |
| if np is not None and hasattr(raw_answer_sequence_prompt_tokens, "shape"): |
| model.answer_sequence_prompt_tokens = raw_answer_sequence_prompt_tokens.astype(int, copy=False) |
| else: |
| model.answer_sequence_prompt_tokens = [ |
| [int(value) for value in row] for row in raw_answer_sequence_prompt_tokens |
| ] |
| raw_answer_sequence_tokens = checkpoint.tensors.get("answer_sequence_tokens", []) |
| if np is not None and hasattr(raw_answer_sequence_tokens, "shape"): |
| model.answer_sequence_tokens = raw_answer_sequence_tokens.astype(int, copy=False) |
| else: |
| model.answer_sequence_tokens = [ |
| [int(value) for value in row] for row in raw_answer_sequence_tokens |
| ] |
| model.transition_tables = model._deserialize_transition_tables( |
| json.loads(metadata.get("transition_tables", "{}")) |
| ) |
| model._refresh_numeric_caches() |
| return model |
|
|
| def _collect_training_examples( |
| self, |
| tokens: list[str], |
| ) -> tuple[list[Vector], list[Vector], list[int]]: |
| assert self.embedding_model is not None |
| if np is not None: |
| hidden_states = [ |
| np.zeros(self.config.state_dim, dtype=np.float64) |
| for _ in self.config.timescales |
| ] |
| context_traces = [ |
| np.zeros(self.config.embedding_dim, dtype=np.float64) |
| for _ in self.config.timescales |
| ] |
| zero_embedding: Vector | object = np.zeros(self.config.embedding_dim, dtype=np.float64) |
| else: |
| hidden_states = [zeros_vector(self.config.state_dim) for _ in self.config.timescales] |
| context_traces = [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] |
| zero_embedding = zeros_vector(self.config.embedding_dim) |
| states: list[Vector] = [] |
| labels: list[Vector] = [] |
| label_ids: list[int] = [] |
| token_ids = [ |
| self.embedding_model.token_to_id.get(token, -1) |
| for token in tokens |
| ] |
| example_count = max(0, len(tokens) - 1) |
| stride = 1 |
| if self.config.max_training_examples and example_count > self.config.max_training_examples: |
| stride = max( |
| 1, |
| (example_count + self.config.max_training_examples - 1) // self.config.max_training_examples, |
| ) |
|
|
| for index in range(len(tokens) - 1): |
| token = tokens[index] |
| token_id = token_ids[index] |
| embedding = ( |
| self.embedding_model.embeddings[token_id] |
| if token_id >= 0 |
| else zero_embedding |
| ) |
| trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) |
| hidden_states, context_traces, combined_state = self._step_hidden_states_from_embedding( |
| hidden_states, |
| context_traces, |
| embedding, |
| trace_embedding=trace_embedding, |
| ) |
| if stride > 1 and index % stride != 0 and index != len(tokens) - 2: |
| continue |
| states.append(combined_state) |
| next_token_id = token_ids[index + 1] |
| labels.append(self._one_hot_from_id(next_token_id)) |
| label_ids.append(next_token_id) |
|
|
| if self.config.max_training_examples and len(states) > self.config.max_training_examples: |
| states = states[: self.config.max_training_examples] |
| labels = labels[: self.config.max_training_examples] |
| label_ids = label_ids[: self.config.max_training_examples] |
| return states, labels, label_ids |
|
|
| def _is_punctuation_piece(self, piece: str) -> bool: |
| return bool(piece) and all(character in string.punctuation for character in piece) |
|
|
| def _encode_context(self, tokens: list[str]) -> Vector: |
| return self._masked_decode_state(self._build_decode_state(tokens)) |
|
|
| def _build_decode_state(self, tokens: list[str]) -> DecodeState: |
| assert self.memory_units is not None |
|
|
| state = DecodeState( |
| hidden_states=( |
| [ |
| np.zeros(self.config.state_dim, dtype=np.float64) |
| for _ in self.config.timescales |
| ] |
| if np is not None |
| else [zeros_vector(self.config.state_dim) for _ in self.config.timescales] |
| ), |
| context_traces=( |
| [ |
| np.zeros(self.config.embedding_dim, dtype=np.float64) |
| for _ in self.config.timescales |
| ] |
| if np is not None |
| else [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] |
| ), |
| combined_state=self._zero_combined_state(), |
| context_tokens=[], |
| ) |
| for token in tokens: |
| self._advance_decode_state(state, token) |
| return state |
|
|
| def _advance_decode_state(self, state: DecodeState, token: str) -> DecodeState: |
| next_hidden_states, next_context_traces, combined_state = self._step_hidden_states( |
| state.hidden_states, |
| state.context_traces, |
| token, |
| ) |
| state.hidden_states = next_hidden_states |
| state.context_traces = next_context_traces |
| state.combined_state = combined_state |
| state.context_tokens.append(token) |
| if token == "<answer>": |
| state.answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:] |
| state.answer_matches = None |
| state.answer_start_matches = None |
| state.answer_sequence_matches = None |
| state.prompt_answer_prior = None |
| state.prompt_answer_start_prior = None |
| return state |
|
|
| def _masked_decode_state(self, state: DecodeState) -> Vector: |
| assert self.ternary_mask is not None |
| return apply_ternary_mask(state.combined_state, self.ternary_mask, self.ternary_scale) |
|
|
| def _masked_combined_state(self, combined_state: Vector) -> Vector: |
| assert self.ternary_mask is not None |
| return apply_ternary_mask(combined_state, self.ternary_mask, self.ternary_scale) |
|
|
| def _masked_decode_state_array(self, state: DecodeState) -> object: |
| assert np is not None |
| if self.ternary_mask_array is None: |
| return np.asarray(self._masked_decode_state(state), dtype=RUNTIME_ARRAY_DTYPE) |
| return ( |
| np.asarray(state.combined_state, dtype=RUNTIME_ARRAY_DTYPE) |
| * self.ternary_scale |
| * self.ternary_mask_array |
| ) |
|
|
| def _masked_combined_state_array(self, combined_state: Vector) -> object: |
| assert np is not None |
| if self.ternary_mask_array is None: |
| return np.asarray(self._masked_combined_state(combined_state), dtype=RUNTIME_ARRAY_DTYPE) |
| return ( |
| np.asarray(combined_state, dtype=RUNTIME_ARRAY_DTYPE) |
| * self.ternary_scale |
| * self.ternary_mask_array |
| ) |
|
|
| def _center_state_vector(self, state: Vector) -> Vector: |
| if not self.state_offset or len(self.state_offset) != len(state): |
| return state |
| return [value - self.state_offset[index] for index, value in enumerate(state)] |
|
|
| def _center_state_array(self, state: object) -> object: |
| assert np is not None |
| state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.state_offset_array is None or self.state_offset_array.shape != state_array.shape: |
| return state_array |
| return state_array - self.state_offset_array |
|
|
| def _zero_combined_state(self) -> Vector: |
| return [0.0 for _ in range(self._combined_state_width())] |
|
|
| def _combined_state_width(self) -> int: |
| return (self.config.state_dim + self.config.embedding_dim) * len(self.config.timescales) |
|
|
| def _derive_trace_token_weights_from_counts(self, token_counts: dict[str, float]) -> Vector: |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
| counts = [ |
| float(token_counts.get(token, 0.0)) |
| for token in self.embedding_model.id_to_token |
| ] |
| positive_counts = sorted(value for value in counts if value > 0.0) |
| reference = ( |
| positive_counts[len(positive_counts) // 2] |
| if positive_counts |
| else 1.0 |
| ) |
| weights: Vector = [] |
| for token, count in zip(self.embedding_model.id_to_token, counts): |
| if token in self.tokenizer.special_tokens: |
| weights.append(0.0) |
| elif count <= 0.0: |
| weights.append(1.0) |
| else: |
| weight = (reference / count) ** 0.75 |
| weights.append(max(0.08, min(4.8, weight))) |
| return weights |
|
|
| def _token_id_for_token(self, token: str) -> int: |
| assert self.embedding_model is not None |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None and token.lower() != token: |
| token_id = self.embedding_model.token_to_id.get(token.lower()) |
| return int(token_id) if token_id is not None else -1 |
|
|
| def _trace_embedding_from_token_id( |
| self, |
| embedding: Vector | object, |
| token_id: int, |
| ) -> Vector | object: |
| if token_id < 0: |
| return embedding |
| if self.trace_embedding_table_array is not None: |
| return self.trace_embedding_table_array[token_id] |
| weight = self.trace_token_weights[token_id] if self.trace_token_weights is not None else 1.0 |
| dimension = self.config.embedding_dim |
| if hasattr(embedding, "shape"): |
| trace_embedding = embedding * weight |
| for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: |
| bucket = (token_id * bucket_multiplier + bucket_offset) % dimension |
| sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 |
| trace_embedding[bucket] += weight * TRACE_IDENTITY_SCALE * sign |
| return trace_embedding |
| trace_values = [float(value) * weight for value in embedding] |
| for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: |
| bucket = (token_id * bucket_multiplier + bucket_offset) % dimension |
| sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 |
| trace_values[bucket] += weight * TRACE_IDENTITY_SCALE * sign |
| return trace_values |
|
|
| def _build_trace_embedding_table_array(self, embedding_array: object) -> object | None: |
| if np is None or self.trace_token_weights is None: |
| return None |
| values = np.asarray(embedding_array, dtype=np.float64) |
| if values.size == 0 or len(values.shape) != 2: |
| return None |
| weights = np.asarray(self.trace_token_weights, dtype=np.float64) |
| if weights.shape[0] != values.shape[0]: |
| return None |
| trace_values = values * weights[:, None] |
| if values.shape[1] <= 0: |
| return trace_values |
| token_ids = np.arange(values.shape[0], dtype=np.int64) |
| for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: |
| buckets = ((token_ids * bucket_multiplier + bucket_offset) % values.shape[1]).astype( |
| np.int64, |
| copy=False, |
| ) |
| signs = np.where( |
| ((token_ids * sign_multiplier + sign_offset) & 1) == 0, |
| 1.0, |
| -1.0, |
| ) |
| np.add.at(trace_values, (token_ids, buckets), weights * TRACE_IDENTITY_SCALE * signs) |
| return trace_values |
|
|
| def _refresh_numeric_caches(self) -> None: |
| if np is None: |
| self.ternary_mask_array = None |
| self.readout_weights_array = None |
| self.readout_bias_array = None |
| self.prompt_answer_weights_array = None |
| self.prompt_answer_bias_array = None |
| self.prompt_answer_start_weights_array = None |
| self.prompt_answer_start_bias_array = None |
| self.trace_token_weights_array = None |
| self.trace_embedding_table_array = None |
| self.preference_bias_array = None |
| self.preference_valid_mask_array = None |
| self.state_offset_array = None |
| self.associative_keys_array = None |
| self.associative_key_norms_array = None |
| self.associative_values_array = None |
| self.associative_valid_mask_array = None |
| self.answer_keys_array = None |
| self.answer_key_norms_array = None |
| self.answer_similarity_keys_array = None |
| self.answer_similarity_key_norms_array = None |
| self.answer_similarity_mask_array = None |
| self.answer_values_array = None |
| self.answer_valid_mask_array = None |
| self.answer_start_keys_array = None |
| self.answer_start_key_norms_array = None |
| self.answer_start_similarity_keys_array = None |
| self.answer_start_similarity_key_norms_array = None |
| self.answer_start_values_array = None |
| self.answer_start_valid_mask_array = None |
| self.answer_sequence_keys_array = None |
| self.answer_sequence_key_norms_array = None |
| self.answer_sequence_similarity_keys_array = None |
| self.answer_sequence_similarity_key_norms_array = None |
| self.answer_sequence_prompt_tokens_array = None |
| self.answer_sequence_tokens_array = None |
| self.answer_sequence_prompt_weight_maps = None |
| self.answer_sequence_prompt_weight_norms = None |
| self.answer_sequence_prompt_bigram_sets = None |
| self.answer_sequence_prompt_trigram_sets = None |
| self.answer_sequence_prompt_number_sets = None |
| self.answer_sequence_prompt_inverted_index = None |
| self._refresh_answer_sequence_prompt_overlap_cache() |
| return |
| self.ternary_mask_array = ( |
| np.asarray(self.ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.ternary_mask is not None |
| else None |
| ) |
| self.readout_weights_array = ( |
| np.asarray(self.readout_weights, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.readout_weights is not None |
| else None |
| ) |
| self.readout_bias_array = ( |
| np.asarray(self.readout_bias, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.readout_bias is not None |
| else None |
| ) |
| self.prompt_answer_weights_array = ( |
| np.asarray(self.prompt_answer_weights, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.prompt_answer_weights is not None |
| and len(self.prompt_answer_weights) > 0 |
| else None |
| ) |
| self.prompt_answer_bias_array = ( |
| np.asarray(self.prompt_answer_bias, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.prompt_answer_bias is not None |
| else None |
| ) |
| self.prompt_answer_start_weights_array = ( |
| np.asarray(self.prompt_answer_start_weights, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.prompt_answer_start_weights is not None |
| and len(self.prompt_answer_start_weights) > 0 |
| else None |
| ) |
| self.prompt_answer_start_bias_array = ( |
| np.asarray(self.prompt_answer_start_bias, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.prompt_answer_start_bias is not None |
| else None |
| ) |
| self.trace_token_weights_array = ( |
| np.asarray(self.trace_token_weights, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.trace_token_weights is not None |
| else None |
| ) |
| trace_embedding_table = ( |
| self._build_trace_embedding_table_array(self.embedding_model.embeddings) |
| if self.embedding_model is not None and self.trace_token_weights is not None |
| else None |
| ) |
| self.trace_embedding_table_array = ( |
| trace_embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| if trace_embedding_table is not None |
| else None |
| ) |
| self.preference_bias_array = ( |
| np.asarray(self.preference_bias, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.preference_bias is not None |
| else None |
| ) |
| self.preference_valid_mask_array = ( |
| np.asarray( |
| [ |
| self._eligible_preference_token(token) |
| for token in self.embedding_model.id_to_token |
| ], |
| dtype=bool, |
| ) |
| if self.embedding_model is not None and self.tokenizer is not None |
| else None |
| ) |
| self.state_offset_array = ( |
| np.asarray(self.state_offset, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.state_offset is not None |
| else None |
| ) |
| self.associative_keys_array = ( |
| np.asarray(self.associative_keys, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.associative_keys is not None and len(self.associative_keys) > 0 |
| else None |
| ) |
| self.associative_key_norms_array = ( |
| np.asarray(self.associative_key_norms, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.associative_key_norms is not None and len(self.associative_key_norms) > 0 |
| else None |
| ) |
| self.associative_values_array = ( |
| np.asarray(self.associative_values, dtype=np.int64) |
| if self.associative_values is not None and len(self.associative_values) > 0 |
| else None |
| ) |
| self.associative_valid_mask_array = ( |
| self.associative_values_array >= 0 |
| if self.associative_values_array is not None |
| else None |
| ) |
| self.answer_keys_array = ( |
| np.asarray(self.answer_keys, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_keys is not None and len(self.answer_keys) > 0 |
| else None |
| ) |
| self.answer_key_norms_array = ( |
| np.asarray(self.answer_key_norms, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_key_norms is not None and len(self.answer_key_norms) > 0 |
| else None |
| ) |
| self.answer_similarity_keys_array = None |
| self.answer_similarity_key_norms_array = None |
| self.answer_similarity_mask_array = None |
| if self.answer_keys_array is not None and len(self.answer_keys_array.shape) == 2: |
| width = int(self.answer_keys_array.shape[1]) |
| block_width = self.config.state_dim + self.config.embedding_dim |
| expected_width = block_width * len(self.config.timescales) |
| if block_width > 0 and width == expected_width: |
| mask = np.zeros(width, dtype=RUNTIME_ARRAY_DTYPE) |
| for scale_index in range(len(self.config.timescales)): |
| start = scale_index * block_width + self.config.state_dim |
| end = start + self.config.embedding_dim |
| mask[start:end] = 1.0 |
| self.answer_similarity_mask_array = mask |
| self.answer_similarity_keys_array = self.answer_keys_array * mask[None, :] |
| self.answer_similarity_key_norms_array = np.linalg.norm( |
| self.answer_similarity_keys_array, |
| axis=1, |
| ).astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| self.answer_values_array = ( |
| np.asarray(self.answer_values, dtype=np.int64) |
| if self.answer_values is not None and len(self.answer_values) > 0 |
| else None |
| ) |
| self.answer_valid_mask_array = ( |
| self.answer_values_array >= 0 |
| if self.answer_values_array is not None |
| else None |
| ) |
| self.answer_start_keys_array = ( |
| np.asarray(self.answer_start_keys, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_start_keys is not None and len(self.answer_start_keys) > 0 |
| else None |
| ) |
| self.answer_start_key_norms_array = ( |
| np.asarray(self.answer_start_key_norms, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_start_key_norms is not None and len(self.answer_start_key_norms) > 0 |
| else None |
| ) |
| self.answer_start_similarity_keys_array = None |
| self.answer_start_similarity_key_norms_array = None |
| if ( |
| self.answer_start_keys_array is not None |
| and len(self.answer_start_keys_array.shape) == 2 |
| and self.answer_similarity_mask_array is not None |
| and int(self.answer_start_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) |
| ): |
| self.answer_start_similarity_keys_array = ( |
| self.answer_start_keys_array * self.answer_similarity_mask_array[None, :] |
| ) |
| self.answer_start_similarity_key_norms_array = np.linalg.norm( |
| self.answer_start_similarity_keys_array, |
| axis=1, |
| ).astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| self.answer_start_values_array = ( |
| np.asarray(self.answer_start_values, dtype=np.int64) |
| if self.answer_start_values is not None and len(self.answer_start_values) > 0 |
| else None |
| ) |
| self.answer_start_valid_mask_array = ( |
| self.answer_start_values_array >= 0 |
| if self.answer_start_values_array is not None |
| else None |
| ) |
| self.answer_sequence_keys_array = ( |
| np.asarray(self.answer_sequence_keys, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_sequence_keys is not None and len(self.answer_sequence_keys) > 0 |
| else None |
| ) |
| self.answer_sequence_key_norms_array = ( |
| np.asarray(self.answer_sequence_key_norms, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.answer_sequence_key_norms is not None and len(self.answer_sequence_key_norms) > 0 |
| else None |
| ) |
| self.answer_sequence_similarity_keys_array = None |
| self.answer_sequence_similarity_key_norms_array = None |
| if ( |
| self.answer_sequence_keys_array is not None |
| and len(self.answer_sequence_keys_array.shape) == 2 |
| and self.answer_similarity_mask_array is not None |
| and int(self.answer_sequence_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) |
| ): |
| self.answer_sequence_similarity_keys_array = ( |
| self.answer_sequence_keys_array * self.answer_similarity_mask_array[None, :] |
| ) |
| self.answer_sequence_similarity_key_norms_array = np.linalg.norm( |
| self.answer_sequence_similarity_keys_array, |
| axis=1, |
| ).astype(RUNTIME_ARRAY_DTYPE, copy=False) |
| self.answer_sequence_tokens_array = ( |
| np.asarray(self.answer_sequence_tokens, dtype=np.int64) |
| if self.answer_sequence_tokens is not None and len(self.answer_sequence_tokens) > 0 |
| else None |
| ) |
| self.answer_sequence_prompt_tokens_array = ( |
| np.asarray(self.answer_sequence_prompt_tokens, dtype=np.int64) |
| if self.answer_sequence_prompt_tokens is not None |
| and len(self.answer_sequence_prompt_tokens) > 0 |
| else None |
| ) |
| self._refresh_answer_sequence_prompt_overlap_cache() |
|
|
| def _refresh_answer_sequence_prompt_overlap_cache(self) -> None: |
| self.answer_sequence_prompt_weight_maps = None |
| self.answer_sequence_prompt_weight_norms = None |
| self.answer_sequence_prompt_bigram_sets = None |
| self.answer_sequence_prompt_trigram_sets = None |
| self.answer_sequence_prompt_number_sets = None |
| self.answer_sequence_prompt_inverted_index = None |
| self.answer_sequence_prompt_specificity = None |
| if self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None: |
| return |
| inverted: dict[int, list[int]] = {} |
| row_id_lists: list[list[int]] = [] |
| for row in self.answer_sequence_prompt_tokens: |
| row_values = row.tolist() if hasattr(row, "tolist") else row |
| row_ids: list[int] = [] |
| for raw_token_id in row_values: |
| token_id = int(raw_token_id) |
| if token_id < 0 or token_id >= len(self.trace_token_weights): |
| continue |
| row_ids.append(token_id) |
| sequence_index = len(row_id_lists) |
| for token_id in set(row_ids): |
| inverted.setdefault(token_id, []).append(sequence_index) |
| row_id_lists.append(row_ids) |
|
|
| total_rows = len(row_id_lists) |
| specificity = { |
| token_id: self._prompt_overlap_token_specificity(len(indices), total_rows) |
| for token_id, indices in inverted.items() |
| } |
| self.answer_sequence_prompt_inverted_index = inverted |
| self.answer_sequence_prompt_specificity = specificity |
|
|
| weight_maps: list[dict[int, float]] = [] |
| weight_norms: list[float] = [] |
| bigram_sets: list[set[tuple[int, int]]] = [] |
| trigram_sets: list[set[tuple[int, int, int]]] = [] |
| number_sets: list[set[str]] = [] |
| for row_ids in row_id_lists: |
| row_weights: dict[int, float] = {} |
| for token_id in row_ids: |
| row_weights[token_id] = max( |
| row_weights.get(token_id, 0.0), |
| float(self.trace_token_weights[token_id]) * specificity.get(token_id, 1.0), |
| ) |
| weight_maps.append(row_weights) |
| weight_norms.append(sum(value * value for value in row_weights.values()) ** 0.5) |
| bigram_sets.append( |
| { |
| (row_ids[index], row_ids[index + 1]) |
| for index in range(len(row_ids) - 1) |
| } |
| ) |
| trigram_sets.append( |
| { |
| (row_ids[index], row_ids[index + 1], row_ids[index + 2]) |
| for index in range(len(row_ids) - 2) |
| } |
| ) |
| number_sets.append(self._number_strings_from_token_ids(row_ids)) |
| self.answer_sequence_prompt_weight_maps = weight_maps |
| self.answer_sequence_prompt_weight_norms = weight_norms |
| self.answer_sequence_prompt_bigram_sets = bigram_sets |
| self.answer_sequence_prompt_trigram_sets = trigram_sets |
| self.answer_sequence_prompt_number_sets = number_sets |
|
|
| @staticmethod |
| def _prompt_overlap_token_specificity(document_frequency: int, total_documents: int) -> float: |
| if document_frequency <= 0 or total_documents <= 0: |
| return 1.0 |
| coverage = min(1.0, document_frequency / total_documents) |
| return max(0.02, 1.0 - (coverage ** 0.5)) |
|
|
| def _number_strings_from_token_ids(self, token_ids: list[int]) -> set[str]: |
| assert self.embedding_model is not None |
| tokens = [ |
| self.embedding_model.id_to_token[token_id] |
| for token_id in token_ids |
| if 0 <= token_id < len(self.embedding_model.id_to_token) |
| ] |
| return self._number_strings_from_tokens(tokens) |
|
|
| def _number_strings_from_tokens(self, tokens: list[str]) -> set[str]: |
| numbers: set[str] = set() |
| current = "" |
| for token in tokens: |
| if self.tokenizer is not None and token in self.tokenizer.special_tokens: |
| if current: |
| numbers.add(current) |
| current = "" |
| continue |
| rendered = self._render_token(token) |
| digits = "".join(character for character in rendered if character.isdigit()) |
| starts_number = self._starts_new_word(token) if self.tokenizer is not None else True |
| if digits and starts_number: |
| if current: |
| numbers.add(current) |
| current = digits |
| elif digits and current: |
| current += digits |
| else: |
| if current: |
| numbers.add(current) |
| current = "" |
| if current: |
| numbers.add(current) |
| return numbers |
|
|
| @staticmethod |
| def _numeric_prompt_can_match(query_numbers: set[str], row_numbers: set[str]) -> bool: |
| if not query_numbers: |
| return True |
| if not row_numbers: |
| return False |
| return query_numbers.issubset(row_numbers) |
|
|
| def _apply_readout_fast(self, state: Vector) -> Vector: |
| if self.readout_weights_array is None or np is None: |
| assert self.readout_weights is not None |
| centered_state = self._center_state_vector(state) |
| logits = apply_readout(self.readout_weights, centered_state) |
| if self.readout_bias: |
| logits = [ |
| value + self.readout_bias[index] |
| for index, value in enumerate(logits) |
| ] |
| return logits |
| state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: |
| state_array = state_array - self.state_offset_array |
| logits = self.readout_weights_array @ state_array |
| if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: |
| logits = logits + self.readout_bias_array |
| return logits.tolist() |
|
|
| def _apply_readout_array(self, state: object) -> object: |
| assert np is not None |
| assert self.readout_weights_array is not None |
| state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) |
| if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: |
| state_array = state_array - self.state_offset_array |
| logits = self.readout_weights_array @ state_array |
| if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: |
| logits = logits + self.readout_bias_array |
| return logits |
|
|
| def _step_hidden_states( |
| self, |
| hidden_states: list[Vector], |
| context_traces: list[Vector], |
| token: str, |
| ) -> tuple[list[Vector], list[Vector], Vector]: |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
| token_id = self._token_id_for_token(token) |
| embedding = self.embedding_model.vector(token) |
| trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) |
| return self._step_hidden_states_from_embedding( |
| hidden_states, |
| context_traces, |
| embedding, |
| trace_embedding=trace_embedding, |
| ) |
|
|
| def _step_hidden_states_from_embedding( |
| self, |
| hidden_states: list[Vector], |
| context_traces: list[Vector], |
| embedding: Vector | object, |
| *, |
| trace_embedding: Vector | object | None = None, |
| ) -> tuple[list[Vector], list[Vector], Vector]: |
| assert self.memory_units is not None |
| if trace_embedding is None: |
| trace_embedding = embedding |
|
|
| if np is not None and hidden_states and hasattr(hidden_states[0], "shape"): |
| embedding_array = ( |
| embedding |
| if hasattr(embedding, "shape") |
| else np.asarray(embedding, dtype=np.float64) |
| ) |
| trace_embedding_array = ( |
| trace_embedding |
| if hasattr(trace_embedding, "shape") |
| else np.asarray(trace_embedding, dtype=np.float64) |
| ) |
| drive = analytical_embedding_drive_fast(embedding_array, self.config.state_dim) |
| next_states: list[Vector] = [] |
| next_traces: list[Vector] = [] |
| combined_state: Vector = [] |
| for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): |
| next_state = unit.step_vector_fast(state, drive) |
| decay = 1.0 / (1.0 + unit.timescale) |
| next_trace = trace + ((1.0 - decay) * trace_embedding_array) |
| next_states.append(next_state) |
| next_traces.append(next_trace) |
| combined_state.extend(next_state.tolist()) |
| combined_state.extend(next_trace.tolist()) |
| return next_states, next_traces, combined_state |
|
|
| embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else embedding |
| trace_embedding_vector = ( |
| trace_embedding.tolist() |
| if hasattr(trace_embedding, "tolist") |
| else trace_embedding |
| ) |
| drive = analytical_embedding_drive(embedding_vector, self.config.state_dim) |
| next_states: list[Vector] = [] |
| next_traces: list[Vector] = [] |
| combined_state: Vector = [] |
| for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): |
| next_state = unit.step_vector(state, drive) |
| decay = 1.0 / (1.0 + unit.timescale) |
| next_trace = [ |
| previous + ((1.0 - decay) * value) |
| for previous, value in zip(trace, trace_embedding_vector) |
| ] |
| next_states.append(next_state) |
| next_traces.append(next_trace) |
| combined_state.extend(next_state) |
| combined_state.extend(next_trace) |
| return next_states, next_traces, combined_state |
|
|
| def _one_hot(self, token: str) -> Vector: |
| assert self.embedding_model is not None |
| return self._one_hot_from_id(self.embedding_model.token_to_id.get(token, -1)) |
|
|
| def _one_hot_from_id(self, token_id: int) -> Vector: |
| assert self.embedding_model is not None |
| vector = [0.0 for _ in self.embedding_model.id_to_token] |
| if token_id >= 0: |
| vector[token_id] = 1.0 |
| return vector |
|
|
| def _blend_probabilities( |
| self, |
| base: Vector, |
| answer: Vector, |
| associative: Vector, |
| transition: Vector, |
| copy: Vector, |
| preference: Vector, |
| *, |
| transition_order: int | None, |
| generated_count: int = 0, |
| answer_locked: bool = False, |
| answer_guided_start: bool = False, |
| ) -> tuple[Vector, dict[str, float]]: |
| base_weight = FAST_BASE_BLEND |
| answer_weight = FAST_ANSWER_BLEND |
| associative_weight = FAST_ASSOCIATIVE_BLEND |
| transition_weight = FAST_TRANSITION_BLEND |
| copy_weight = FAST_COPY_BLEND |
| preference_weight = FAST_PREFERENCE_BLEND |
| if answer_locked: |
| base_weight *= 0.18 |
| answer_weight *= 5.0 |
| associative_weight *= 0.2 |
| transition_weight *= 0.2 |
| copy_weight *= 0.2 |
| preference_weight *= 0.2 |
| elif answer_guided_start: |
| base_weight *= 0.35 |
| answer_weight *= 3.5 |
| associative_weight *= 0.2 |
| transition_weight *= 0.35 |
| copy_weight *= 0.2 |
| preference_weight *= 0.2 |
| elif generated_count > 0: |
| answer_weight *= 0.32 |
| transition_weight *= 2.0 |
| copy_weight *= 0.75 |
|
|
| if transition_order is None: |
| answer_weight *= 1.1 |
| associative_weight *= 0.75 |
| copy_weight += 0.02 |
| elif transition_order <= 2: |
| answer_weight *= 1.15 |
| associative_weight *= 0.65 |
| transition_weight *= 0.55 |
| copy_weight += 0.01 |
| elif transition_order >= 5: |
| transition_weight *= 1.25 |
|
|
| sources: list[tuple[str, float, Vector]] = [("base", base_weight, base)] |
| if any(value > 0.0 for value in answer): |
| sources.append(("answer", answer_weight, answer)) |
| if any(value > 0.0 for value in associative): |
| sources.append(("associative", associative_weight, associative)) |
| if any(value > 0.0 for value in transition): |
| sources.append(("transition", transition_weight, transition)) |
| if any(value > 0.0 for value in copy): |
| sources.append(("copy", copy_weight, copy)) |
| if any(value > 0.0 for value in preference): |
| sources.append(("preference", preference_weight, preference)) |
|
|
| total_weight = sum(weight for _, weight, _ in sources) |
| blended = [0.0 for _ in base] |
| blend_weights: dict[str, float] = {} |
| for name, weight, source in sources: |
| normalized_weight = weight / total_weight if total_weight else 0.0 |
| blend_weights[name] = normalized_weight |
| for index, value in enumerate(source): |
| blended[index] += normalized_weight * value |
| return _normalize_vector(blended), blend_weights |
|
|
| def _blend_probability_arrays( |
| self, |
| base: object, |
| answer: object, |
| associative: object, |
| transition: object, |
| copy: object, |
| preference: object, |
| *, |
| transition_order: int | None, |
| generated_count: int = 0, |
| answer_locked: bool = False, |
| answer_guided_start: bool = False, |
| ) -> tuple[object, dict[str, float]]: |
| assert np is not None |
|
|
| base_weight = FAST_BASE_BLEND |
| answer_weight = FAST_ANSWER_BLEND |
| associative_weight = FAST_ASSOCIATIVE_BLEND |
| transition_weight = FAST_TRANSITION_BLEND |
| copy_weight = FAST_COPY_BLEND |
| preference_weight = FAST_PREFERENCE_BLEND |
| if answer_locked: |
| base_weight *= 0.18 |
| answer_weight *= 5.0 |
| associative_weight *= 0.2 |
| transition_weight *= 0.2 |
| copy_weight *= 0.2 |
| preference_weight *= 0.2 |
| elif answer_guided_start: |
| base_weight *= 0.35 |
| answer_weight *= 3.5 |
| associative_weight *= 0.2 |
| transition_weight *= 0.35 |
| copy_weight *= 0.2 |
| preference_weight *= 0.2 |
| elif generated_count > 0: |
| answer_weight *= 0.32 |
| transition_weight *= 2.0 |
| copy_weight *= 0.75 |
| if transition_order is None: |
| answer_weight *= 1.1 |
| associative_weight *= 0.75 |
| copy_weight += 0.02 |
| elif transition_order <= 2: |
| answer_weight *= 1.15 |
| associative_weight *= 0.65 |
| transition_weight *= 0.55 |
| copy_weight += 0.01 |
| elif transition_order >= 5: |
| transition_weight *= 1.25 |
|
|
| sources: list[tuple[str, float, object]] = [("base", base_weight, base)] |
| if np.any(answer > 0.0): |
| sources.append(("answer", answer_weight, answer)) |
| if np.any(associative > 0.0): |
| sources.append(("associative", associative_weight, associative)) |
| if np.any(transition > 0.0): |
| sources.append(("transition", transition_weight, transition)) |
| if np.any(copy > 0.0): |
| sources.append(("copy", copy_weight, copy)) |
| if np.any(preference > 0.0): |
| sources.append(("preference", preference_weight, preference)) |
|
|
| total_weight = sum(weight for _, weight, _ in sources) |
| blended = np.zeros_like(base, dtype=np.float64) |
| blend_weights: dict[str, float] = {} |
| for name, weight, source in sources: |
| normalized_weight = weight / total_weight if total_weight else 0.0 |
| blend_weights[name] = normalized_weight |
| blended += normalized_weight * source |
| total = float(blended.sum()) |
| if total <= 0.0: |
| return base, blend_weights |
| return blended / total, blend_weights |
|
|
| def _score_associative_matches( |
| self, |
| state: Vector, |
| *, |
| limit: int = ASSOCIATIVE_TOP_K, |
| ) -> list[tuple[float, int, int]]: |
| if ( |
| self.associative_keys is None |
| or self.associative_values is None |
| or self.associative_key_norms is None |
| or len(self.associative_keys) == 0 |
| or len(self.associative_values) == 0 |
| or len(self.associative_key_norms) == 0 |
| ): |
| return [] |
|
|
| if ( |
| np is not None |
| and |
| self.associative_keys_array is not None |
| and self.associative_key_norms_array is not None |
| and self.associative_values_array is not None |
| and self.associative_valid_mask_array is not None |
| and limit > 0 |
| ): |
| state_array = self._center_state_array(state).astype(self.associative_keys_array.dtype, copy=False) |
| state_norm = float(np.linalg.norm(state_array)) |
| if state_norm == 0.0: |
| return [] |
| numerators = self.associative_keys_array @ state_array |
| denominators = self.associative_key_norms_array * state_norm |
| valid_mask = self.associative_valid_mask_array & (denominators > 0.0) |
| if np.any(valid_mask): |
| scores = np.zeros_like(numerators, dtype=self.associative_keys_array.dtype) |
| np.divide(numerators, denominators, out=scores, where=valid_mask) |
| positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) |
| if positive_positions.size: |
| selected_positions = positive_positions |
| if positive_positions.size > limit: |
| partition = np.argpartition(scores[positive_positions], -limit)[-limit:] |
| selected_positions = positive_positions[partition] |
| ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] |
| return [ |
| ( |
| float(scores[position]), |
| int(self.associative_values_array[position]), |
| int(position), |
| ) |
| for position in ordered_positions |
| ] |
|
|
| state = self._center_state_vector(state) |
| state_norm = norm(state) |
| if state_norm == 0.0: |
| return [] |
|
|
| scored: list[tuple[float, int, int]] = [] |
| for example_index, (key, key_norm, token_id) in enumerate( |
| zip(self.associative_keys, self.associative_key_norms, self.associative_values) |
| ): |
| if token_id < 0: |
| continue |
| denominator = state_norm * key_norm |
| if denominator == 0.0: |
| continue |
| similarity = dot(state, key) / denominator |
| if similarity > 0.0: |
| scored.append((similarity, token_id, example_index)) |
| scored.sort(key=lambda item: item[0], reverse=True) |
| return scored[:limit] |
|
|
| def _associative_prior_from_matches( |
| self, |
| matches: list[tuple[float, int, int]], |
| ) -> Vector: |
| assert self.embedding_model is not None |
| if not matches: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
|
|
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| for similarity, token_id, _ in matches[:ASSOCIATIVE_TOP_K]: |
| prior[token_id] += similarity |
| return _normalize_vector(prior) |
|
|
| def _associative_prior(self, state: Vector) -> Vector: |
| return self._associative_prior_from_matches(self._score_associative_matches(state)) |
|
|
| def _score_answer_matches( |
| self, |
| answer_anchor_state: Vector | None, |
| *, |
| limit: int = ANSWER_TOP_K, |
| ) -> list[tuple[float, int, int]]: |
| return self._score_prompt_anchor_matches( |
| answer_anchor_state, |
| self.answer_keys, |
| self.answer_key_norms, |
| self.answer_values, |
| self.answer_keys_array, |
| self.answer_key_norms_array, |
| self.answer_values_array, |
| self.answer_valid_mask_array, |
| self.answer_similarity_keys_array, |
| self.answer_similarity_key_norms_array, |
| self.answer_similarity_mask_array, |
| limit=limit, |
| ) |
|
|
| def _score_answer_start_matches( |
| self, |
| answer_anchor_state: Vector | None, |
| *, |
| limit: int = ANSWER_START_TOP_K, |
| ) -> list[tuple[float, int, int]]: |
| return self._score_prompt_anchor_matches( |
| answer_anchor_state, |
| self.answer_start_keys, |
| self.answer_start_key_norms, |
| self.answer_start_values, |
| self.answer_start_keys_array, |
| self.answer_start_key_norms_array, |
| self.answer_start_values_array, |
| self.answer_start_valid_mask_array, |
| self.answer_start_similarity_keys_array, |
| self.answer_start_similarity_key_norms_array, |
| self.answer_similarity_mask_array, |
| limit=limit, |
| ) |
|
|
| def _score_answer_sequence_matches( |
| self, |
| answer_anchor_state: Vector | None, |
| context_tokens: list[str], |
| *, |
| limit: int = ANSWER_START_TOP_K, |
| ) -> list[tuple[float, int, int]]: |
| if ( |
| answer_anchor_state is None |
| or self.answer_sequence_keys is None |
| or self.answer_sequence_key_norms is None |
| or self.answer_sequence_tokens is None |
| ): |
| return [] |
| values = list(range(len(self.answer_sequence_tokens))) |
| values_array = np.arange(len(values), dtype=np.int64) if np is not None else None |
| anchor_matches = self._score_prompt_anchor_matches( |
| answer_anchor_state, |
| self.answer_sequence_keys, |
| self.answer_sequence_key_norms, |
| values, |
| self.answer_sequence_keys_array, |
| self.answer_sequence_key_norms_array, |
| values_array, |
| values_array >= 0 if values_array is not None else None, |
| self.answer_sequence_similarity_keys_array, |
| self.answer_sequence_similarity_key_norms_array, |
| self.answer_similarity_mask_array, |
| limit=max(limit * 4, limit), |
| ) |
| overlap_scores = self._answer_sequence_prompt_overlap_scores(context_tokens) |
| if overlap_scores is None: |
| return anchor_matches[:limit] |
| if not overlap_scores: |
| return [] |
| best_overlap = max(overlap_scores.values()) if overlap_scores else 0.0 |
| overlap_floor = max(0.16, best_overlap * 0.90) |
| focused_overlap_scores = { |
| sequence_index: overlap |
| for sequence_index, overlap in overlap_scores.items() |
| if overlap >= overlap_floor |
| } |
| if not focused_overlap_scores: |
| focused_overlap_scores = overlap_scores |
| focused_indices = set(focused_overlap_scores) |
| merged: dict[int, float] = {} |
| for similarity, sequence_index, _ in anchor_matches: |
| if sequence_index not in focused_indices: |
| continue |
| merged[sequence_index] = max(merged.get(sequence_index, 0.0), 0.20 * similarity) |
| for sequence_index, overlap in focused_overlap_scores.items(): |
| merged[sequence_index] = merged.get(sequence_index, 0.0) + (0.80 * overlap) |
| ranked = [ |
| (score, sequence_index, sequence_index) |
| for sequence_index, score in merged.items() |
| if score > 0.0 |
| ] |
| ranked.sort(key=lambda item: item[0], reverse=True) |
| return ranked[:limit] |
|
|
| def _answer_sequence_prompt_overlap_scores( |
| self, |
| context_tokens: list[str], |
| ) -> dict[int, float] | None: |
| if ( |
| self.embedding_model is None |
| or self.answer_sequence_prompt_tokens is None |
| or self.trace_token_weights is None |
| ): |
| return None |
| answer_boundary = _last_index(context_tokens, "<answer>") |
| prompt_tokens = ( |
| context_tokens[:answer_boundary] |
| if answer_boundary is not None |
| else context_tokens |
| ) |
| if self.answer_sequence_prompt_specificity is None: |
| self._refresh_answer_sequence_prompt_overlap_cache() |
| specificity_map = self.answer_sequence_prompt_specificity or {} |
| query_weights: dict[int, float] = {} |
| query_specificity: dict[int, float] = {} |
| query_content_weight = 0.0 |
| query_ids: list[int] = [] |
| for token in prompt_tokens: |
| if self.tokenizer is not None and token in self.tokenizer.special_tokens: |
| continue |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| continue |
| query_ids.append(token_id) |
| specificity = specificity_map.get(token_id, 1.0) |
| weight = specificity |
| query_weights[token_id] = max( |
| query_weights.get(token_id, 0.0), |
| weight, |
| ) |
| query_specificity[token_id] = max( |
| query_specificity.get(token_id, 0.0), |
| specificity, |
| ) |
| if specificity >= 0.20: |
| query_content_weight += weight |
| if not query_weights: |
| return None |
| query_norm = sum(value * value for value in query_weights.values()) ** 0.5 |
| if query_norm <= 0.0: |
| return None |
|
|
| query_bigrams = { |
| (query_ids[index], query_ids[index + 1]) |
| for index in range(len(query_ids) - 1) |
| } |
| query_trigrams = { |
| (query_ids[index], query_ids[index + 1], query_ids[index + 2]) |
| for index in range(len(query_ids) - 2) |
| } |
| query_numbers = self._number_strings_from_tokens(prompt_tokens) |
|
|
| def ordered_ngram_score( |
| query_grams: set[tuple[int, ...]], |
| row_grams: set[tuple[int, ...]], |
| ) -> float: |
| if not query_grams or not row_grams: |
| return 0.0 |
| overlap = len(query_grams & row_grams) |
| if overlap <= 0: |
| return 0.0 |
| return overlap / ((len(query_grams) * len(row_grams)) ** 0.5) |
|
|
| cached_maps = self.answer_sequence_prompt_weight_maps |
| cached_norms = self.answer_sequence_prompt_weight_norms |
| cached_bigrams = self.answer_sequence_prompt_bigram_sets |
| cached_trigrams = self.answer_sequence_prompt_trigram_sets |
| cached_numbers = self.answer_sequence_prompt_number_sets |
| cached_index = self.answer_sequence_prompt_inverted_index |
| if ( |
| cached_maps is not None |
| and cached_norms is not None |
| and cached_bigrams is not None |
| and cached_trigrams is not None |
| and cached_numbers is not None |
| and len(cached_maps) == len(self.answer_sequence_prompt_tokens) |
| ): |
| candidate_indices: set[int] | range |
| if cached_index is not None: |
| candidates: set[int] = set() |
| for token_id in query_weights: |
| candidates.update(cached_index.get(token_id, ())) |
| candidate_indices = candidates if candidates else range(len(cached_maps)) |
| else: |
| candidate_indices = range(len(cached_maps)) |
| candidate_indices = list(candidate_indices) |
| if cached_index is not None and candidate_indices: |
| candidate_set = set(candidate_indices) |
| local_query_weights: dict[int, float] = {} |
| local_query_specificity: dict[int, float] = {} |
| local_query_content_weight = 0.0 |
| for token_id in query_weights: |
| local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) |
| if local_frequency <= 0: |
| continue |
| specificity = self._prompt_overlap_token_specificity( |
| local_frequency, |
| len(candidate_indices), |
| ) |
| weight = specificity |
| local_query_weights[token_id] = weight |
| local_query_specificity[token_id] = specificity |
| if specificity >= 0.20: |
| local_query_content_weight += weight |
| local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 |
| if local_query_norm > 0.0: |
| query_weights = local_query_weights |
| query_specificity = local_query_specificity |
| if local_query_content_weight > 0.0: |
| query_content_weight = local_query_content_weight |
| query_norm = local_query_norm |
| scores: dict[int, float] = {} |
| for sequence_index in candidate_indices: |
| row_weights = cached_maps[sequence_index] |
| if not row_weights: |
| continue |
| if not self._numeric_prompt_can_match(query_numbers, cached_numbers[sequence_index]): |
| continue |
| matched_content_weight = sum( |
| query_weights[token_id] |
| for token_id in query_weights.keys() & row_weights.keys() |
| if query_specificity.get(token_id, 0.0) >= 0.20 |
| ) |
| row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( |
| 1, |
| len(row_weights), |
| ) |
| if ( |
| query_content_weight > 0.0 |
| and matched_content_weight / query_content_weight < 0.40 |
| and row_token_coverage < 0.75 |
| ): |
| continue |
| query_coverage = ( |
| matched_content_weight / query_content_weight |
| if query_content_weight > 0.0 |
| else row_token_coverage |
| ) |
| numerator = sum( |
| query_weights[token_id] * row_weights[token_id] |
| for token_id in query_weights.keys() & row_weights.keys() |
| ) |
| if numerator <= 0.0: |
| continue |
| row_norm = cached_norms[sequence_index] |
| if row_norm <= 0.0: |
| continue |
| token_score = numerator / (query_norm * row_norm) |
| bigram_score = ordered_ngram_score( |
| query_bigrams, |
| cached_bigrams[sequence_index], |
| ) |
| trigram_score = ordered_ngram_score( |
| query_trigrams, |
| cached_trigrams[sequence_index], |
| ) |
| scores[sequence_index] = ( |
| (0.35 * token_score) |
| + (0.35 * query_coverage) |
| + (0.15 * bigram_score) |
| + (0.15 * trigram_score) |
| ) |
| return scores |
|
|
| if cached_index is not None: |
| candidate_set: set[int] = set() |
| for token_id in query_weights: |
| candidate_set.update(cached_index.get(token_id, ())) |
| if not candidate_set: |
| return {} |
| candidate_indices: list[int] | range = sorted(candidate_set) |
| local_query_weights: dict[int, float] = {} |
| local_query_specificity: dict[int, float] = {} |
| local_query_content_weight = 0.0 |
| candidate_count = len(candidate_indices) |
| for token_id in query_weights: |
| local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) |
| if local_frequency <= 0: |
| continue |
| specificity = self._prompt_overlap_token_specificity( |
| local_frequency, |
| candidate_count, |
| ) |
| local_query_weights[token_id] = specificity |
| local_query_specificity[token_id] = specificity |
| if specificity >= 0.20: |
| local_query_content_weight += specificity |
| local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 |
| if local_query_norm > 0.0: |
| query_weights = local_query_weights |
| query_specificity = local_query_specificity |
| if local_query_content_weight > 0.0: |
| query_content_weight = local_query_content_weight |
| query_norm = local_query_norm |
| else: |
| candidate_indices = range(len(self.answer_sequence_prompt_tokens)) |
|
|
| scores: dict[int, float] = {} |
| for sequence_index in candidate_indices: |
| row = self.answer_sequence_prompt_tokens[sequence_index] |
| row_values = row.tolist() if hasattr(row, "tolist") else row |
| row_weights: dict[int, float] = {} |
| row_ids: list[int] = [] |
| for raw_token_id in row_values: |
| token_id = int(raw_token_id) |
| if token_id < 0 or token_id >= len(self.trace_token_weights): |
| continue |
| row_ids.append(token_id) |
| row_weights[token_id] = max( |
| row_weights.get(token_id, 0.0), |
| specificity_map.get(token_id, 1.0), |
| ) |
| if not row_weights: |
| continue |
| if not self._numeric_prompt_can_match( |
| query_numbers, |
| self._number_strings_from_token_ids(row_ids), |
| ): |
| continue |
| matched_content_weight = sum( |
| query_weights[token_id] |
| for token_id in query_weights.keys() & row_weights.keys() |
| if query_specificity.get(token_id, 0.0) >= 0.20 |
| ) |
| row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( |
| 1, |
| len(row_weights), |
| ) |
| if ( |
| query_content_weight > 0.0 |
| and matched_content_weight / query_content_weight < 0.40 |
| and row_token_coverage < 0.75 |
| ): |
| continue |
| query_coverage = ( |
| matched_content_weight / query_content_weight |
| if query_content_weight > 0.0 |
| else row_token_coverage |
| ) |
| numerator = sum( |
| query_weights[token_id] * row_weights[token_id] |
| for token_id in query_weights.keys() & row_weights.keys() |
| ) |
| if numerator <= 0.0: |
| continue |
| row_norm = sum(value * value for value in row_weights.values()) ** 0.5 |
| if row_norm > 0.0: |
| token_score = numerator / (query_norm * row_norm) |
| row_bigrams = { |
| (row_ids[index], row_ids[index + 1]) |
| for index in range(len(row_ids) - 1) |
| } |
| row_trigrams = { |
| (row_ids[index], row_ids[index + 1], row_ids[index + 2]) |
| for index in range(len(row_ids) - 2) |
| } |
| bigram_score = ordered_ngram_score(query_bigrams, row_bigrams) |
| trigram_score = ordered_ngram_score(query_trigrams, row_trigrams) |
| scores[sequence_index] = ( |
| (0.35 * token_score) |
| + (0.35 * query_coverage) |
| + (0.15 * bigram_score) |
| + (0.15 * trigram_score) |
| ) |
| return scores |
|
|
| def _score_prompt_anchor_matches( |
| self, |
| answer_anchor_state: Vector | None, |
| keys: object | None, |
| key_norms_list: object | None, |
| values: object | None, |
| keys_array: object | None, |
| key_norms_array: object | None, |
| values_array: object | None, |
| valid_mask_array: object | None, |
| similarity_keys_array: object | None, |
| similarity_key_norms_array: object | None, |
| similarity_mask_array: object | None, |
| *, |
| limit: int, |
| ) -> list[tuple[float, int, int]]: |
| if ( |
| answer_anchor_state is None |
| or keys is None |
| or key_norms_list is None |
| or values is None |
| ): |
| return [] |
|
|
| if ( |
| np is not None |
| and keys_array is not None |
| and key_norms_array is not None |
| and values_array is not None |
| and valid_mask_array is not None |
| and limit > 0 |
| ): |
| state_array = self._center_state_array( |
| self._masked_combined_state_array(answer_anchor_state) |
| ).astype(keys_array.dtype, copy=False) |
| key_array = keys_array |
| key_norms = key_norms_array |
| if ( |
| similarity_keys_array is not None |
| and similarity_key_norms_array is not None |
| and similarity_mask_array is not None |
| ): |
| state_array = state_array * similarity_mask_array |
| key_array = similarity_keys_array |
| key_norms = similarity_key_norms_array |
| state_norm = float(np.linalg.norm(state_array)) |
| if state_norm == 0.0: |
| return [] |
| numerators = key_array @ state_array |
| denominators = key_norms * state_norm |
| valid_mask = valid_mask_array & (denominators > 0.0) |
| if np.any(valid_mask): |
| scores = np.zeros_like(numerators, dtype=key_array.dtype) |
| np.divide(numerators, denominators, out=scores, where=valid_mask) |
| positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) |
| if positive_positions.size: |
| selected_positions = positive_positions |
| if positive_positions.size > limit: |
| partition = np.argpartition(scores[positive_positions], -limit)[-limit:] |
| selected_positions = positive_positions[partition] |
| ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] |
| return [ |
| ( |
| float(scores[position]), |
| int(values_array[position]), |
| int(position), |
| ) |
| for position in ordered_positions |
| ] |
|
|
| state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) |
| state_norm = norm(state) |
| if state_norm == 0.0: |
| return [] |
|
|
| scored: list[tuple[float, int, int]] = [] |
| for example_index, (key, key_norm, token_id) in enumerate( |
| zip(keys, key_norms_list, values) |
| ): |
| if token_id < 0: |
| continue |
| denominator = state_norm * key_norm |
| if denominator == 0.0: |
| continue |
| similarity = dot(state, key) / denominator |
| if similarity > 0.0: |
| scored.append((similarity, token_id, example_index)) |
| scored.sort(key=lambda item: item[0], reverse=True) |
| return scored[:limit] |
|
|
| def _answer_prior_from_matches( |
| self, |
| matches: list[tuple[float, int, int]], |
| generated_tokens: list[str], |
| ) -> Vector: |
| assert self.embedding_model is not None |
| if not matches: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
|
|
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| generated_ids = { |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| } |
| for similarity, token_id, _ in matches[:ANSWER_TOP_K]: |
| token = self.embedding_model.id_to_token[token_id] |
| if not self._allowed_generation_token(token, generated_tokens): |
| continue |
| if token_id in generated_ids: |
| prior[token_id] += similarity * 0.35 |
| else: |
| prior[token_id] += similarity |
| return _normalize_vector(prior) |
|
|
| def _answer_sequence_prior_from_matches( |
| self, |
| matches: list[tuple[float, int, int]], |
| generated_tokens: list[str], |
| ) -> Vector: |
| assert self.embedding_model is not None |
| if not matches or self.answer_sequence_tokens is None: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
|
|
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| best_similarity = matches[0][0] |
| match_floor = best_similarity - 0.02 if best_similarity >= 0.9 else 0.0 |
| for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: |
| if similarity < match_floor: |
| continue |
| row = self.answer_sequence_tokens[sequence_index] |
| token_ids = [ |
| int(value) |
| for value in (row.tolist() if hasattr(row, "tolist") else row) |
| if int(value) >= 0 |
| ] |
| if not token_ids: |
| continue |
| next_token_id = self._next_sequence_token_id(token_ids, generated_ids) |
| if next_token_id is None: |
| continue |
| token = self.embedding_model.id_to_token[next_token_id] |
| if self._allowed_generation_token(token, generated_tokens): |
| prior[next_token_id] += max(1e-9, similarity - match_floor) |
| return _normalize_vector(prior) |
|
|
| def _should_stop_answer_sequence( |
| self, |
| decode_state: DecodeState, |
| generated_tokens: list[str], |
| ) -> bool: |
| matches = decode_state.answer_sequence_matches |
| if matches is None: |
| matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| ) |
| return self._answer_sequence_is_complete(generated_tokens, matches) |
|
|
| def _answer_decode_has_continuation( |
| self, |
| decode_state: DecodeState, |
| generated_tokens: list[str], |
| ) -> bool: |
| matches = decode_state.answer_sequence_matches |
| if matches is None: |
| matches = self._score_answer_sequence_matches( |
| decode_state.answer_anchor_state, |
| decode_state.context_tokens, |
| ) |
| return self._answer_sequence_has_continuation(generated_tokens, matches) |
|
|
| def _answer_sequence_is_complete( |
| self, |
| generated_tokens: list[str], |
| matches: list[tuple[float, int, int]], |
| ) -> bool: |
| if ( |
| self.embedding_model is None |
| or self.answer_sequence_tokens is None |
| or not generated_tokens |
| or not matches |
| ): |
| return False |
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if not generated_ids: |
| return False |
| for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: |
| if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): |
| continue |
| row = self.answer_sequence_tokens[sequence_index] |
| token_ids = [ |
| int(value) |
| for value in (row.tolist() if hasattr(row, "tolist") else row) |
| if int(value) >= 0 |
| ] |
| if not token_ids or len(generated_ids) < len(token_ids): |
| continue |
| if generated_ids[: len(token_ids)] == token_ids: |
| return True |
| return False |
|
|
| def _answer_sequence_has_continuation( |
| self, |
| generated_tokens: list[str], |
| matches: list[tuple[float, int, int]], |
| ) -> bool: |
| if ( |
| self.embedding_model is None |
| or self.answer_sequence_tokens is None |
| or not generated_tokens |
| or not matches |
| ): |
| return False |
| generated_ids = [ |
| self.embedding_model.token_to_id[token] |
| for token in generated_tokens |
| if token in self.embedding_model.token_to_id |
| ] |
| if not generated_ids: |
| return False |
| for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: |
| if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): |
| continue |
| row = self.answer_sequence_tokens[sequence_index] |
| token_ids = [ |
| int(value) |
| for value in (row.tolist() if hasattr(row, "tolist") else row) |
| if int(value) >= 0 |
| ] |
| if not token_ids: |
| continue |
| next_token_id = self._next_sequence_token_id(token_ids, generated_ids) |
| if next_token_id is None: |
| continue |
| token = self.embedding_model.id_to_token[next_token_id] |
| if self._allowed_generation_token(token, generated_tokens): |
| return True |
| return False |
|
|
| def _next_sequence_token_id( |
| self, |
| token_ids: list[int], |
| generated_ids: list[int], |
| ) -> int | None: |
| if not generated_ids: |
| return token_ids[0] |
| if len(generated_ids) >= len(token_ids): |
| return None |
| if token_ids[: len(generated_ids)] != generated_ids: |
| return None |
| return token_ids[len(generated_ids)] |
|
|
| def _transition_prior(self, context_tokens: list[str]) -> Vector: |
| prior, _ = self._transition_prior_with_order(context_tokens) |
| return prior |
|
|
| def _transition_prior_with_order( |
| self, |
| context_tokens: list[str], |
| ) -> tuple[Vector, int | None]: |
| assert self.embedding_model is not None |
| if not self.transition_tables: |
| return [0.0 for _ in self.embedding_model.id_to_token], None |
|
|
| for order in TRANSITION_ORDERS: |
| if len(context_tokens) < order: |
| continue |
| key = tuple(context_tokens[-order:]) |
| transitions = self.transition_tables.get(order, {}).get(key) |
| if not transitions: |
| continue |
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| for token, probability in transitions.items(): |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is not None: |
| prior[token_id] = probability |
| return _normalize_vector(prior), order |
| return [0.0 for _ in self.embedding_model.id_to_token], None |
|
|
| def _transition_prior_array_with_order( |
| self, |
| context_tokens: list[str], |
| ) -> tuple[object, int | None]: |
| assert np is not None |
| assert self.embedding_model is not None |
| prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| if not self.transition_tables: |
| return prior, None |
|
|
| for order in TRANSITION_ORDERS: |
| if len(context_tokens) < order: |
| continue |
| key = tuple(context_tokens[-order:]) |
| transitions = self.transition_tables.get(order, {}).get(key) |
| if not transitions: |
| continue |
| for token, probability in transitions.items(): |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is not None: |
| prior[token_id] = probability |
| total = float(prior.sum()) |
| if total > 0.0: |
| prior /= total |
| return prior, order |
| return prior, None |
|
|
| def _copy_prior(self, context_tokens: list[str]) -> Vector: |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
|
|
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| decay = 0.82 |
| answer_start = None |
| for index in range(len(context_tokens) - 1, -1, -1): |
| if context_tokens[index] == "<answer>": |
| answer_start = index + 1 |
| break |
| source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens |
| if not source_tokens: |
| return prior |
| for distance, token in enumerate(reversed(source_tokens[-8:])): |
| if token in self.tokenizer.special_tokens: |
| continue |
| if not self._eligible_copy_token(token): |
| continue |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| continue |
| prior[token_id] += decay**distance |
| return _normalize_vector(prior) |
|
|
| def _copy_prior_array(self, context_tokens: list[str]) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| assert self.tokenizer is not None |
|
|
| prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| decay = 0.82 |
| answer_start = None |
| for index in range(len(context_tokens) - 1, -1, -1): |
| if context_tokens[index] == "<answer>": |
| answer_start = index + 1 |
| break |
| source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens |
| for distance, token in enumerate(reversed(source_tokens[-8:])): |
| if token in self.tokenizer.special_tokens: |
| continue |
| if not self._eligible_copy_token(token): |
| continue |
| token_id = self.embedding_model.token_to_id.get(token) |
| if token_id is None: |
| continue |
| prior[token_id] += decay**distance |
| total = float(prior.sum()) |
| if total > 0.0: |
| prior /= total |
| return prior |
|
|
| def _preference_prior(self) -> Vector: |
| assert self.embedding_model is not None |
| if not self.preference_bias or not any(value != 0.0 for value in self.preference_bias): |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| eligible_indices = [ |
| index |
| for index, token in enumerate(self.embedding_model.id_to_token) |
| if self.preference_bias[index] > 0.0 and self._eligible_preference_token(token) |
| ] |
| if not eligible_indices: |
| return [0.0 for _ in self.embedding_model.id_to_token] |
| eligible_probabilities = self._calibrated_softmax( |
| [self.preference_bias[index] for index in eligible_indices] |
| ) |
| prior = [0.0 for _ in self.embedding_model.id_to_token] |
| for index, probability in zip(eligible_indices, eligible_probabilities): |
| prior[index] = probability |
| return prior |
|
|
| def _preference_prior_array(self) -> object: |
| assert np is not None |
| assert self.embedding_model is not None |
| if self.preference_bias_array is None or not np.any(self.preference_bias_array != 0.0): |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| if self.preference_valid_mask_array is None or not np.any(self.preference_valid_mask_array): |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| positive_mask = self.preference_bias_array > 0.0 |
| active_mask = self.preference_valid_mask_array & positive_mask |
| if not np.any(active_mask): |
| return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) |
| prior[active_mask] = self._calibrated_softmax_array( |
| self.preference_bias_array[active_mask] |
| ) |
| return prior |
|
|
| def _eligible_preference_token(self, token: str) -> bool: |
| assert self.tokenizer is not None |
| if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens: |
| return False |
| if not self._starts_new_word(token): |
| return False |
| rendered = self._render_token(token) |
| if not rendered.strip() or self._is_punctuation_piece(rendered): |
| return False |
| alphanumeric = "".join(character for character in rendered if character.isalnum()) |
| return len(alphanumeric) >= 1 |
|
|
| def _build_transition_tables( |
| self, |
| tokens: list[str], |
| ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: |
| counts: dict[int, dict[tuple[str, ...], dict[str, int]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
| for order in sorted(TRANSITION_ORDERS): |
| for index in range(order - 1, len(tokens) - 1): |
| key = tuple(tokens[index - order + 1 : index + 1]) |
| nxt = tokens[index + 1] |
| bucket = counts[order].setdefault(key, {}) |
| bucket[nxt] = bucket.get(nxt, 0) + 1 |
|
|
| probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
| for order, mapping in counts.items(): |
| items = list(mapping.items()) |
| items.sort(key=lambda item: (-sum(item[1].values()), item[0])) |
| if ( |
| self.config.max_transition_contexts_per_order is not None |
| and self.config.max_transition_contexts_per_order >= 0 |
| ): |
| items = items[: self.config.max_transition_contexts_per_order] |
| for key, bucket in items: |
| next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0])) |
| if self.config.max_transition_next_tokens > 0: |
| next_items = next_items[: self.config.max_transition_next_tokens] |
| total = sum(value for _, value in next_items) |
| if total <= 0: |
| continue |
| probabilities[order][key] = { |
| token: value / total |
| for token, value in next_items |
| } |
| return probabilities |
|
|
| def _serialize_transition_tables(self) -> dict[str, dict[str, dict[str, float]]]: |
| assert self.transition_tables is not None |
| return { |
| str(order): { |
| _encode_ngram_key(key): value |
| for key, value in mapping.items() |
| } |
| for order, mapping in self.transition_tables.items() |
| } |
|
|
| def _deserialize_transition_tables( |
| self, |
| payload: dict[str, dict[str, dict[str, float]]], |
| ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: |
| tables: dict[int, dict[tuple[str, ...], dict[str, float]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
| for order_text, mapping in payload.items(): |
| order = int(order_text) |
| tables[order] = { |
| _decode_ngram_key(key): { |
| str(token): float(probability) |
| for token, probability in value.items() |
| } |
| for key, value in mapping.items() |
| } |
| return tables |
|
|
| def _eligible_copy_token(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| if not rendered.strip(): |
| return False |
| if self._is_punctuation_piece(rendered): |
| return False |
| if not self._starts_new_word(token): |
| return False |
| alphanumeric = "".join(character for character in rendered if character.isalnum()) |
| return len(alphanumeric) >= 2 |
|
|
| def _allowed_generation_token(self, token: str, generated_tokens: list[str]) -> bool: |
| assert self.embedding_model is not None |
| if len(self.embedding_model.id_to_token) < 1024: |
| return True |
| if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens: |
| return False |
| rendered = self._render_token(token) |
| if rendered == "\n": |
| return bool(generated_tokens) |
| if not rendered.strip(): |
| return False |
| if self._is_word_joiner_token(token): |
| return ( |
| self._can_attach_word_joiner(generated_tokens) |
| or self._can_start_line_with_word_joiner(token, generated_tokens) |
| ) |
| if self._is_structural_punctuation_token(token): |
| return bool(generated_tokens) or self._can_start_answer_with_structural_punctuation(token) |
| if self._is_structural_symbol_token(token): |
| return bool(generated_tokens) or self._starts_new_word(token) |
| if not self._starts_new_word(token): |
| return False |
| alphanumeric = "".join(character for character in rendered if character.isalnum()) |
| return len(alphanumeric) >= 1 or not self._is_punctuation_piece(rendered) |
|
|
| def _would_repeat_recent_pattern( |
| self, |
| candidate: str, |
| generated_tokens: list[str], |
| recent_rendered_words: list[str] | None = None, |
| ) -> bool: |
| if len(generated_tokens) >= 2 and generated_tokens[-1] == candidate and generated_tokens[-2] == candidate: |
| return True |
|
|
| if len(generated_tokens) >= 2: |
| trigram = tuple(generated_tokens[-2:] + [candidate]) |
| recent_tokens = generated_tokens[-12:] |
| for index in range(max(0, len(recent_tokens) - 4)): |
| if tuple(recent_tokens[index : index + 3]) == trigram: |
| return True |
|
|
| rendered_words = recent_rendered_words |
| if rendered_words is None: |
| rendered_words = self._recent_rendered_words(generated_tokens) |
| candidate_word = self._render_token(candidate).casefold() |
| if ( |
| rendered_words |
| and self._starts_new_word(candidate) |
| and any(character.isalnum() for character in candidate_word) |
| ): |
| candidate_bigram = (rendered_words[-1], candidate_word) |
| recent_window = rendered_words[-10:] |
| recent_bigrams = { |
| (recent_window[index], recent_window[index + 1]) |
| for index in range(len(recent_window) - 1) |
| } |
| if candidate_bigram in recent_bigrams: |
| return True |
| if ( |
| len(candidate_word) > 2 |
| and rendered_words[-10:].count(candidate_word) >= 2 |
| and not self._is_common_connector_token(candidate) |
| ): |
| return True |
|
|
| return False |
|
|
| def _recent_rendered_words(self, generated_tokens: list[str]) -> list[str]: |
| rendered_words: list[str] = [] |
| for token in generated_tokens: |
| if not self._starts_new_word(token): |
| continue |
| rendered = self._render_token(token).casefold() |
| if any(character.isalnum() for character in rendered): |
| rendered_words.append(rendered) |
| return rendered_words |
|
|
| def _select_generation_token( |
| self, |
| distribution: dict[str, float], |
| *, |
| context_tokens: list[str] | None = None, |
| generated_tokens: list[str] | None = None, |
| temperature: float = DEFAULT_GENERATION_TEMPERATURE, |
| top_k: int = DEFAULT_GENERATION_TOP_K, |
| top_p: float = DEFAULT_GENERATION_TOP_P, |
| repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
| preserve_dominant_candidates: bool = False, |
| ) -> str: |
| assert self.tokenizer is not None |
| generated_tokens = generated_tokens or [] |
| candidates = self._prepare_generation_candidates( |
| distribution, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=preserve_dominant_candidates, |
| ) |
| if candidates: |
| return self._sample_generation_candidate( |
| candidates, |
| context_tokens=context_tokens or [], |
| generated_tokens=generated_tokens, |
| stochastic=temperature > 0.0, |
| ) |
|
|
| for token, _ in sorted(distribution.items(), key=lambda item: item[1], reverse=True): |
| if token in self.tokenizer.special_tokens: |
| continue |
| if token == self.tokenizer.unk_token: |
| continue |
| if not self._allowed_generation_token(token, generated_tokens): |
| continue |
| return token |
| return "" |
|
|
| def _select_generation_token_from_array( |
| self, |
| probabilities: object, |
| *, |
| context_tokens: list[str], |
| generated_tokens: list[str], |
| temperature: float = DEFAULT_GENERATION_TEMPERATURE, |
| top_k: int = DEFAULT_GENERATION_TOP_K, |
| top_p: float = DEFAULT_GENERATION_TOP_P, |
| repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
| preserve_dominant_candidates: bool = False, |
| ) -> str: |
| assert np is not None |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
|
|
| values = np.asarray(probabilities, dtype=np.float64) |
| if values.size == 0: |
| return "" |
| pool_size = min(values.size, max(top_k * 4, 64)) |
| if pool_size <= 0: |
| pool_size = min(values.size, 64) |
| if pool_size < values.size: |
| candidate_indices = np.argpartition(values, -pool_size)[-pool_size:] |
| candidate_indices = candidate_indices[np.argsort(values[candidate_indices])[::-1]] |
| else: |
| candidate_indices = np.argsort(values)[::-1] |
|
|
| distribution: dict[str, float] = {} |
| for raw_index in candidate_indices: |
| index = int(raw_index) |
| score = float(values[index]) |
| if score <= 0.0: |
| continue |
| token = self.embedding_model.id_to_token[index] |
| if token in self.tokenizer.special_tokens or token == self.tokenizer.unk_token: |
| continue |
| distribution[token] = score |
| return self._select_generation_token( |
| distribution, |
| context_tokens=context_tokens, |
| generated_tokens=generated_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| preserve_dominant_candidates=preserve_dominant_candidates, |
| ) |
|
|
| def _prepare_generation_candidates( |
| self, |
| distribution: dict[str, float], |
| *, |
| generated_tokens: list[str], |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| preserve_dominant_candidates: bool = False, |
| ) -> list[tuple[str, float]]: |
| assert self.tokenizer is not None |
| assert self.embedding_model is not None |
|
|
| generated_word_count = self._generated_word_count(generated_tokens) |
| clause_words = self._words_since_clause_break(generated_tokens) |
| recent_rendered_words = self._recent_rendered_words(generated_tokens) |
| best_probability = max(distribution.values(), default=0.0) |
| adjusted: list[tuple[str, float]] = [] |
| for token, probability in sorted(distribution.items(), key=lambda item: item[1], reverse=True): |
| if token in self.tokenizer.special_tokens: |
| continue |
| if token == self.tokenizer.unk_token or probability <= 0.0: |
| continue |
| if not self._allowed_generation_token(token, generated_tokens): |
| continue |
| repeats_recent_pattern = self._would_repeat_recent_pattern( |
| token, |
| generated_tokens, |
| recent_rendered_words=recent_rendered_words, |
| ) |
| if ( |
| repeats_recent_pattern |
| and not ( |
| preserve_dominant_candidates |
| and best_probability > 0.0 |
| and probability >= best_probability * 0.80 |
| ) |
| ): |
| continue |
|
|
| score = probability |
| rendered = self._render_token(token) |
| punctuation_token = self._is_structural_punctuation_token(token) |
| starts_new_word = self._starts_new_word(token) |
| alphanumeric = "".join(character for character in rendered if character.isalnum()) |
| if generated_tokens and starts_new_word and alphanumeric: |
| previous_rendered = self._render_token(generated_tokens[-1]) |
| previous_alphanumeric = "".join( |
| character for character in previous_rendered if character.isalnum() |
| ) |
| if previous_alphanumeric.casefold() == alphanumeric.casefold(): |
| continue |
| common_connector = self._is_common_connector_token(token) |
| if ( |
| starts_new_word |
| and len(alphanumeric) == 1 |
| and not common_connector |
| ): |
| score *= 0.08 |
| recent_count = generated_tokens[-12:].count(token) |
| if recent_count > 0 and not common_connector: |
| score /= repetition_penalty ** (2 * recent_count) |
| if generated_tokens and token == generated_tokens[-1]: |
| score /= repetition_penalty**3 |
| if generated_tokens and token in generated_tokens[-4:] and not common_connector: |
| score *= 0.35 |
| if generated_tokens and not starts_new_word and self._starts_new_word(generated_tokens[-1]): |
| score *= 0.08 |
| if not generated_tokens and punctuation_token: |
| if best_probability <= 0.0 or probability < best_probability * 0.80: |
| score *= 0.01 |
| elif not generated_tokens and not starts_new_word: |
| score *= 0.02 |
| if punctuation_token: |
| if generated_tokens and self._is_structural_punctuation_token(generated_tokens[-1]): |
| score *= 0.05 |
| if clause_words >= 6: |
| score *= 1.0 + min(1.4, 0.18 * (clause_words - 5)) |
| elif generated_word_count >= 12: |
| score *= 1.1 |
| if score > 0.0: |
| adjusted.append((token, score)) |
|
|
| if not adjusted: |
| return [] |
| adjusted.sort(key=lambda item: item[1], reverse=True) |
| if top_k > 0: |
| adjusted = adjusted[:top_k] |
| if 0.0 < top_p < 1.0: |
| kept: list[tuple[str, float]] = [] |
| cumulative = 0.0 |
| total = sum(score for _, score in adjusted) |
| for token, score in adjusted: |
| normalized = score / total if total else 0.0 |
| kept.append((token, score)) |
| cumulative += normalized |
| if cumulative >= top_p: |
| break |
| adjusted = kept |
|
|
| if temperature <= 0.0: |
| return [(adjusted[0][0], 1.0)] |
|
|
| exponent = 1.0 / temperature |
| tempered = [ |
| (token, score**exponent) |
| for token, score in adjusted |
| if score > 0.0 |
| ] |
| total = sum(score for _, score in tempered) |
| if total <= 0.0: |
| return [] |
| return [(token, score / total) for token, score in tempered] |
|
|
| def _sample_generation_candidate( |
| self, |
| candidates: list[tuple[str, float]], |
| *, |
| context_tokens: list[str], |
| generated_tokens: list[str], |
| stochastic: bool = False, |
| ) -> str: |
| if not candidates: |
| return "" |
| if len(candidates) == 1: |
| return candidates[0][0] |
| top_probability = candidates[0][1] |
| second_probability = candidates[1][1] |
| top_has_clear_half_majority = top_probability >= 0.5 and ( |
| second_probability <= 0.0 |
| or top_probability - second_probability >= 0.02 |
| ) |
| if top_has_clear_half_majority or ( |
| second_probability > 0.0 and top_probability >= second_probability * 2.5 |
| ) or ( |
| top_probability >= 0.08 |
| and second_probability > 0.0 |
| and top_probability >= second_probability * 1.35 |
| ): |
| return candidates[0][0] |
| if stochastic: |
| threshold = random.random() |
| else: |
| seed_payload = "\u0002".join([*context_tokens, "<generated>", *generated_tokens, str(len(candidates))]) |
| seed = int.from_bytes(hashlib.sha256(seed_payload.encode("utf-8")).digest()[:8], "big") |
| threshold = random.Random(seed).random() |
| cumulative = 0.0 |
| for token, probability in candidates: |
| cumulative += probability |
| if threshold <= cumulative: |
| return token |
| return candidates[-1][0] |
|
|
| def _top_entries_from_vector( |
| self, |
| values: Vector, |
| limit: int, |
| ) -> list[dict[str, object]]: |
| if limit <= 0: |
| return [] |
| ranked = sorted( |
| enumerate(values), |
| key=lambda item: item[1], |
| reverse=True, |
| ) |
| return [ |
| self._token_entry(index, probability) |
| for index, probability in ranked[:limit] |
| if probability > 0.0 |
| ] |
|
|
| def _token_entry( |
| self, |
| index: int, |
| probability: float, |
| ) -> dict[str, object]: |
| assert self.embedding_model is not None |
| token = self.embedding_model.id_to_token[index] |
| return { |
| "token": token, |
| "text": self._render_token(token), |
| "probability": probability, |
| } |
|
|
| def _build_reasoning_summary( |
| self, |
| transition_order: int | None, |
| blend_weights: dict[str, float], |
| ) -> str: |
| dominant_source = max(blend_weights.items(), key=lambda item: item[1])[0] if blend_weights else "base" |
| if transition_order is not None: |
| transition_message = f" Transition prior is using order-{transition_order} context." |
| else: |
| transition_message = " Transition prior found no matching n-gram." |
|
|
| return ( |
| "Generation is running on analytical state, recurrent traces, and corpus-derived token transitions." |
| f"{transition_message}" |
| f" Dominant blend source: {dominant_source}." |
| ) |
|
|
| def _generated_word_count(self, tokens: list[str]) -> int: |
| return len(self._decode_tokens(tokens).split()) |
|
|
| def _is_structural_punctuation_text(self, text: str) -> bool: |
| if len(text) != 1: |
| return False |
| if self._is_word_joiner_text(text): |
| return False |
| category = unicodedata.category(text) |
| return category.startswith("P") |
|
|
| def _is_structural_punctuation_token(self, token: str) -> bool: |
| return self._is_structural_punctuation_text(self._render_token(token)) |
|
|
| def _is_structural_symbol_token(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| return len(rendered) == 1 and unicodedata.category(rendered).startswith("S") |
|
|
| def _is_word_joiner_token(self, token: str) -> bool: |
| return self._is_word_joiner_text(self._render_token(token)) |
|
|
| def _is_word_joiner_text(self, text: str) -> bool: |
| if len(text) != 1: |
| return False |
| category = unicodedata.category(text) |
| if category in ("Pc", "Pd", "Lm"): |
| return True |
| name = unicodedata.name(text, "") |
| return "APOSTROPHE" in name or ( |
| "SINGLE" in name and "QUOTATION MARK" in name |
| ) |
|
|
| def _can_start_line_with_word_joiner(self, token: str, generated_tokens: list[str]) -> bool: |
| rendered = self._render_token(token) |
| if len(rendered) != 1 or unicodedata.category(rendered) != "Pd": |
| return False |
| if not self._starts_new_word(token): |
| return False |
| return not generated_tokens or self._render_token(generated_tokens[-1]) == "\n" |
|
|
| def _can_start_answer_with_structural_punctuation(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| if len(rendered) != 1 or not self._starts_new_word(token): |
| return False |
| return unicodedata.category(rendered) in ("Ps", "Pi") |
|
|
| def _is_common_connector_token(self, token: str) -> bool: |
| rendered = self._render_token(token) |
| return rendered.isalpha() and len(rendered) <= 3 |
|
|
| def _can_attach_word_joiner(self, generated_tokens: list[str]) -> bool: |
| if not generated_tokens: |
| return False |
| rendered = self._render_token(generated_tokens[-1]) |
| if not rendered: |
| return False |
| if any(character.isalnum() for character in rendered): |
| return True |
| if len(rendered) != 1: |
| return False |
| return unicodedata.category(rendered) in ("Ps", "Pi") |
|
|
| def _words_since_clause_break(self, tokens: list[str]) -> int: |
| assert self.tokenizer is not None |
|
|
| words = 0 |
| for token in reversed(tokens): |
| if token in self.tokenizer.special_tokens: |
| continue |
| rendered = self._render_token(token) |
| if self._is_structural_punctuation_text(rendered): |
| break |
| if self._starts_new_word(token) and not self._is_punctuation_piece(rendered): |
| words += 1 |
| return words |
|
|
| def _should_stop_generation(self, generated_tokens: list[str]) -> bool: |
| if not generated_tokens: |
| return False |
| if not self._is_terminal_punctuation_text(self._render_token(generated_tokens[-1])): |
| return False |
| return self._generated_word_count(generated_tokens) >= 14 |
|
|
| def _is_terminal_punctuation_text(self, text: str) -> bool: |
| if not self._is_structural_punctuation_text(text): |
| return False |
| name = unicodedata.name(text, "") |
| return ( |
| "FULL STOP" in name |
| or "QUESTION MARK" in name |
| or "EXCLAMATION MARK" in name |
| ) |
|
|
| def _starts_new_word(self, token: str) -> bool: |
| assert self.tokenizer is not None |
| if token in self.tokenizer.special_tokens: |
| return True |
| if token.startswith(self.tokenizer.word_prefix): |
| return True |
| return len(token) == 1 and not token.isalnum() and not self._is_word_joiner_token(token) |
|
|
| def _decode_tokens(self, tokens: list[str]) -> str: |
| assert self.tokenizer is not None |
| return self.tokenizer.decode(tokens) |
|
|
| def _render_token(self, token: str) -> str: |
| assert self.tokenizer is not None |
| if token.startswith(self.tokenizer.word_prefix): |
| return token[len(self.tokenizer.word_prefix) :] |
| return token |
|
|
| def _require_fit(self) -> None: |
| if ( |
| self.tokenizer is None |
| or self.embedding_model is None |
| or self.memory_units is None |
| or self.readout_weights is None |
| or self.ternary_mask is None |
| or self.associative_keys is None |
| or self.associative_key_norms is None |
| or self.associative_values is None |
| or self.transition_tables is None |
| ): |
| raise RuntimeError("Call fit() before using the REFRAMR model.") |
|
|
| def _ensure_numeric_caches(self) -> None: |
| if np is None: |
| return |
| if self.readout_weights_array is None: |
| self._refresh_numeric_caches() |
|
|