import json import hashlib import random import site import string import sys import unicodedata from dataclasses import dataclass from pathlib import Path _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor" for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"): if _vendor_path.exists(): vendor_text = str(_vendor_path) if vendor_text not in sys.path: sys.path.insert(0, vendor_text) try: import numpy as np except ModuleNotFoundError: user_site = site.getusersitepackages() if user_site and user_site not in sys.path: sys.path.append(user_site) try: import numpy as np except ModuleNotFoundError: np = None if np is not None and not hasattr(np, "asarray"): np = None from .checkpoint import read_safetensor_file, write_safetensor_file from .config import ReframrConfig from .embeddings import EmbeddingModel, fit_ppmi_embedding_from_tokens from .hippo import AnalyticalMemoryUnit, analytical_embedding_drive, analytical_embedding_drive_fast from .linalg import Vector, dot, mean, norm, softmax, zeros_vector from .reservoir import apply_readout, ridge_regression_readout from .reasoning import reasoning_prefix from .ternary import apply_ternary_mask, derive_ternary_mask_from_states from .tokenizer import NativeTokenizer ASSOCIATIVE_BLEND = 0.42 TRANSITION_BLEND = 0.08 COPY_BLEND = 0.04 BASE_BLEND = 0.34 FAST_ASSOCIATIVE_BLEND = 0.06 FAST_TRANSITION_BLEND = 0.14 FAST_COPY_BLEND = 0.04 FAST_BASE_BLEND = 0.58 FAST_PREFERENCE_BLEND = 0.15 FAST_ANSWER_BLEND = 0.30 PROMPT_READOUT_LOGIT_ZSCORE_SCALE = 0.48 ASSOCIATIVE_TOP_K = 12 ANSWER_TOP_K = 48 ANSWER_START_TOP_K = 32 ANSWER_SEQUENCE_MATCH_FLOOR = 0.30 ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR = 0.45 ANSWER_SEQUENCE_LOCK_FLOOR = 0.55 ANSWER_SEQUENCE_SPIKE_CONFIDENCE = 0.80 READOUT_LOGIT_ZSCORE_SCALE = 0.22 TRACE_IDENTITY_SCALE = 0.78 TRACE_IDENTITY_HASHES = ( (1103515245, 12345, 214013, 2531011), (1664525, 1013904223, 22695477, 1), (69069, 362437, 134775813, 17), (134775813, 97, 1103515245, 31), (22695477, 911, 1664525, 73), (214013, 2531011, 69069, 19), (48271, 0, 69621, 11), (16807, 37, 40692, 101), (279470273, 173, 1299709, 53), (39916801, 29, 2147483629, 7), ) NGRAM_KEY_SEPARATOR = "\u0001" TRANSITION_ORDERS = (10, 8, 6, 5, 4, 3, 2, 1) DEFAULT_GENERATION_TEMPERATURE = 0.82 DEFAULT_GENERATION_TOP_K = 24 DEFAULT_GENERATION_TOP_P = 0.92 DEFAULT_REPETITION_PENALTY = 1.18 ANSWER_SEQUENCE_MAX_TOKENS = 192 RUNTIME_ARRAY_DTYPE = np.float32 if np is not None else None @dataclass(frozen=True, slots=True) class CharacterCountFact: character: str word: str count: int surface_seed: int def _normalize_vector(values: Vector) -> Vector: total = sum(values) if total <= 0.0: return [0.0 for _ in values] return [value / total for value in values] def _encode_ngram_key(tokens: tuple[str, ...]) -> str: return NGRAM_KEY_SEPARATOR.join(tokens) def _decode_ngram_key(key: str) -> tuple[str, ...]: return tuple(part for part in key.split(NGRAM_KEY_SEPARATOR) if part) def _last_index(values: list[str], target: str) -> int | None: for index in range(len(values) - 1, -1, -1): if values[index] == target: return index return None @dataclass(slots=True) class DecodeState: hidden_states: list[Vector] context_traces: list[Vector] combined_state: Vector context_tokens: list[str] answer_anchor_state: Vector | None = None answer_matches: list[tuple[float, int, int]] | None = None answer_start_matches: list[tuple[float, int, int]] | None = None answer_sequence_matches: list[tuple[float, int, int]] | None = None prompt_answer_prior: object | None = None prompt_answer_start_prior: object | None = None @dataclass(slots=True) class ReframrModel: config: ReframrConfig tokenizer: NativeTokenizer | None = None embedding_model: EmbeddingModel | None = None memory_units: list[AnalyticalMemoryUnit] | None = None ternary_scale: float = 1.0 ternary_mask: list[int] | None = None ternary_mask_array: object | None = None readout_weights: list[list[float]] | None = None readout_weights_array: object | None = None readout_bias: Vector | None = None readout_bias_array: object | None = None prompt_answer_weights: list[list[float]] | None = None prompt_answer_weights_array: object | None = None prompt_answer_bias: Vector | None = None prompt_answer_bias_array: object | None = None prompt_answer_start_weights: list[list[float]] | None = None prompt_answer_start_weights_array: object | None = None prompt_answer_start_bias: Vector | None = None prompt_answer_start_bias_array: object | None = None trace_token_weights: Vector | None = None trace_token_weights_array: object | None = None trace_embedding_table_array: object | None = None preference_bias: Vector | None = None preference_bias_array: object | None = None preference_valid_mask_array: object | None = None state_offset: Vector | None = None state_offset_array: object | None = None associative_keys: list[Vector] | None = None associative_keys_array: object | None = None associative_key_norms: list[float] | None = None associative_key_norms_array: object | None = None associative_values: list[int] | None = None associative_values_array: object | None = None associative_valid_mask_array: object | None = None answer_keys: list[Vector] | None = None answer_keys_array: object | None = None answer_key_norms: list[float] | None = None answer_key_norms_array: object | None = None answer_similarity_keys_array: object | None = None answer_similarity_key_norms_array: object | None = None answer_similarity_mask_array: object | None = None answer_values: list[int] | None = None answer_values_array: object | None = None answer_valid_mask_array: object | None = None answer_start_keys: list[Vector] | None = None answer_start_keys_array: object | None = None answer_start_key_norms: list[float] | None = None answer_start_key_norms_array: object | None = None answer_start_similarity_keys_array: object | None = None answer_start_similarity_key_norms_array: object | None = None answer_start_values: list[int] | None = None answer_start_values_array: object | None = None answer_start_valid_mask_array: object | None = None answer_sequence_keys: list[Vector] | None = None answer_sequence_keys_array: object | None = None answer_sequence_key_norms: list[float] | None = None answer_sequence_key_norms_array: object | None = None answer_sequence_similarity_keys_array: object | None = None answer_sequence_similarity_key_norms_array: object | None = None answer_sequence_prompt_tokens: list[list[int]] | None = None answer_sequence_prompt_tokens_array: object | None = None answer_sequence_tokens: list[list[int]] | None = None answer_sequence_tokens_array: object | None = None answer_sequence_prompt_weight_maps: list[dict[int, float]] | None = None answer_sequence_prompt_weight_norms: list[float] | None = None answer_sequence_prompt_bigram_sets: list[set[tuple[int, int]]] | None = None answer_sequence_prompt_trigram_sets: list[set[tuple[int, int, int]]] | None = None answer_sequence_prompt_number_sets: list[set[str]] | None = None answer_sequence_prompt_inverted_index: dict[int, list[int]] | None = None answer_sequence_prompt_specificity: dict[int, float] | None = None transition_tables: dict[int, dict[tuple[str, ...], dict[str, float]]] | None = None def fit(self, text: str) -> "ReframrModel": self.tokenizer = NativeTokenizer.train( text, vocab_size=self.config.tokenizer_vocab_size, min_pair_frequency=self.config.tokenizer_min_pair_frequency, lowercase=self.config.lowercase, ) tokens = self.tokenizer.encode(text) if len(tokens) < 2: raise ValueError("REFRAMR needs at least two tokens to derive a next-token readout.") self.embedding_model = fit_ppmi_embedding_from_tokens( tokens, embedding_dim=self.config.embedding_dim, window_size=self.config.window_size, min_frequency=self.config.min_frequency, max_vocab=self.config.max_vocab, ) self.memory_units = [ AnalyticalMemoryUnit(self.config.state_dim, timescale) for timescale in self.config.timescales ] token_counts: dict[str, float] = {} for token in tokens: token_counts[token] = token_counts.get(token, 0.0) + 1.0 self.trace_token_weights = self._derive_trace_token_weights_from_counts(token_counts) raw_states, targets, target_ids = self._collect_training_examples(tokens) self.ternary_scale, self.ternary_mask = derive_ternary_mask_from_states(raw_states) analytical_states = [ apply_ternary_mask(state, self.ternary_mask, self.ternary_scale) for state in raw_states ] self.associative_keys = [state[:] for state in analytical_states] self.associative_key_norms = [norm(state) for state in analytical_states] self.associative_values = target_ids[:] self.answer_keys = [] self.answer_key_norms = [] self.answer_values = [] self.answer_start_keys = [] self.answer_start_key_norms = [] self.answer_start_values = [] self.answer_sequence_keys = [] self.answer_sequence_key_norms = [] self.answer_sequence_prompt_tokens = [] self.answer_sequence_tokens = [] self.prompt_answer_weights = [] self.prompt_answer_bias = [0.0 for _ in self.embedding_model.id_to_token] self.prompt_answer_start_weights = [] self.prompt_answer_start_bias = [0.0 for _ in self.embedding_model.id_to_token] self.transition_tables = self._build_transition_tables(tokens) self._fit_answer_memory_from_text(text) self.readout_weights = ridge_regression_readout( analytical_states, targets, regularization=self.config.regularization, ) self.readout_bias = [0.0 for _ in self.embedding_model.id_to_token] self.preference_bias = [0.0 for _ in self.embedding_model.id_to_token] self.state_offset = [0.0 for _ in analytical_states[0]] if analytical_states else [] self._refresh_numeric_caches() return self def _fit_answer_memory_from_text(self, text: str) -> None: assert self.tokenizer is not None assert self.embedding_model is not None if ( self.answer_keys is None or self.answer_key_norms is None or self.answer_values is None or self.answer_start_keys is None or self.answer_start_key_norms is None or self.answer_start_values is None or self.answer_sequence_keys is None or self.answer_sequence_key_norms is None or self.answer_sequence_prompt_tokens is None or self.answer_sequence_tokens is None ): return for line in text.splitlines(): if "" not in line: continue prompt_text, answer_text = line.split("", 1) prompt_text = prompt_text.strip() answer_text = answer_text.strip() if not prompt_text or not answer_text: continue prompt_tokens = self.tokenizer.encode(prompt_text) + [""] answer_tokens = [ token for token in self.tokenizer.encode(answer_text) if token in self.embedding_model.token_to_id and token not in self.tokenizer.special_tokens ] if not prompt_tokens or not answer_tokens: continue key = self._encode_context(prompt_tokens) key_norm = norm(key) if key_norm <= 0.0: continue answer_ids = [ self.embedding_model.token_to_id[token] for token in answer_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] ] prompt_ids = [ self.embedding_model.token_to_id[token] for token in prompt_tokens[:ANSWER_SEQUENCE_MAX_TOKENS] if token in self.embedding_model.token_to_id and token not in self.tokenizer.special_tokens ] if not answer_ids: continue self.answer_keys.append(key[:]) self.answer_key_norms.append(key_norm) self.answer_values.append(answer_ids[0]) self.answer_start_keys.append(key[:]) self.answer_start_key_norms.append(key_norm) self.answer_start_values.append(answer_ids[0]) self.answer_sequence_keys.append(key[:]) self.answer_sequence_key_norms.append(key_norm) self.answer_sequence_prompt_tokens.append( prompt_ids + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(prompt_ids))] ) self.answer_sequence_tokens.append( answer_ids + [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(answer_ids))] ) def predict_next_distribution( self, context: str, *, reasoning_mode: str | None = None, ) -> dict[str, float]: self._require_fit() assert self.tokenizer is not None assert self.embedding_model is not None probabilities = self.predict_next_token_distribution( context, reasoning_mode=reasoning_mode, ) distribution: dict[str, float] = {} for token, probability in probabilities.items(): rendered = self._render_token(token) distribution[rendered] = distribution.get(rendered, 0.0) + probability return distribution def predict_next_token_distribution( self, context: str, *, reasoning_mode: str | None = None, ) -> dict[str, float]: self._require_fit() assert self.tokenizer is not None assert self.embedding_model is not None assert self.readout_weights is not None active_mode = reasoning_mode or self.config.default_reasoning_profile context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) return self._predict_next_token_distribution_from_tokens(context_tokens) def generate_text( self, context: str, *, max_tokens: int = 64, reasoning_mode: str | None = None, temperature: float = 0.0, top_k: int = DEFAULT_GENERATION_TOP_K, top_p: float = DEFAULT_GENERATION_TOP_P, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, ) -> str: character_count_response = self._character_count_response( context, temperature=temperature, ) if character_count_response is not None: return character_count_response self._require_fit() self._ensure_numeric_caches() assert self.tokenizer is not None if ( np is not None and self.readout_weights_array is not None and self.embedding_model is not None and len(self.embedding_model.id_to_token) >= 1024 ): return self._generate_text_fast( context, max_tokens=max_tokens, reasoning_mode=reasoning_mode, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, ) active_mode = reasoning_mode or self.config.default_reasoning_profile _, context_tokens = self._generation_prompt_tokens(context, active_mode) decode_state = self._build_decode_state(context_tokens) generated_tokens: list[str] = [] for _ in range(max_tokens): distribution, _ = self._score_next_token_from_state( decode_state, include_trace=False, generated_tokens=generated_tokens, ) next_token = self._select_generation_token( distribution, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=self._answer_decode_has_continuation( decode_state, generated_tokens, ), ) if not next_token: break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) if self._should_stop_answer_sequence(decode_state, generated_tokens): break if self._should_stop_generation( generated_tokens ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): break overflow_budget = 6 while ( generated_tokens and not self._starts_new_word(generated_tokens[-1]) and overflow_budget > 0 ): distribution, _ = self._score_next_token_from_state( decode_state, include_trace=False, generated_tokens=generated_tokens, ) next_token = self._select_generation_token( distribution, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=self._answer_decode_has_continuation( decode_state, generated_tokens, ), ) if not next_token or self._starts_new_word(next_token): break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) overflow_budget -= 1 return self._decode_tokens(generated_tokens) @staticmethod def _character_count_fact(context: str) -> CharacterCountFact | None: normalized = unicodedata.normalize("NFKC", context).strip() tokens = ReframrModel._character_count_word_tokens(normalized) if not tokens: return None lowered = [token.casefold() for token in tokens] count_terms = {"count", "counts", "counting", "many"} unit_terms = {"character", "characters", "letter", "letters"} if not any(token in count_terms for token in lowered): return None if not any(token in unit_terms for token in lowered) and "count" not in lowered: return None filler_terms = {"a", "an", "the", "single", "one", "please"} word_markers = {"in", "inside"} char_index = ReframrModel._character_count_target_index( lowered, unit_terms=unit_terms, filler_terms=filler_terms, ) word_index = ReframrModel._character_count_word_index( lowered, char_index=char_index, filler_terms=filler_terms, word_markers=word_markers, ) if char_index is None or word_index is None: return None character = tokens[char_index] word = tokens[word_index] if len(character) != 1 or not word: return None order_offset = 0 if char_index < word_index else 1 surface_seed = ((char_index + 1) * 7 + (word_index + 1) * 3 + len(tokens) + order_offset) % 4 return CharacterCountFact( character=character, word=word, count=word.casefold().count(character.casefold()), surface_seed=surface_seed, ) @staticmethod def _character_count_word_tokens(text: str) -> list[str]: tokens: list[str] = [] current: list[str] = [] for character in text: if character != "_" and character.isalnum(): current.append(character) continue if current: tokens.append("".join(current)) current = [] if current: tokens.append("".join(current)) return tokens @staticmethod def _character_count_target_index( tokens: list[str], *, unit_terms: set[str], filler_terms: set[str], ) -> int | None: for index, token in enumerate(tokens): if token not in unit_terms: continue for adjacent in (index - 1, index + 1): if 0 <= adjacent < len(tokens) and len(tokens[adjacent]) == 1: return adjacent before = ReframrModel._nearest_content_index(tokens, index - 1, -1, filler_terms) after = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) for candidate in (before, after): if candidate is not None and len(tokens[candidate]) == 1: return candidate for index, token in enumerate(tokens): if token not in {"count", "counts", "counting"}: continue candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) if candidate is not None and tokens[candidate] in unit_terms: candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) if candidate is not None and len(tokens[candidate]) == 1: return candidate return None @staticmethod def _character_count_word_index( tokens: list[str], *, char_index: int | None, filler_terms: set[str], word_markers: set[str], ) -> int | None: for index, token in enumerate(tokens): if token != "word": continue candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: return candidate for index, token in enumerate(tokens): if token not in word_markers: continue candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms) if candidate is not None and tokens[candidate] == "word": candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms) if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1: return candidate skipped_terms = { "how", "many", "count", "counts", "counting", "letter", "letters", "character", "characters", "word", "there", "are", "is", "appear", "appears", "times", } | filler_terms | word_markers for index in range(len(tokens) - 1, -1, -1): if index == char_index: continue if len(tokens[index]) <= 1 or tokens[index] in skipped_terms: continue return index return None @staticmethod def _nearest_content_index( tokens: list[str], start: int, direction: int, skipped_terms: set[str], ) -> int | None: index = start while 0 <= index < len(tokens): if tokens[index] not in skipped_terms: return index index += direction return None @classmethod def _character_count_response(cls, context: str, *, temperature: float = 0.0) -> str | None: fact = cls._character_count_fact(context) if fact is None: return None return cls._render_character_count_fact(fact, temperature=temperature) @staticmethod def _render_character_count_fact(fact: CharacterCountFact, *, temperature: float = 0.0) -> str: character_label = f"'{fact.character}'" word_label = f"'{fact.word}'" character_noun = "character" if fact.count == 1 else "characters" plural_times = "" if fact.count == 1 else "s" surfaces = ( f"There {'is' if fact.count == 1 else 'are'} {fact.count} {character_label} {character_noun} in {word_label}.", f"{word_label} contains {fact.count} {character_label} {character_noun}.", f"In {word_label}, {character_label} appears {fact.count} time{plural_times}.", f"The count is {fact.count} for {character_label} in {word_label}.", ) if temperature > 0.0: return surfaces[(random.randrange(len(surfaces)) + fact.surface_seed) % len(surfaces)] return surfaces[fact.surface_seed % len(surfaces)] def _generate_text_fast( self, context: str, *, max_tokens: int, reasoning_mode: str | None, temperature: float, top_k: int, top_p: float, repetition_penalty: float, ) -> str: assert self.tokenizer is not None active_mode = reasoning_mode or self.config.default_reasoning_profile _, context_tokens = self._generation_prompt_tokens(context, active_mode) decode_state = self._build_decode_state(context_tokens) generated_tokens: list[str] = [] for _ in range(max_tokens): probabilities, _ = self._score_next_token_array_from_state( decode_state, include_associative=True, generated_tokens=generated_tokens, ) next_token = self._select_generation_token_from_array( probabilities, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=self._answer_decode_has_continuation( decode_state, generated_tokens, ), ) if not next_token: break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) if self._should_stop_answer_sequence(decode_state, generated_tokens): break if self._should_stop_generation( generated_tokens ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): break overflow_budget = 6 while ( generated_tokens and not self._starts_new_word(generated_tokens[-1]) and overflow_budget > 0 ): probabilities, _ = self._score_next_token_array_from_state( decode_state, include_associative=True, generated_tokens=generated_tokens, ) next_token = self._select_generation_token_from_array( probabilities, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=self._answer_decode_has_continuation( decode_state, generated_tokens, ), ) if not next_token or self._starts_new_word(next_token): break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) overflow_budget -= 1 return self._decode_tokens(generated_tokens) def trace_next_token( self, context: str, *, reasoning_mode: str | None = None, top_k: int = 5, ) -> dict[str, object]: self._require_fit() assert self.tokenizer is not None active_mode = reasoning_mode or self.config.default_reasoning_profile context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context) _, trace = self._score_next_token_from_tokens( context_tokens, top_k=top_k, include_trace=True, ) trace.update( { "context": context, "reasoning_mode": active_mode, "reasoning_tokens": reasoning_prefix(active_mode), "context_tokens": context_tokens, } ) return trace def trace_generation( self, context: str, *, max_tokens: int = 16, reasoning_mode: str | None = None, top_k: int = 5, temperature: float = 0.0, top_p: float = DEFAULT_GENERATION_TOP_P, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, ) -> dict[str, object]: character_count_response = self._character_count_response( context, temperature=temperature, ) if character_count_response is not None: active_mode = reasoning_mode or self.config.default_reasoning_profile prompt = context if "" in context else f"{context} " return { "context": context, "prompt": prompt, "reasoning_mode": active_mode, "reasoning_tokens": reasoning_prefix(active_mode), "generation_policy": { "temperature": temperature, "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), "top_p": top_p, "repetition_penalty": repetition_penalty, }, "prompt_tokens": [], "generated_tokens": [], "generated_text": character_count_response, "generated_token_count": len(character_count_response.split()), "steps": [], "reasoning_summary": ( "The prompt matched the generic character-counting path, so Reframr " "read the requested character and word from the prompt and counted " "the characters directly." ), } self._require_fit() assert self.tokenizer is not None active_mode = reasoning_mode or self.config.default_reasoning_profile prompt, context_tokens = self._generation_prompt_tokens(context, active_mode) decode_state = self._build_decode_state(context_tokens) prompt_tokens = decode_state.context_tokens[:] generated_tokens: list[str] = [] steps: list[dict[str, object]] = [] for step_index in range(1, max_tokens + 1): distribution, trace = self._score_next_token_from_state( decode_state, top_k=top_k, include_trace=True, generated_tokens=generated_tokens, ) next_token = self._select_generation_token( distribution, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=max(DEFAULT_GENERATION_TOP_K, top_k), top_p=top_p, repetition_penalty=repetition_penalty, ) if not next_token: break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) trace["step"] = step_index trace["chosen_token"] = next_token trace["chosen_text"] = self._render_token(next_token) trace["chosen_probability"] = distribution[next_token] steps.append(trace) if self._should_stop_generation( generated_tokens ) and not self._answer_decode_has_continuation(decode_state, generated_tokens): break overflow_budget = 6 while ( generated_tokens and not self._starts_new_word(generated_tokens[-1]) and overflow_budget > 0 ): distribution, trace = self._score_next_token_from_state( decode_state, top_k=top_k, include_trace=True, generated_tokens=generated_tokens, ) next_token = self._select_generation_token( distribution, context_tokens=decode_state.context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=max(DEFAULT_GENERATION_TOP_K, top_k), top_p=top_p, repetition_penalty=repetition_penalty, ) if not next_token or self._starts_new_word(next_token): break generated_tokens.append(next_token) self._advance_decode_state(decode_state, next_token) trace["step"] = len(steps) + 1 trace["chosen_token"] = next_token trace["chosen_text"] = self._render_token(next_token) trace["chosen_probability"] = distribution[next_token] steps.append(trace) overflow_budget -= 1 return { "context": context, "prompt": prompt, "reasoning_mode": active_mode, "reasoning_tokens": reasoning_prefix(active_mode), "generation_policy": { "temperature": temperature, "top_k": max(DEFAULT_GENERATION_TOP_K, top_k), "top_p": top_p, "repetition_penalty": repetition_penalty, }, "prompt_tokens": prompt_tokens, "generated_tokens": generated_tokens, "generated_text": self._decode_tokens(generated_tokens), "generated_token_count": len(generated_tokens), "steps": steps, } def _generation_prompt_tokens(self, context: str, active_mode: str) -> tuple[str, list[str]]: assert self.tokenizer is not None prompt = context if "" in context else f"{context} " prefix = reasoning_prefix(active_mode) prompt_tokens = self.tokenizer.encode(prompt) if ( "" in prompt_tokens and "" not in prompt_tokens and "" not in prefix ): prompt_tokens = [""] + prompt_tokens return prompt, prefix + prompt_tokens def _predict_next_token_distribution_from_tokens( self, context_tokens: list[str], ) -> dict[str, float]: decode_state = self._build_decode_state(context_tokens) return self._predict_next_token_distribution_from_state(decode_state) def _predict_next_token_distribution_from_state( self, decode_state: DecodeState, ) -> dict[str, float]: probabilities, _ = self._score_next_token_from_state( decode_state, include_trace=False, ) return probabilities @staticmethod def _answer_sequence_should_lock( *, answer_sequence_confidence: float, answer_sequence_match_confidence: float, has_answer_sequence_prior: bool, ) -> bool: if not has_answer_sequence_prior or answer_sequence_confidence <= 0.0: return False if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: return True return ( answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR and answer_sequence_confidence <= ANSWER_SEQUENCE_SPIKE_CONFIDENCE ) @staticmethod def _answer_start_blend_weights( *, answer_sequence_match_confidence: float, ) -> dict[str, float]: if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR: return { "prompt_answer_start": 0.35, "prompt_answer": 0.10, "answer_sequence": 0.45, "answer_start": 0.10, } return { "prompt_answer_start": 0.55, "prompt_answer": 0.20, "answer_sequence": 0.15, "answer_start": 0.10, } def _score_next_token_from_tokens( self, context_tokens: list[str], *, top_k: int = 5, include_trace: bool = True, ) -> tuple[dict[str, float], dict[str, object]]: decode_state = self._build_decode_state(context_tokens) return self._score_next_token_from_state( decode_state, top_k=top_k, include_trace=include_trace, ) def _score_next_token_from_state( self, decode_state: DecodeState, *, top_k: int = 5, include_trace: bool = True, generated_tokens: list[str] | None = None, ) -> tuple[dict[str, float], dict[str, object]]: assert self.embedding_model is not None assert self.readout_weights is not None generated_tokens = generated_tokens or [] state = self._masked_decode_state(decode_state) logits = self._apply_readout_fast(state) base_probabilities = self._calibrated_softmax(logits) if decode_state.answer_matches is None: decode_state.answer_matches = self._score_answer_matches( decode_state.answer_anchor_state, limit=max(ANSWER_TOP_K, top_k) if include_trace else ANSWER_TOP_K, ) answer_matches = decode_state.answer_matches if decode_state.answer_start_matches is None: decode_state.answer_start_matches = self._score_answer_start_matches( decode_state.answer_anchor_state, limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, ) answer_start_matches = decode_state.answer_start_matches if decode_state.answer_sequence_matches is None: decode_state.answer_sequence_matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K, ) answer_sequence_matches = decode_state.answer_sequence_matches answer_prior = self._answer_prior_from_matches(answer_matches, generated_tokens) answer_start_prior = self._answer_prior_from_matches(answer_start_matches, generated_tokens) answer_sequence_prior = self._answer_sequence_prior_from_matches( answer_sequence_matches, generated_tokens, ) answer_sequence_confidence = max(answer_sequence_prior) if answer_sequence_prior else 0.0 answer_sequence_match_confidence = ( answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 ) has_answer_sequence_prior = any(value > 0.0 for value in answer_sequence_prior) answer_locked = self._answer_sequence_should_lock( answer_sequence_confidence=answer_sequence_confidence, answer_sequence_match_confidence=answer_sequence_match_confidence, has_answer_sequence_prior=has_answer_sequence_prior, ) if decode_state.prompt_answer_prior is None: decode_state.prompt_answer_prior = self._prompt_answer_readout_prior( decode_state.answer_anchor_state, start=False, ) prompt_answer_prior = decode_state.prompt_answer_prior prompt_answer_start_prior = ( decode_state.prompt_answer_start_prior if not generated_tokens else [0.0 for _ in self.embedding_model.id_to_token] ) if not generated_tokens and prompt_answer_start_prior is None: decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior( decode_state.answer_anchor_state, start=True, ) prompt_answer_start_prior = decode_state.prompt_answer_start_prior use_answer_start = ( not generated_tokens and ( any(value > 0.0 for value in answer_start_prior) or any(value > 0.0 for value in prompt_answer_start_prior) ) ) if answer_locked: answer_prior = answer_sequence_prior elif use_answer_start: start_blend = self._answer_start_blend_weights( answer_sequence_match_confidence=answer_sequence_match_confidence ) answer_prior = self._weighted_prior_sum( [ (start_blend["prompt_answer_start"], prompt_answer_start_prior), (start_blend["prompt_answer"], prompt_answer_prior), (start_blend["answer_sequence"], answer_sequence_prior), (start_blend["answer_start"], answer_start_prior), ], ) elif any(value > 0.0 for value in answer_sequence_prior): answer_prior = self._weighted_prior_sum( [ (0.50, prompt_answer_prior), (0.30, answer_sequence_prior), (0.20, answer_prior), ], ) elif any(value > 0.0 for value in prompt_answer_prior): answer_prior = self._weighted_prior_sum( [ (0.65, prompt_answer_prior), (0.35, answer_prior), ], ) associative_matches = ( [] if use_answer_start else self._score_associative_matches( state, limit=max(ASSOCIATIVE_TOP_K, top_k) if include_trace else ASSOCIATIVE_TOP_K, ) ) associative_prior = ( [0.0 for _ in self.embedding_model.id_to_token] if use_answer_start else self._associative_prior_from_matches(associative_matches) ) transition_prior, transition_order = self._transition_prior_with_order(decode_state.context_tokens) copy_prior = self._copy_prior(decode_state.context_tokens) preference_prior = self._preference_prior() probabilities, blend_weights = self._blend_probabilities( base_probabilities, answer_prior, associative_prior, transition_prior, copy_prior, preference_prior, transition_order=transition_order, generated_count=len(generated_tokens), answer_locked=answer_locked, answer_guided_start=use_answer_start, ) distribution = { token: probabilities[index] for index, token in enumerate(self.embedding_model.id_to_token) } if not include_trace: return distribution, {} trace = { "state_norm": norm(state), "blend_weights": blend_weights, "transition_order": transition_order, "base_top_predictions": self._top_entries_from_vector(base_probabilities, top_k), "answer_top_predictions": self._top_entries_from_vector(answer_prior, top_k), "prompt_answer_top_predictions": self._top_entries_from_vector(prompt_answer_prior, top_k), "prompt_answer_start_top_predictions": self._top_entries_from_vector(prompt_answer_start_prior, top_k), "answer_start_top_predictions": self._top_entries_from_vector(answer_start_prior, top_k), "answer_sequence_top_predictions": self._top_entries_from_vector(answer_sequence_prior, top_k), "associative_top_predictions": self._top_entries_from_vector(associative_prior, top_k), "transition_top_predictions": self._top_entries_from_vector(transition_prior, top_k), "copy_top_predictions": self._top_entries_from_vector(copy_prior, top_k), "preference_top_predictions": self._top_entries_from_vector(preference_prior, top_k), "final_top_predictions": self._top_entries_from_vector(probabilities, top_k), "associative_matches": [ { "example_index": example_index, "similarity": similarity, **self._token_entry(token_id, similarity), } for similarity, token_id, example_index in associative_matches[:top_k] ], "answer_matches": [ { "example_index": example_index, "similarity": similarity, **self._token_entry(token_id, similarity), } for similarity, token_id, example_index in answer_matches[:top_k] ], "answer_start_matches": [ { "example_index": example_index, "similarity": similarity, **self._token_entry(token_id, similarity), } for similarity, token_id, example_index in answer_start_matches[:top_k] ], "answer_sequence_matches": [ { "example_index": example_index, "similarity": similarity, } for similarity, _, example_index in answer_sequence_matches[:top_k] ], "reasoning_summary": self._build_reasoning_summary( transition_order, blend_weights, ), } return distribution, trace def _score_next_token_array_from_state( self, decode_state: DecodeState, *, include_associative: bool, generated_tokens: list[str] | None = None, ) -> tuple[object, dict[str, float]]: assert np is not None assert self.embedding_model is not None generated_tokens = generated_tokens or [] state = self._masked_decode_state_array(decode_state) logits = self._apply_readout_array(state) base_probabilities = self._calibrated_softmax_array(logits) if decode_state.answer_matches is None: decode_state.answer_matches = self._score_answer_matches(decode_state.answer_anchor_state) answer_prior = np.asarray( self._answer_prior_from_matches( decode_state.answer_matches, generated_tokens, ), dtype=np.float64, ) if decode_state.answer_sequence_matches is None: decode_state.answer_sequence_matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, ) answer_sequence_matches = decode_state.answer_sequence_matches answer_sequence_prior = np.asarray( self._answer_sequence_prior_from_matches( answer_sequence_matches, generated_tokens, ), dtype=np.float64, ) answer_sequence_confidence = ( float(answer_sequence_prior.max()) if answer_sequence_prior.size else 0.0 ) answer_sequence_match_confidence = ( answer_sequence_matches[0][0] if answer_sequence_matches else 0.0 ) has_answer_sequence_prior = bool(np.any(answer_sequence_prior > 0.0)) answer_locked = self._answer_sequence_should_lock( answer_sequence_confidence=answer_sequence_confidence, answer_sequence_match_confidence=answer_sequence_match_confidence, has_answer_sequence_prior=has_answer_sequence_prior, ) if decode_state.prompt_answer_prior is None: decode_state.prompt_answer_prior = self._prompt_answer_readout_prior_array( decode_state.answer_anchor_state, start=False, ) prompt_answer_prior = decode_state.prompt_answer_prior prompt_answer_start_prior = np.zeros_like(base_probabilities) use_answer_start = False if answer_locked: answer_prior = answer_sequence_prior elif not generated_tokens: if decode_state.prompt_answer_start_prior is None: decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior_array( decode_state.answer_anchor_state, start=True, ) prompt_answer_start_prior = decode_state.prompt_answer_start_prior if decode_state.answer_start_matches is None: decode_state.answer_start_matches = self._score_answer_start_matches( decode_state.answer_anchor_state ) answer_start_prior = np.asarray( self._answer_prior_from_matches( decode_state.answer_start_matches, generated_tokens, ), dtype=np.float64, ) if np.any(answer_start_prior > 0.0) or np.any(prompt_answer_start_prior > 0.0): start_blend = self._answer_start_blend_weights( answer_sequence_match_confidence=answer_sequence_match_confidence ) answer_prior = self._weighted_prior_sum_array( [ (start_blend["prompt_answer_start"], prompt_answer_start_prior), (start_blend["prompt_answer"], prompt_answer_prior), (start_blend["answer_sequence"], answer_sequence_prior), (start_blend["answer_start"], answer_start_prior), ], ) use_answer_start = True if answer_locked: answer_prior = answer_sequence_prior elif not use_answer_start and np.any(answer_sequence_prior > 0.0): answer_prior = self._weighted_prior_sum_array( [ (0.50, prompt_answer_prior), (0.30, answer_sequence_prior), (0.20, answer_prior), ], ) elif not use_answer_start and np.any(prompt_answer_prior > 0.0): answer_prior = self._weighted_prior_sum_array( [ (0.65, prompt_answer_prior), (0.35, answer_prior), ], ) if include_associative and not use_answer_start: associative_prior = np.asarray( self._associative_prior_from_matches( self._score_associative_matches(state) ), dtype=np.float64, ) else: associative_prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) transition_prior, transition_order = self._transition_prior_array_with_order( decode_state.context_tokens ) copy_prior = self._copy_prior_array(decode_state.context_tokens) preference_prior = self._preference_prior_array() return self._blend_probability_arrays( base_probabilities, answer_prior, associative_prior, transition_prior, copy_prior, preference_prior, transition_order=transition_order, generated_count=len(generated_tokens), answer_locked=answer_locked, answer_guided_start=use_answer_start, ) def _calibrated_softmax( self, logits: Vector, *, scale: float = READOUT_LOGIT_ZSCORE_SCALE, ) -> Vector: if np is not None: return self._calibrated_softmax_array( np.asarray(logits, dtype=np.float64), scale=scale, ).tolist() if not logits: return [] center = mean(logits) variance = mean([(value - center) * (value - center) for value in logits]) spread = variance**0.5 if spread <= 1e-12: return softmax(logits) calibrated = [ max(-20.0, min(20.0, ((value - center) / spread) * scale)) for value in logits ] return softmax(calibrated) def _calibrated_softmax_array( self, logits: object, *, scale: float = READOUT_LOGIT_ZSCORE_SCALE, ) -> object: assert np is not None values = np.asarray(logits, dtype=np.float64) if values.size == 0: return values spread = float(values.std()) if spread > 1e-12: values = ((values - float(values.mean())) / spread) * scale values = np.clip(values, -20.0, 20.0) else: values = values - float(values.max()) values = values - float(values.max()) exponentials = np.exp(values) total = float(exponentials.sum()) if total <= 0.0: return np.full(values.shape, 1.0 / max(1, values.size), dtype=np.float64) return exponentials / total def _weighted_prior_sum(self, sources: list[tuple[float, Vector]]) -> Vector: assert self.embedding_model is not None active_sources = [ (weight, vector) for weight, vector in sources if weight > 0.0 and any(value > 0.0 for value in vector) ] if not active_sources: return [0.0 for _ in self.embedding_model.id_to_token] total_weight = sum(weight for weight, _ in active_sources) merged = [0.0 for _ in self.embedding_model.id_to_token] for weight, vector in active_sources: normalized_weight = weight / total_weight for index, value in enumerate(vector): merged[index] += normalized_weight * value return _normalize_vector(merged) def _weighted_prior_sum_array(self, sources: list[tuple[float, object]]) -> object: assert np is not None assert self.embedding_model is not None active_sources = [ (weight, np.asarray(vector, dtype=np.float64)) for weight, vector in sources if weight > 0.0 and np.any(np.asarray(vector, dtype=np.float64) > 0.0) ] if not active_sources: return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) total_weight = sum(weight for weight, _ in active_sources) merged = np.zeros_like(active_sources[0][1], dtype=np.float64) for weight, vector in active_sources: merged += (weight / total_weight) * vector total = float(merged.sum()) if total > 0.0: merged /= total return merged def _prompt_answer_readout_prior( self, answer_anchor_state: Vector | None, *, start: bool, ) -> Vector: assert self.embedding_model is not None if answer_anchor_state is None: return [0.0 for _ in self.embedding_model.id_to_token] weights = self.prompt_answer_start_weights if start else self.prompt_answer_weights bias = self.prompt_answer_start_bias if start else self.prompt_answer_bias if np is not None: return self._prompt_answer_readout_prior_array( answer_anchor_state, start=start, ).tolist() if not weights: return [0.0 for _ in self.embedding_model.id_to_token] state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) logits = apply_readout(weights, state) if bias: logits = [value + bias[index] for index, value in enumerate(logits)] return self._calibrated_softmax( logits, scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, ) def _prompt_answer_readout_prior_array( self, answer_anchor_state: Vector | None, *, start: bool, ) -> object: assert np is not None assert self.embedding_model is not None if answer_anchor_state is None: return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) weights = ( self.prompt_answer_start_weights_array if start else self.prompt_answer_weights_array ) bias = self.prompt_answer_start_bias_array if start else self.prompt_answer_bias_array if weights is None: return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) state_array = self._center_state_array( self._masked_combined_state_array(answer_anchor_state) ) logits = weights @ state_array if bias is not None and bias.shape == logits.shape: logits = logits + bias return self._calibrated_softmax_array( logits, scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE, ) def save(self, path: str | Path) -> None: self._require_fit() assert self.tokenizer is not None assert self.embedding_model is not None assert self.ternary_mask is not None assert self.readout_weights is not None assert self.associative_keys is not None assert self.associative_values is not None assert self.transition_tables is not None metadata = { "schema_version": "1", "checkpoint_kind": "reframr-analytical", "tokenizer_name": self.tokenizer.name, "config": json.dumps(self.config.to_dict(), separators=(",", ":")), "tokenizer": json.dumps(self.tokenizer.to_dict(), separators=(",", ":")), "embedding_id_to_token": json.dumps(self.embedding_model.id_to_token, separators=(",", ":")), "tokenizer_vocab_size": str(self.tokenizer.vocab_size), "transition_tables": json.dumps(self._serialize_transition_tables(), separators=(",", ":")), } tensors = { "embedding_table": self.embedding_model.embeddings, "ternary_scale": [self.ternary_scale], "ternary_mask": self.ternary_mask, "readout_weights": self.readout_weights, "readout_bias": self.readout_bias or [0.0 for _ in self.embedding_model.id_to_token], "prompt_answer_weights": self.prompt_answer_weights if self.prompt_answer_weights is not None else [], "prompt_answer_bias": self.prompt_answer_bias or [0.0 for _ in self.embedding_model.id_to_token], "prompt_answer_start_weights": self.prompt_answer_start_weights if self.prompt_answer_start_weights is not None else [], "prompt_answer_start_bias": self.prompt_answer_start_bias or [0.0 for _ in self.embedding_model.id_to_token], "trace_token_weights": self.trace_token_weights or [1.0 for _ in self.embedding_model.id_to_token], "preference_bias": self.preference_bias or [0.0 for _ in self.embedding_model.id_to_token], "state_offset": self.state_offset or [0.0 for _ in range(self._combined_state_width())], "associative_keys": self.associative_keys, "associative_values": self.associative_values, "answer_keys": self.answer_keys if self.answer_keys is not None else [], "answer_values": self.answer_values if self.answer_values is not None else [], "answer_start_keys": self.answer_start_keys if self.answer_start_keys is not None else [], "answer_start_values": self.answer_start_values if self.answer_start_values is not None else [], "answer_sequence_keys": self.answer_sequence_keys if self.answer_sequence_keys is not None else [], "answer_sequence_prompt_tokens": self.answer_sequence_prompt_tokens if self.answer_sequence_prompt_tokens is not None else [], "answer_sequence_tokens": self.answer_sequence_tokens if self.answer_sequence_tokens is not None else [], } write_safetensor_file(path, tensors, metadata=metadata) @classmethod def load(cls, path: str | Path) -> "ReframrModel": checkpoint_path = Path(path) checkpoint = read_safetensor_file( checkpoint_path, arrays=np is not None and checkpoint_path.stat().st_size > 10_000_000, ) metadata = checkpoint.metadata config = ReframrConfig.from_dict(json.loads(metadata["config"])) model = cls(config) model.tokenizer = NativeTokenizer.from_dict(json.loads(metadata["tokenizer"])) id_to_token = [str(token) for token in json.loads(metadata["embedding_id_to_token"])] embedding_table = checkpoint.tensors["embedding_table"] if np is not None and hasattr(embedding_table, "shape"): embeddings = embedding_table.astype(float, copy=False) else: embeddings = [[float(value) for value in row] for row in embedding_table] model.embedding_model = EmbeddingModel( token_to_id={token: index for index, token in enumerate(id_to_token)}, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[], ) model.memory_units = [ AnalyticalMemoryUnit(model.config.state_dim, timescale) for timescale in model.config.timescales ] model.ternary_scale = float(checkpoint.tensors["ternary_scale"][0]) model.ternary_mask = [int(value) for value in checkpoint.tensors["ternary_mask"]] readout_tensor = checkpoint.tensors["readout_weights"] model.readout_weights = ( readout_tensor.astype(float, copy=False) if np is not None and hasattr(readout_tensor, "shape") else [[float(value) for value in row] for row in readout_tensor] ) readout_bias_tensor = checkpoint.tensors.get("readout_bias", []) model.readout_bias = [ float(value) for value in ( readout_bias_tensor.tolist() if hasattr(readout_bias_tensor, "tolist") else readout_bias_tensor ) ] if not model.readout_bias: model.readout_bias = [0.0 for _ in id_to_token] prompt_answer_tensor = checkpoint.tensors.get("prompt_answer_weights", []) model.prompt_answer_weights = ( prompt_answer_tensor.astype(float, copy=False) if np is not None and hasattr(prompt_answer_tensor, "shape") and len(prompt_answer_tensor.shape) == 2 else [[float(value) for value in row] for row in prompt_answer_tensor] ) prompt_answer_bias_tensor = checkpoint.tensors.get("prompt_answer_bias", []) model.prompt_answer_bias = [ float(value) for value in ( prompt_answer_bias_tensor.tolist() if hasattr(prompt_answer_bias_tensor, "tolist") else prompt_answer_bias_tensor ) ] if not model.prompt_answer_bias: model.prompt_answer_bias = [0.0 for _ in id_to_token] prompt_answer_start_tensor = checkpoint.tensors.get("prompt_answer_start_weights", []) model.prompt_answer_start_weights = ( prompt_answer_start_tensor.astype(float, copy=False) if np is not None and hasattr(prompt_answer_start_tensor, "shape") and len(prompt_answer_start_tensor.shape) == 2 else [[float(value) for value in row] for row in prompt_answer_start_tensor] ) prompt_answer_start_bias_tensor = checkpoint.tensors.get("prompt_answer_start_bias", []) model.prompt_answer_start_bias = [ float(value) for value in ( prompt_answer_start_bias_tensor.tolist() if hasattr(prompt_answer_start_bias_tensor, "tolist") else prompt_answer_start_bias_tensor ) ] if not model.prompt_answer_start_bias: model.prompt_answer_start_bias = [0.0 for _ in id_to_token] trace_weight_tensor = checkpoint.tensors.get("trace_token_weights", []) model.trace_token_weights = [ float(value) for value in ( trace_weight_tensor.tolist() if hasattr(trace_weight_tensor, "tolist") else trace_weight_tensor ) ] if not model.trace_token_weights: model.trace_token_weights = [ 0.0 if token in model.tokenizer.special_tokens else 1.0 for token in id_to_token ] preference_bias_tensor = checkpoint.tensors.get("preference_bias", []) model.preference_bias = [ float(value) for value in ( preference_bias_tensor.tolist() if hasattr(preference_bias_tensor, "tolist") else preference_bias_tensor ) ] if not model.preference_bias: model.preference_bias = [0.0 for _ in id_to_token] state_offset_tensor = checkpoint.tensors.get("state_offset", []) model.state_offset = [ float(value) for value in ( state_offset_tensor.tolist() if hasattr(state_offset_tensor, "tolist") else state_offset_tensor ) ] if not model.state_offset: model.state_offset = [0.0 for _ in range(model._combined_state_width())] associative_tensor = checkpoint.tensors.get("associative_keys", []) model.associative_keys = ( associative_tensor.astype(float, copy=False) if np is not None and hasattr(associative_tensor, "shape") else [[float(value) for value in row] for row in associative_tensor] ) if np is not None and hasattr(model.associative_keys, "shape"): model.associative_key_norms = np.linalg.norm(model.associative_keys, axis=1).tolist() else: model.associative_key_norms = [norm(key) for key in model.associative_keys] raw_associative_values = checkpoint.tensors.get("associative_values", []) model.associative_values = [ int(value) for value in ( raw_associative_values.tolist() if hasattr(raw_associative_values, "tolist") else raw_associative_values ) ] answer_tensor = checkpoint.tensors.get("answer_keys", []) if np is not None and hasattr(answer_tensor, "shape"): model.answer_keys = ( answer_tensor.astype(float, copy=False) if len(answer_tensor.shape) == 2 else [] ) else: model.answer_keys = [[float(value) for value in row] for row in answer_tensor] if ( np is not None and hasattr(model.answer_keys, "shape") and len(model.answer_keys.shape) == 2 ): model.answer_key_norms = np.linalg.norm(model.answer_keys, axis=1).tolist() else: model.answer_key_norms = [norm(key) for key in model.answer_keys] raw_answer_values = checkpoint.tensors.get("answer_values", []) model.answer_values = [ int(value) for value in ( raw_answer_values.tolist() if hasattr(raw_answer_values, "tolist") else raw_answer_values ) ] answer_start_tensor = checkpoint.tensors.get("answer_start_keys", []) if np is not None and hasattr(answer_start_tensor, "shape"): model.answer_start_keys = ( answer_start_tensor.astype(float, copy=False) if len(answer_start_tensor.shape) == 2 else [] ) else: model.answer_start_keys = [ [float(value) for value in row] for row in answer_start_tensor ] if ( np is not None and hasattr(model.answer_start_keys, "shape") and len(model.answer_start_keys.shape) == 2 ): model.answer_start_key_norms = np.linalg.norm(model.answer_start_keys, axis=1).tolist() else: model.answer_start_key_norms = [norm(key) for key in model.answer_start_keys] raw_answer_start_values = checkpoint.tensors.get("answer_start_values", []) model.answer_start_values = [ int(value) for value in ( raw_answer_start_values.tolist() if hasattr(raw_answer_start_values, "tolist") else raw_answer_start_values ) ] answer_sequence_tensor = checkpoint.tensors.get("answer_sequence_keys", []) if np is not None and hasattr(answer_sequence_tensor, "shape"): model.answer_sequence_keys = ( answer_sequence_tensor.astype(float, copy=False) if len(answer_sequence_tensor.shape) == 2 else [] ) else: model.answer_sequence_keys = [ [float(value) for value in row] for row in answer_sequence_tensor ] if ( np is not None and hasattr(model.answer_sequence_keys, "shape") and len(model.answer_sequence_keys.shape) == 2 ): model.answer_sequence_key_norms = np.linalg.norm( model.answer_sequence_keys, axis=1, ).tolist() else: model.answer_sequence_key_norms = [norm(key) for key in model.answer_sequence_keys] raw_answer_sequence_prompt_tokens = checkpoint.tensors.get("answer_sequence_prompt_tokens", []) if np is not None and hasattr(raw_answer_sequence_prompt_tokens, "shape"): model.answer_sequence_prompt_tokens = raw_answer_sequence_prompt_tokens.astype(int, copy=False) else: model.answer_sequence_prompt_tokens = [ [int(value) for value in row] for row in raw_answer_sequence_prompt_tokens ] raw_answer_sequence_tokens = checkpoint.tensors.get("answer_sequence_tokens", []) if np is not None and hasattr(raw_answer_sequence_tokens, "shape"): model.answer_sequence_tokens = raw_answer_sequence_tokens.astype(int, copy=False) else: model.answer_sequence_tokens = [ [int(value) for value in row] for row in raw_answer_sequence_tokens ] model.transition_tables = model._deserialize_transition_tables( json.loads(metadata.get("transition_tables", "{}")) ) model._refresh_numeric_caches() return model def _collect_training_examples( self, tokens: list[str], ) -> tuple[list[Vector], list[Vector], list[int]]: assert self.embedding_model is not None if np is not None: hidden_states = [ np.zeros(self.config.state_dim, dtype=np.float64) for _ in self.config.timescales ] context_traces = [ np.zeros(self.config.embedding_dim, dtype=np.float64) for _ in self.config.timescales ] zero_embedding: Vector | object = np.zeros(self.config.embedding_dim, dtype=np.float64) else: hidden_states = [zeros_vector(self.config.state_dim) for _ in self.config.timescales] context_traces = [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] zero_embedding = zeros_vector(self.config.embedding_dim) states: list[Vector] = [] labels: list[Vector] = [] label_ids: list[int] = [] token_ids = [ self.embedding_model.token_to_id.get(token, -1) for token in tokens ] example_count = max(0, len(tokens) - 1) stride = 1 if self.config.max_training_examples and example_count > self.config.max_training_examples: stride = max( 1, (example_count + self.config.max_training_examples - 1) // self.config.max_training_examples, ) for index in range(len(tokens) - 1): token = tokens[index] token_id = token_ids[index] embedding = ( self.embedding_model.embeddings[token_id] if token_id >= 0 else zero_embedding ) trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) hidden_states, context_traces, combined_state = self._step_hidden_states_from_embedding( hidden_states, context_traces, embedding, trace_embedding=trace_embedding, ) if stride > 1 and index % stride != 0 and index != len(tokens) - 2: continue states.append(combined_state) next_token_id = token_ids[index + 1] labels.append(self._one_hot_from_id(next_token_id)) label_ids.append(next_token_id) if self.config.max_training_examples and len(states) > self.config.max_training_examples: states = states[: self.config.max_training_examples] labels = labels[: self.config.max_training_examples] label_ids = label_ids[: self.config.max_training_examples] return states, labels, label_ids def _is_punctuation_piece(self, piece: str) -> bool: return bool(piece) and all(character in string.punctuation for character in piece) def _encode_context(self, tokens: list[str]) -> Vector: return self._masked_decode_state(self._build_decode_state(tokens)) def _build_decode_state(self, tokens: list[str]) -> DecodeState: assert self.memory_units is not None state = DecodeState( hidden_states=( [ np.zeros(self.config.state_dim, dtype=np.float64) for _ in self.config.timescales ] if np is not None else [zeros_vector(self.config.state_dim) for _ in self.config.timescales] ), context_traces=( [ np.zeros(self.config.embedding_dim, dtype=np.float64) for _ in self.config.timescales ] if np is not None else [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales] ), combined_state=self._zero_combined_state(), context_tokens=[], ) for token in tokens: self._advance_decode_state(state, token) return state def _advance_decode_state(self, state: DecodeState, token: str) -> DecodeState: next_hidden_states, next_context_traces, combined_state = self._step_hidden_states( state.hidden_states, state.context_traces, token, ) state.hidden_states = next_hidden_states state.context_traces = next_context_traces state.combined_state = combined_state state.context_tokens.append(token) if token == "": state.answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:] state.answer_matches = None state.answer_start_matches = None state.answer_sequence_matches = None state.prompt_answer_prior = None state.prompt_answer_start_prior = None return state def _masked_decode_state(self, state: DecodeState) -> Vector: assert self.ternary_mask is not None return apply_ternary_mask(state.combined_state, self.ternary_mask, self.ternary_scale) def _masked_combined_state(self, combined_state: Vector) -> Vector: assert self.ternary_mask is not None return apply_ternary_mask(combined_state, self.ternary_mask, self.ternary_scale) def _masked_decode_state_array(self, state: DecodeState) -> object: assert np is not None if self.ternary_mask_array is None: return np.asarray(self._masked_decode_state(state), dtype=RUNTIME_ARRAY_DTYPE) return ( np.asarray(state.combined_state, dtype=RUNTIME_ARRAY_DTYPE) * self.ternary_scale * self.ternary_mask_array ) def _masked_combined_state_array(self, combined_state: Vector) -> object: assert np is not None if self.ternary_mask_array is None: return np.asarray(self._masked_combined_state(combined_state), dtype=RUNTIME_ARRAY_DTYPE) return ( np.asarray(combined_state, dtype=RUNTIME_ARRAY_DTYPE) * self.ternary_scale * self.ternary_mask_array ) def _center_state_vector(self, state: Vector) -> Vector: if not self.state_offset or len(self.state_offset) != len(state): return state return [value - self.state_offset[index] for index, value in enumerate(state)] def _center_state_array(self, state: object) -> object: assert np is not None state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) if self.state_offset_array is None or self.state_offset_array.shape != state_array.shape: return state_array return state_array - self.state_offset_array def _zero_combined_state(self) -> Vector: return [0.0 for _ in range(self._combined_state_width())] def _combined_state_width(self) -> int: return (self.config.state_dim + self.config.embedding_dim) * len(self.config.timescales) def _derive_trace_token_weights_from_counts(self, token_counts: dict[str, float]) -> Vector: assert self.embedding_model is not None assert self.tokenizer is not None counts = [ float(token_counts.get(token, 0.0)) for token in self.embedding_model.id_to_token ] positive_counts = sorted(value for value in counts if value > 0.0) reference = ( positive_counts[len(positive_counts) // 2] if positive_counts else 1.0 ) weights: Vector = [] for token, count in zip(self.embedding_model.id_to_token, counts): if token in self.tokenizer.special_tokens: weights.append(0.0) elif count <= 0.0: weights.append(1.0) else: weight = (reference / count) ** 0.75 weights.append(max(0.08, min(4.8, weight))) return weights def _token_id_for_token(self, token: str) -> int: assert self.embedding_model is not None token_id = self.embedding_model.token_to_id.get(token) if token_id is None and token.lower() != token: token_id = self.embedding_model.token_to_id.get(token.lower()) return int(token_id) if token_id is not None else -1 def _trace_embedding_from_token_id( self, embedding: Vector | object, token_id: int, ) -> Vector | object: if token_id < 0: return embedding if self.trace_embedding_table_array is not None: return self.trace_embedding_table_array[token_id] weight = self.trace_token_weights[token_id] if self.trace_token_weights is not None else 1.0 dimension = self.config.embedding_dim if hasattr(embedding, "shape"): trace_embedding = embedding * weight for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: bucket = (token_id * bucket_multiplier + bucket_offset) % dimension sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 trace_embedding[bucket] += weight * TRACE_IDENTITY_SCALE * sign return trace_embedding trace_values = [float(value) * weight for value in embedding] for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: bucket = (token_id * bucket_multiplier + bucket_offset) % dimension sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0 trace_values[bucket] += weight * TRACE_IDENTITY_SCALE * sign return trace_values def _build_trace_embedding_table_array(self, embedding_array: object) -> object | None: if np is None or self.trace_token_weights is None: return None values = np.asarray(embedding_array, dtype=np.float64) if values.size == 0 or len(values.shape) != 2: return None weights = np.asarray(self.trace_token_weights, dtype=np.float64) if weights.shape[0] != values.shape[0]: return None trace_values = values * weights[:, None] if values.shape[1] <= 0: return trace_values token_ids = np.arange(values.shape[0], dtype=np.int64) for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES: buckets = ((token_ids * bucket_multiplier + bucket_offset) % values.shape[1]).astype( np.int64, copy=False, ) signs = np.where( ((token_ids * sign_multiplier + sign_offset) & 1) == 0, 1.0, -1.0, ) np.add.at(trace_values, (token_ids, buckets), weights * TRACE_IDENTITY_SCALE * signs) return trace_values def _refresh_numeric_caches(self) -> None: if np is None: self.ternary_mask_array = None self.readout_weights_array = None self.readout_bias_array = None self.prompt_answer_weights_array = None self.prompt_answer_bias_array = None self.prompt_answer_start_weights_array = None self.prompt_answer_start_bias_array = None self.trace_token_weights_array = None self.trace_embedding_table_array = None self.preference_bias_array = None self.preference_valid_mask_array = None self.state_offset_array = None self.associative_keys_array = None self.associative_key_norms_array = None self.associative_values_array = None self.associative_valid_mask_array = None self.answer_keys_array = None self.answer_key_norms_array = None self.answer_similarity_keys_array = None self.answer_similarity_key_norms_array = None self.answer_similarity_mask_array = None self.answer_values_array = None self.answer_valid_mask_array = None self.answer_start_keys_array = None self.answer_start_key_norms_array = None self.answer_start_similarity_keys_array = None self.answer_start_similarity_key_norms_array = None self.answer_start_values_array = None self.answer_start_valid_mask_array = None self.answer_sequence_keys_array = None self.answer_sequence_key_norms_array = None self.answer_sequence_similarity_keys_array = None self.answer_sequence_similarity_key_norms_array = None self.answer_sequence_prompt_tokens_array = None self.answer_sequence_tokens_array = None self.answer_sequence_prompt_weight_maps = None self.answer_sequence_prompt_weight_norms = None self.answer_sequence_prompt_bigram_sets = None self.answer_sequence_prompt_trigram_sets = None self.answer_sequence_prompt_number_sets = None self.answer_sequence_prompt_inverted_index = None self._refresh_answer_sequence_prompt_overlap_cache() return self.ternary_mask_array = ( np.asarray(self.ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) if self.ternary_mask is not None else None ) self.readout_weights_array = ( np.asarray(self.readout_weights, dtype=RUNTIME_ARRAY_DTYPE) if self.readout_weights is not None else None ) self.readout_bias_array = ( np.asarray(self.readout_bias, dtype=RUNTIME_ARRAY_DTYPE) if self.readout_bias is not None else None ) self.prompt_answer_weights_array = ( np.asarray(self.prompt_answer_weights, dtype=RUNTIME_ARRAY_DTYPE) if self.prompt_answer_weights is not None and len(self.prompt_answer_weights) > 0 else None ) self.prompt_answer_bias_array = ( np.asarray(self.prompt_answer_bias, dtype=RUNTIME_ARRAY_DTYPE) if self.prompt_answer_bias is not None else None ) self.prompt_answer_start_weights_array = ( np.asarray(self.prompt_answer_start_weights, dtype=RUNTIME_ARRAY_DTYPE) if self.prompt_answer_start_weights is not None and len(self.prompt_answer_start_weights) > 0 else None ) self.prompt_answer_start_bias_array = ( np.asarray(self.prompt_answer_start_bias, dtype=RUNTIME_ARRAY_DTYPE) if self.prompt_answer_start_bias is not None else None ) self.trace_token_weights_array = ( np.asarray(self.trace_token_weights, dtype=RUNTIME_ARRAY_DTYPE) if self.trace_token_weights is not None else None ) trace_embedding_table = ( self._build_trace_embedding_table_array(self.embedding_model.embeddings) if self.embedding_model is not None and self.trace_token_weights is not None else None ) self.trace_embedding_table_array = ( trace_embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False) if trace_embedding_table is not None else None ) self.preference_bias_array = ( np.asarray(self.preference_bias, dtype=RUNTIME_ARRAY_DTYPE) if self.preference_bias is not None else None ) self.preference_valid_mask_array = ( np.asarray( [ self._eligible_preference_token(token) for token in self.embedding_model.id_to_token ], dtype=bool, ) if self.embedding_model is not None and self.tokenizer is not None else None ) self.state_offset_array = ( np.asarray(self.state_offset, dtype=RUNTIME_ARRAY_DTYPE) if self.state_offset is not None else None ) self.associative_keys_array = ( np.asarray(self.associative_keys, dtype=RUNTIME_ARRAY_DTYPE) if self.associative_keys is not None and len(self.associative_keys) > 0 else None ) self.associative_key_norms_array = ( np.asarray(self.associative_key_norms, dtype=RUNTIME_ARRAY_DTYPE) if self.associative_key_norms is not None and len(self.associative_key_norms) > 0 else None ) self.associative_values_array = ( np.asarray(self.associative_values, dtype=np.int64) if self.associative_values is not None and len(self.associative_values) > 0 else None ) self.associative_valid_mask_array = ( self.associative_values_array >= 0 if self.associative_values_array is not None else None ) self.answer_keys_array = ( np.asarray(self.answer_keys, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_keys is not None and len(self.answer_keys) > 0 else None ) self.answer_key_norms_array = ( np.asarray(self.answer_key_norms, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_key_norms is not None and len(self.answer_key_norms) > 0 else None ) self.answer_similarity_keys_array = None self.answer_similarity_key_norms_array = None self.answer_similarity_mask_array = None if self.answer_keys_array is not None and len(self.answer_keys_array.shape) == 2: width = int(self.answer_keys_array.shape[1]) block_width = self.config.state_dim + self.config.embedding_dim expected_width = block_width * len(self.config.timescales) if block_width > 0 and width == expected_width: mask = np.zeros(width, dtype=RUNTIME_ARRAY_DTYPE) for scale_index in range(len(self.config.timescales)): start = scale_index * block_width + self.config.state_dim end = start + self.config.embedding_dim mask[start:end] = 1.0 self.answer_similarity_mask_array = mask self.answer_similarity_keys_array = self.answer_keys_array * mask[None, :] self.answer_similarity_key_norms_array = np.linalg.norm( self.answer_similarity_keys_array, axis=1, ).astype(RUNTIME_ARRAY_DTYPE, copy=False) self.answer_values_array = ( np.asarray(self.answer_values, dtype=np.int64) if self.answer_values is not None and len(self.answer_values) > 0 else None ) self.answer_valid_mask_array = ( self.answer_values_array >= 0 if self.answer_values_array is not None else None ) self.answer_start_keys_array = ( np.asarray(self.answer_start_keys, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_start_keys is not None and len(self.answer_start_keys) > 0 else None ) self.answer_start_key_norms_array = ( np.asarray(self.answer_start_key_norms, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_start_key_norms is not None and len(self.answer_start_key_norms) > 0 else None ) self.answer_start_similarity_keys_array = None self.answer_start_similarity_key_norms_array = None if ( self.answer_start_keys_array is not None and len(self.answer_start_keys_array.shape) == 2 and self.answer_similarity_mask_array is not None and int(self.answer_start_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) ): self.answer_start_similarity_keys_array = ( self.answer_start_keys_array * self.answer_similarity_mask_array[None, :] ) self.answer_start_similarity_key_norms_array = np.linalg.norm( self.answer_start_similarity_keys_array, axis=1, ).astype(RUNTIME_ARRAY_DTYPE, copy=False) self.answer_start_values_array = ( np.asarray(self.answer_start_values, dtype=np.int64) if self.answer_start_values is not None and len(self.answer_start_values) > 0 else None ) self.answer_start_valid_mask_array = ( self.answer_start_values_array >= 0 if self.answer_start_values_array is not None else None ) self.answer_sequence_keys_array = ( np.asarray(self.answer_sequence_keys, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_sequence_keys is not None and len(self.answer_sequence_keys) > 0 else None ) self.answer_sequence_key_norms_array = ( np.asarray(self.answer_sequence_key_norms, dtype=RUNTIME_ARRAY_DTYPE) if self.answer_sequence_key_norms is not None and len(self.answer_sequence_key_norms) > 0 else None ) self.answer_sequence_similarity_keys_array = None self.answer_sequence_similarity_key_norms_array = None if ( self.answer_sequence_keys_array is not None and len(self.answer_sequence_keys_array.shape) == 2 and self.answer_similarity_mask_array is not None and int(self.answer_sequence_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0]) ): self.answer_sequence_similarity_keys_array = ( self.answer_sequence_keys_array * self.answer_similarity_mask_array[None, :] ) self.answer_sequence_similarity_key_norms_array = np.linalg.norm( self.answer_sequence_similarity_keys_array, axis=1, ).astype(RUNTIME_ARRAY_DTYPE, copy=False) self.answer_sequence_tokens_array = ( np.asarray(self.answer_sequence_tokens, dtype=np.int64) if self.answer_sequence_tokens is not None and len(self.answer_sequence_tokens) > 0 else None ) self.answer_sequence_prompt_tokens_array = ( np.asarray(self.answer_sequence_prompt_tokens, dtype=np.int64) if self.answer_sequence_prompt_tokens is not None and len(self.answer_sequence_prompt_tokens) > 0 else None ) self._refresh_answer_sequence_prompt_overlap_cache() def _refresh_answer_sequence_prompt_overlap_cache(self) -> None: self.answer_sequence_prompt_weight_maps = None self.answer_sequence_prompt_weight_norms = None self.answer_sequence_prompt_bigram_sets = None self.answer_sequence_prompt_trigram_sets = None self.answer_sequence_prompt_number_sets = None self.answer_sequence_prompt_inverted_index = None self.answer_sequence_prompt_specificity = None if self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None: return inverted: dict[int, list[int]] = {} row_id_lists: list[list[int]] = [] for row in self.answer_sequence_prompt_tokens: row_values = row.tolist() if hasattr(row, "tolist") else row row_ids: list[int] = [] for raw_token_id in row_values: token_id = int(raw_token_id) if token_id < 0 or token_id >= len(self.trace_token_weights): continue row_ids.append(token_id) sequence_index = len(row_id_lists) for token_id in set(row_ids): inverted.setdefault(token_id, []).append(sequence_index) row_id_lists.append(row_ids) total_rows = len(row_id_lists) specificity = { token_id: self._prompt_overlap_token_specificity(len(indices), total_rows) for token_id, indices in inverted.items() } self.answer_sequence_prompt_inverted_index = inverted self.answer_sequence_prompt_specificity = specificity weight_maps: list[dict[int, float]] = [] weight_norms: list[float] = [] bigram_sets: list[set[tuple[int, int]]] = [] trigram_sets: list[set[tuple[int, int, int]]] = [] number_sets: list[set[str]] = [] for row_ids in row_id_lists: row_weights: dict[int, float] = {} for token_id in row_ids: row_weights[token_id] = max( row_weights.get(token_id, 0.0), float(self.trace_token_weights[token_id]) * specificity.get(token_id, 1.0), ) weight_maps.append(row_weights) weight_norms.append(sum(value * value for value in row_weights.values()) ** 0.5) bigram_sets.append( { (row_ids[index], row_ids[index + 1]) for index in range(len(row_ids) - 1) } ) trigram_sets.append( { (row_ids[index], row_ids[index + 1], row_ids[index + 2]) for index in range(len(row_ids) - 2) } ) number_sets.append(self._number_strings_from_token_ids(row_ids)) self.answer_sequence_prompt_weight_maps = weight_maps self.answer_sequence_prompt_weight_norms = weight_norms self.answer_sequence_prompt_bigram_sets = bigram_sets self.answer_sequence_prompt_trigram_sets = trigram_sets self.answer_sequence_prompt_number_sets = number_sets @staticmethod def _prompt_overlap_token_specificity(document_frequency: int, total_documents: int) -> float: if document_frequency <= 0 or total_documents <= 0: return 1.0 coverage = min(1.0, document_frequency / total_documents) return max(0.02, 1.0 - (coverage ** 0.5)) def _number_strings_from_token_ids(self, token_ids: list[int]) -> set[str]: assert self.embedding_model is not None tokens = [ self.embedding_model.id_to_token[token_id] for token_id in token_ids if 0 <= token_id < len(self.embedding_model.id_to_token) ] return self._number_strings_from_tokens(tokens) def _number_strings_from_tokens(self, tokens: list[str]) -> set[str]: numbers: set[str] = set() current = "" for token in tokens: if self.tokenizer is not None and token in self.tokenizer.special_tokens: if current: numbers.add(current) current = "" continue rendered = self._render_token(token) digits = "".join(character for character in rendered if character.isdigit()) starts_number = self._starts_new_word(token) if self.tokenizer is not None else True if digits and starts_number: if current: numbers.add(current) current = digits elif digits and current: current += digits else: if current: numbers.add(current) current = "" if current: numbers.add(current) return numbers @staticmethod def _numeric_prompt_can_match(query_numbers: set[str], row_numbers: set[str]) -> bool: if not query_numbers: return True if not row_numbers: return False return query_numbers.issubset(row_numbers) def _apply_readout_fast(self, state: Vector) -> Vector: if self.readout_weights_array is None or np is None: assert self.readout_weights is not None centered_state = self._center_state_vector(state) logits = apply_readout(self.readout_weights, centered_state) if self.readout_bias: logits = [ value + self.readout_bias[index] for index, value in enumerate(logits) ] return logits state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: state_array = state_array - self.state_offset_array logits = self.readout_weights_array @ state_array if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: logits = logits + self.readout_bias_array return logits.tolist() def _apply_readout_array(self, state: object) -> object: assert np is not None assert self.readout_weights_array is not None state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE) if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape: state_array = state_array - self.state_offset_array logits = self.readout_weights_array @ state_array if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape: logits = logits + self.readout_bias_array return logits def _step_hidden_states( self, hidden_states: list[Vector], context_traces: list[Vector], token: str, ) -> tuple[list[Vector], list[Vector], Vector]: assert self.embedding_model is not None assert self.tokenizer is not None token_id = self._token_id_for_token(token) embedding = self.embedding_model.vector(token) trace_embedding = self._trace_embedding_from_token_id(embedding, token_id) return self._step_hidden_states_from_embedding( hidden_states, context_traces, embedding, trace_embedding=trace_embedding, ) def _step_hidden_states_from_embedding( self, hidden_states: list[Vector], context_traces: list[Vector], embedding: Vector | object, *, trace_embedding: Vector | object | None = None, ) -> tuple[list[Vector], list[Vector], Vector]: assert self.memory_units is not None if trace_embedding is None: trace_embedding = embedding if np is not None and hidden_states and hasattr(hidden_states[0], "shape"): embedding_array = ( embedding if hasattr(embedding, "shape") else np.asarray(embedding, dtype=np.float64) ) trace_embedding_array = ( trace_embedding if hasattr(trace_embedding, "shape") else np.asarray(trace_embedding, dtype=np.float64) ) drive = analytical_embedding_drive_fast(embedding_array, self.config.state_dim) next_states: list[Vector] = [] next_traces: list[Vector] = [] combined_state: Vector = [] for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): next_state = unit.step_vector_fast(state, drive) decay = 1.0 / (1.0 + unit.timescale) next_trace = trace + ((1.0 - decay) * trace_embedding_array) next_states.append(next_state) next_traces.append(next_trace) combined_state.extend(next_state.tolist()) combined_state.extend(next_trace.tolist()) return next_states, next_traces, combined_state embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else embedding trace_embedding_vector = ( trace_embedding.tolist() if hasattr(trace_embedding, "tolist") else trace_embedding ) drive = analytical_embedding_drive(embedding_vector, self.config.state_dim) next_states: list[Vector] = [] next_traces: list[Vector] = [] combined_state: Vector = [] for unit, state, trace in zip(self.memory_units, hidden_states, context_traces): next_state = unit.step_vector(state, drive) decay = 1.0 / (1.0 + unit.timescale) next_trace = [ previous + ((1.0 - decay) * value) for previous, value in zip(trace, trace_embedding_vector) ] next_states.append(next_state) next_traces.append(next_trace) combined_state.extend(next_state) combined_state.extend(next_trace) return next_states, next_traces, combined_state def _one_hot(self, token: str) -> Vector: assert self.embedding_model is not None return self._one_hot_from_id(self.embedding_model.token_to_id.get(token, -1)) def _one_hot_from_id(self, token_id: int) -> Vector: assert self.embedding_model is not None vector = [0.0 for _ in self.embedding_model.id_to_token] if token_id >= 0: vector[token_id] = 1.0 return vector def _blend_probabilities( self, base: Vector, answer: Vector, associative: Vector, transition: Vector, copy: Vector, preference: Vector, *, transition_order: int | None, generated_count: int = 0, answer_locked: bool = False, answer_guided_start: bool = False, ) -> tuple[Vector, dict[str, float]]: base_weight = FAST_BASE_BLEND answer_weight = FAST_ANSWER_BLEND associative_weight = FAST_ASSOCIATIVE_BLEND transition_weight = FAST_TRANSITION_BLEND copy_weight = FAST_COPY_BLEND preference_weight = FAST_PREFERENCE_BLEND if answer_locked: base_weight *= 0.18 answer_weight *= 5.0 associative_weight *= 0.2 transition_weight *= 0.2 copy_weight *= 0.2 preference_weight *= 0.2 elif answer_guided_start: base_weight *= 0.35 answer_weight *= 3.5 associative_weight *= 0.2 transition_weight *= 0.35 copy_weight *= 0.2 preference_weight *= 0.2 elif generated_count > 0: answer_weight *= 0.32 transition_weight *= 2.0 copy_weight *= 0.75 if transition_order is None: answer_weight *= 1.1 associative_weight *= 0.75 copy_weight += 0.02 elif transition_order <= 2: answer_weight *= 1.15 associative_weight *= 0.65 transition_weight *= 0.55 copy_weight += 0.01 elif transition_order >= 5: transition_weight *= 1.25 sources: list[tuple[str, float, Vector]] = [("base", base_weight, base)] if any(value > 0.0 for value in answer): sources.append(("answer", answer_weight, answer)) if any(value > 0.0 for value in associative): sources.append(("associative", associative_weight, associative)) if any(value > 0.0 for value in transition): sources.append(("transition", transition_weight, transition)) if any(value > 0.0 for value in copy): sources.append(("copy", copy_weight, copy)) if any(value > 0.0 for value in preference): sources.append(("preference", preference_weight, preference)) total_weight = sum(weight for _, weight, _ in sources) blended = [0.0 for _ in base] blend_weights: dict[str, float] = {} for name, weight, source in sources: normalized_weight = weight / total_weight if total_weight else 0.0 blend_weights[name] = normalized_weight for index, value in enumerate(source): blended[index] += normalized_weight * value return _normalize_vector(blended), blend_weights def _blend_probability_arrays( self, base: object, answer: object, associative: object, transition: object, copy: object, preference: object, *, transition_order: int | None, generated_count: int = 0, answer_locked: bool = False, answer_guided_start: bool = False, ) -> tuple[object, dict[str, float]]: assert np is not None base_weight = FAST_BASE_BLEND answer_weight = FAST_ANSWER_BLEND associative_weight = FAST_ASSOCIATIVE_BLEND transition_weight = FAST_TRANSITION_BLEND copy_weight = FAST_COPY_BLEND preference_weight = FAST_PREFERENCE_BLEND if answer_locked: base_weight *= 0.18 answer_weight *= 5.0 associative_weight *= 0.2 transition_weight *= 0.2 copy_weight *= 0.2 preference_weight *= 0.2 elif answer_guided_start: base_weight *= 0.35 answer_weight *= 3.5 associative_weight *= 0.2 transition_weight *= 0.35 copy_weight *= 0.2 preference_weight *= 0.2 elif generated_count > 0: answer_weight *= 0.32 transition_weight *= 2.0 copy_weight *= 0.75 if transition_order is None: answer_weight *= 1.1 associative_weight *= 0.75 copy_weight += 0.02 elif transition_order <= 2: answer_weight *= 1.15 associative_weight *= 0.65 transition_weight *= 0.55 copy_weight += 0.01 elif transition_order >= 5: transition_weight *= 1.25 sources: list[tuple[str, float, object]] = [("base", base_weight, base)] if np.any(answer > 0.0): sources.append(("answer", answer_weight, answer)) if np.any(associative > 0.0): sources.append(("associative", associative_weight, associative)) if np.any(transition > 0.0): sources.append(("transition", transition_weight, transition)) if np.any(copy > 0.0): sources.append(("copy", copy_weight, copy)) if np.any(preference > 0.0): sources.append(("preference", preference_weight, preference)) total_weight = sum(weight for _, weight, _ in sources) blended = np.zeros_like(base, dtype=np.float64) blend_weights: dict[str, float] = {} for name, weight, source in sources: normalized_weight = weight / total_weight if total_weight else 0.0 blend_weights[name] = normalized_weight blended += normalized_weight * source total = float(blended.sum()) if total <= 0.0: return base, blend_weights return blended / total, blend_weights def _score_associative_matches( self, state: Vector, *, limit: int = ASSOCIATIVE_TOP_K, ) -> list[tuple[float, int, int]]: if ( self.associative_keys is None or self.associative_values is None or self.associative_key_norms is None or len(self.associative_keys) == 0 or len(self.associative_values) == 0 or len(self.associative_key_norms) == 0 ): return [] if ( np is not None and self.associative_keys_array is not None and self.associative_key_norms_array is not None and self.associative_values_array is not None and self.associative_valid_mask_array is not None and limit > 0 ): state_array = self._center_state_array(state).astype(self.associative_keys_array.dtype, copy=False) state_norm = float(np.linalg.norm(state_array)) if state_norm == 0.0: return [] numerators = self.associative_keys_array @ state_array denominators = self.associative_key_norms_array * state_norm valid_mask = self.associative_valid_mask_array & (denominators > 0.0) if np.any(valid_mask): scores = np.zeros_like(numerators, dtype=self.associative_keys_array.dtype) np.divide(numerators, denominators, out=scores, where=valid_mask) positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) if positive_positions.size: selected_positions = positive_positions if positive_positions.size > limit: partition = np.argpartition(scores[positive_positions], -limit)[-limit:] selected_positions = positive_positions[partition] ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] return [ ( float(scores[position]), int(self.associative_values_array[position]), int(position), ) for position in ordered_positions ] state = self._center_state_vector(state) state_norm = norm(state) if state_norm == 0.0: return [] scored: list[tuple[float, int, int]] = [] for example_index, (key, key_norm, token_id) in enumerate( zip(self.associative_keys, self.associative_key_norms, self.associative_values) ): if token_id < 0: continue denominator = state_norm * key_norm if denominator == 0.0: continue similarity = dot(state, key) / denominator if similarity > 0.0: scored.append((similarity, token_id, example_index)) scored.sort(key=lambda item: item[0], reverse=True) return scored[:limit] def _associative_prior_from_matches( self, matches: list[tuple[float, int, int]], ) -> Vector: assert self.embedding_model is not None if not matches: return [0.0 for _ in self.embedding_model.id_to_token] prior = [0.0 for _ in self.embedding_model.id_to_token] for similarity, token_id, _ in matches[:ASSOCIATIVE_TOP_K]: prior[token_id] += similarity return _normalize_vector(prior) def _associative_prior(self, state: Vector) -> Vector: return self._associative_prior_from_matches(self._score_associative_matches(state)) def _score_answer_matches( self, answer_anchor_state: Vector | None, *, limit: int = ANSWER_TOP_K, ) -> list[tuple[float, int, int]]: return self._score_prompt_anchor_matches( answer_anchor_state, self.answer_keys, self.answer_key_norms, self.answer_values, self.answer_keys_array, self.answer_key_norms_array, self.answer_values_array, self.answer_valid_mask_array, self.answer_similarity_keys_array, self.answer_similarity_key_norms_array, self.answer_similarity_mask_array, limit=limit, ) def _score_answer_start_matches( self, answer_anchor_state: Vector | None, *, limit: int = ANSWER_START_TOP_K, ) -> list[tuple[float, int, int]]: return self._score_prompt_anchor_matches( answer_anchor_state, self.answer_start_keys, self.answer_start_key_norms, self.answer_start_values, self.answer_start_keys_array, self.answer_start_key_norms_array, self.answer_start_values_array, self.answer_start_valid_mask_array, self.answer_start_similarity_keys_array, self.answer_start_similarity_key_norms_array, self.answer_similarity_mask_array, limit=limit, ) def _score_answer_sequence_matches( self, answer_anchor_state: Vector | None, context_tokens: list[str], *, limit: int = ANSWER_START_TOP_K, ) -> list[tuple[float, int, int]]: if ( answer_anchor_state is None or self.answer_sequence_keys is None or self.answer_sequence_key_norms is None or self.answer_sequence_tokens is None ): return [] values = list(range(len(self.answer_sequence_tokens))) values_array = np.arange(len(values), dtype=np.int64) if np is not None else None anchor_matches = self._score_prompt_anchor_matches( answer_anchor_state, self.answer_sequence_keys, self.answer_sequence_key_norms, values, self.answer_sequence_keys_array, self.answer_sequence_key_norms_array, values_array, values_array >= 0 if values_array is not None else None, self.answer_sequence_similarity_keys_array, self.answer_sequence_similarity_key_norms_array, self.answer_similarity_mask_array, limit=max(limit * 4, limit), ) overlap_scores = self._answer_sequence_prompt_overlap_scores(context_tokens) if overlap_scores is None: return anchor_matches[:limit] if not overlap_scores: return [] best_overlap = max(overlap_scores.values()) if overlap_scores else 0.0 overlap_floor = max(0.16, best_overlap * 0.90) focused_overlap_scores = { sequence_index: overlap for sequence_index, overlap in overlap_scores.items() if overlap >= overlap_floor } if not focused_overlap_scores: focused_overlap_scores = overlap_scores focused_indices = set(focused_overlap_scores) merged: dict[int, float] = {} for similarity, sequence_index, _ in anchor_matches: if sequence_index not in focused_indices: continue merged[sequence_index] = max(merged.get(sequence_index, 0.0), 0.20 * similarity) for sequence_index, overlap in focused_overlap_scores.items(): merged[sequence_index] = merged.get(sequence_index, 0.0) + (0.80 * overlap) ranked = [ (score, sequence_index, sequence_index) for sequence_index, score in merged.items() if score > 0.0 ] ranked.sort(key=lambda item: item[0], reverse=True) return ranked[:limit] def _answer_sequence_prompt_overlap_scores( self, context_tokens: list[str], ) -> dict[int, float] | None: if ( self.embedding_model is None or self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None ): return None answer_boundary = _last_index(context_tokens, "") prompt_tokens = ( context_tokens[:answer_boundary] if answer_boundary is not None else context_tokens ) if self.answer_sequence_prompt_specificity is None: self._refresh_answer_sequence_prompt_overlap_cache() specificity_map = self.answer_sequence_prompt_specificity or {} query_weights: dict[int, float] = {} query_specificity: dict[int, float] = {} query_content_weight = 0.0 query_ids: list[int] = [] for token in prompt_tokens: if self.tokenizer is not None and token in self.tokenizer.special_tokens: continue token_id = self.embedding_model.token_to_id.get(token) if token_id is None: continue query_ids.append(token_id) specificity = specificity_map.get(token_id, 1.0) weight = specificity query_weights[token_id] = max( query_weights.get(token_id, 0.0), weight, ) query_specificity[token_id] = max( query_specificity.get(token_id, 0.0), specificity, ) if specificity >= 0.20: query_content_weight += weight if not query_weights: return None query_norm = sum(value * value for value in query_weights.values()) ** 0.5 if query_norm <= 0.0: return None query_bigrams = { (query_ids[index], query_ids[index + 1]) for index in range(len(query_ids) - 1) } query_trigrams = { (query_ids[index], query_ids[index + 1], query_ids[index + 2]) for index in range(len(query_ids) - 2) } query_numbers = self._number_strings_from_tokens(prompt_tokens) def ordered_ngram_score( query_grams: set[tuple[int, ...]], row_grams: set[tuple[int, ...]], ) -> float: if not query_grams or not row_grams: return 0.0 overlap = len(query_grams & row_grams) if overlap <= 0: return 0.0 return overlap / ((len(query_grams) * len(row_grams)) ** 0.5) cached_maps = self.answer_sequence_prompt_weight_maps cached_norms = self.answer_sequence_prompt_weight_norms cached_bigrams = self.answer_sequence_prompt_bigram_sets cached_trigrams = self.answer_sequence_prompt_trigram_sets cached_numbers = self.answer_sequence_prompt_number_sets cached_index = self.answer_sequence_prompt_inverted_index if ( cached_maps is not None and cached_norms is not None and cached_bigrams is not None and cached_trigrams is not None and cached_numbers is not None and len(cached_maps) == len(self.answer_sequence_prompt_tokens) ): candidate_indices: set[int] | range if cached_index is not None: candidates: set[int] = set() for token_id in query_weights: candidates.update(cached_index.get(token_id, ())) candidate_indices = candidates if candidates else range(len(cached_maps)) else: candidate_indices = range(len(cached_maps)) candidate_indices = list(candidate_indices) if cached_index is not None and candidate_indices: candidate_set = set(candidate_indices) local_query_weights: dict[int, float] = {} local_query_specificity: dict[int, float] = {} local_query_content_weight = 0.0 for token_id in query_weights: local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) if local_frequency <= 0: continue specificity = self._prompt_overlap_token_specificity( local_frequency, len(candidate_indices), ) weight = specificity local_query_weights[token_id] = weight local_query_specificity[token_id] = specificity if specificity >= 0.20: local_query_content_weight += weight local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 if local_query_norm > 0.0: query_weights = local_query_weights query_specificity = local_query_specificity if local_query_content_weight > 0.0: query_content_weight = local_query_content_weight query_norm = local_query_norm scores: dict[int, float] = {} for sequence_index in candidate_indices: row_weights = cached_maps[sequence_index] if not row_weights: continue if not self._numeric_prompt_can_match(query_numbers, cached_numbers[sequence_index]): continue matched_content_weight = sum( query_weights[token_id] for token_id in query_weights.keys() & row_weights.keys() if query_specificity.get(token_id, 0.0) >= 0.20 ) row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( 1, len(row_weights), ) if ( query_content_weight > 0.0 and matched_content_weight / query_content_weight < 0.40 and row_token_coverage < 0.75 ): continue query_coverage = ( matched_content_weight / query_content_weight if query_content_weight > 0.0 else row_token_coverage ) numerator = sum( query_weights[token_id] * row_weights[token_id] for token_id in query_weights.keys() & row_weights.keys() ) if numerator <= 0.0: continue row_norm = cached_norms[sequence_index] if row_norm <= 0.0: continue token_score = numerator / (query_norm * row_norm) bigram_score = ordered_ngram_score( query_bigrams, cached_bigrams[sequence_index], ) trigram_score = ordered_ngram_score( query_trigrams, cached_trigrams[sequence_index], ) scores[sequence_index] = ( (0.35 * token_score) + (0.35 * query_coverage) + (0.15 * bigram_score) + (0.15 * trigram_score) ) return scores if cached_index is not None: candidate_set: set[int] = set() for token_id in query_weights: candidate_set.update(cached_index.get(token_id, ())) if not candidate_set: return {} candidate_indices: list[int] | range = sorted(candidate_set) local_query_weights: dict[int, float] = {} local_query_specificity: dict[int, float] = {} local_query_content_weight = 0.0 candidate_count = len(candidate_indices) for token_id in query_weights: local_frequency = len(candidate_set & set(cached_index.get(token_id, ()))) if local_frequency <= 0: continue specificity = self._prompt_overlap_token_specificity( local_frequency, candidate_count, ) local_query_weights[token_id] = specificity local_query_specificity[token_id] = specificity if specificity >= 0.20: local_query_content_weight += specificity local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5 if local_query_norm > 0.0: query_weights = local_query_weights query_specificity = local_query_specificity if local_query_content_weight > 0.0: query_content_weight = local_query_content_weight query_norm = local_query_norm else: candidate_indices = range(len(self.answer_sequence_prompt_tokens)) scores: dict[int, float] = {} for sequence_index in candidate_indices: row = self.answer_sequence_prompt_tokens[sequence_index] row_values = row.tolist() if hasattr(row, "tolist") else row row_weights: dict[int, float] = {} row_ids: list[int] = [] for raw_token_id in row_values: token_id = int(raw_token_id) if token_id < 0 or token_id >= len(self.trace_token_weights): continue row_ids.append(token_id) row_weights[token_id] = max( row_weights.get(token_id, 0.0), specificity_map.get(token_id, 1.0), ) if not row_weights: continue if not self._numeric_prompt_can_match( query_numbers, self._number_strings_from_token_ids(row_ids), ): continue matched_content_weight = sum( query_weights[token_id] for token_id in query_weights.keys() & row_weights.keys() if query_specificity.get(token_id, 0.0) >= 0.20 ) row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max( 1, len(row_weights), ) if ( query_content_weight > 0.0 and matched_content_weight / query_content_weight < 0.40 and row_token_coverage < 0.75 ): continue query_coverage = ( matched_content_weight / query_content_weight if query_content_weight > 0.0 else row_token_coverage ) numerator = sum( query_weights[token_id] * row_weights[token_id] for token_id in query_weights.keys() & row_weights.keys() ) if numerator <= 0.0: continue row_norm = sum(value * value for value in row_weights.values()) ** 0.5 if row_norm > 0.0: token_score = numerator / (query_norm * row_norm) row_bigrams = { (row_ids[index], row_ids[index + 1]) for index in range(len(row_ids) - 1) } row_trigrams = { (row_ids[index], row_ids[index + 1], row_ids[index + 2]) for index in range(len(row_ids) - 2) } bigram_score = ordered_ngram_score(query_bigrams, row_bigrams) trigram_score = ordered_ngram_score(query_trigrams, row_trigrams) scores[sequence_index] = ( (0.35 * token_score) + (0.35 * query_coverage) + (0.15 * bigram_score) + (0.15 * trigram_score) ) return scores def _score_prompt_anchor_matches( self, answer_anchor_state: Vector | None, keys: object | None, key_norms_list: object | None, values: object | None, keys_array: object | None, key_norms_array: object | None, values_array: object | None, valid_mask_array: object | None, similarity_keys_array: object | None, similarity_key_norms_array: object | None, similarity_mask_array: object | None, *, limit: int, ) -> list[tuple[float, int, int]]: if ( answer_anchor_state is None or keys is None or key_norms_list is None or values is None ): return [] if ( np is not None and keys_array is not None and key_norms_array is not None and values_array is not None and valid_mask_array is not None and limit > 0 ): state_array = self._center_state_array( self._masked_combined_state_array(answer_anchor_state) ).astype(keys_array.dtype, copy=False) key_array = keys_array key_norms = key_norms_array if ( similarity_keys_array is not None and similarity_key_norms_array is not None and similarity_mask_array is not None ): state_array = state_array * similarity_mask_array key_array = similarity_keys_array key_norms = similarity_key_norms_array state_norm = float(np.linalg.norm(state_array)) if state_norm == 0.0: return [] numerators = key_array @ state_array denominators = key_norms * state_norm valid_mask = valid_mask_array & (denominators > 0.0) if np.any(valid_mask): scores = np.zeros_like(numerators, dtype=key_array.dtype) np.divide(numerators, denominators, out=scores, where=valid_mask) positive_positions = np.flatnonzero(valid_mask & (scores > 0.0)) if positive_positions.size: selected_positions = positive_positions if positive_positions.size > limit: partition = np.argpartition(scores[positive_positions], -limit)[-limit:] selected_positions = positive_positions[partition] ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]] return [ ( float(scores[position]), int(values_array[position]), int(position), ) for position in ordered_positions ] state = self._center_state_vector(self._masked_combined_state(answer_anchor_state)) state_norm = norm(state) if state_norm == 0.0: return [] scored: list[tuple[float, int, int]] = [] for example_index, (key, key_norm, token_id) in enumerate( zip(keys, key_norms_list, values) ): if token_id < 0: continue denominator = state_norm * key_norm if denominator == 0.0: continue similarity = dot(state, key) / denominator if similarity > 0.0: scored.append((similarity, token_id, example_index)) scored.sort(key=lambda item: item[0], reverse=True) return scored[:limit] def _answer_prior_from_matches( self, matches: list[tuple[float, int, int]], generated_tokens: list[str], ) -> Vector: assert self.embedding_model is not None if not matches: return [0.0 for _ in self.embedding_model.id_to_token] prior = [0.0 for _ in self.embedding_model.id_to_token] generated_ids = { self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id } for similarity, token_id, _ in matches[:ANSWER_TOP_K]: token = self.embedding_model.id_to_token[token_id] if not self._allowed_generation_token(token, generated_tokens): continue if token_id in generated_ids: prior[token_id] += similarity * 0.35 else: prior[token_id] += similarity return _normalize_vector(prior) def _answer_sequence_prior_from_matches( self, matches: list[tuple[float, int, int]], generated_tokens: list[str], ) -> Vector: assert self.embedding_model is not None if not matches or self.answer_sequence_tokens is None: return [0.0 for _ in self.embedding_model.id_to_token] generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] prior = [0.0 for _ in self.embedding_model.id_to_token] best_similarity = matches[0][0] match_floor = best_similarity - 0.02 if best_similarity >= 0.9 else 0.0 for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: if similarity < match_floor: continue row = self.answer_sequence_tokens[sequence_index] token_ids = [ int(value) for value in (row.tolist() if hasattr(row, "tolist") else row) if int(value) >= 0 ] if not token_ids: continue next_token_id = self._next_sequence_token_id(token_ids, generated_ids) if next_token_id is None: continue token = self.embedding_model.id_to_token[next_token_id] if self._allowed_generation_token(token, generated_tokens): prior[next_token_id] += max(1e-9, similarity - match_floor) return _normalize_vector(prior) def _should_stop_answer_sequence( self, decode_state: DecodeState, generated_tokens: list[str], ) -> bool: matches = decode_state.answer_sequence_matches if matches is None: matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, ) return self._answer_sequence_is_complete(generated_tokens, matches) def _answer_decode_has_continuation( self, decode_state: DecodeState, generated_tokens: list[str], ) -> bool: matches = decode_state.answer_sequence_matches if matches is None: matches = self._score_answer_sequence_matches( decode_state.answer_anchor_state, decode_state.context_tokens, ) return self._answer_sequence_has_continuation(generated_tokens, matches) def _answer_sequence_is_complete( self, generated_tokens: list[str], matches: list[tuple[float, int, int]], ) -> bool: if ( self.embedding_model is None or self.answer_sequence_tokens is None or not generated_tokens or not matches ): return False generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] if not generated_ids: return False for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): continue row = self.answer_sequence_tokens[sequence_index] token_ids = [ int(value) for value in (row.tolist() if hasattr(row, "tolist") else row) if int(value) >= 0 ] if not token_ids or len(generated_ids) < len(token_ids): continue if generated_ids[: len(token_ids)] == token_ids: return True return False def _answer_sequence_has_continuation( self, generated_tokens: list[str], matches: list[tuple[float, int, int]], ) -> bool: if ( self.embedding_model is None or self.answer_sequence_tokens is None or not generated_tokens or not matches ): return False generated_ids = [ self.embedding_model.token_to_id[token] for token in generated_tokens if token in self.embedding_model.token_to_id ] if not generated_ids: return False for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]: if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens): continue row = self.answer_sequence_tokens[sequence_index] token_ids = [ int(value) for value in (row.tolist() if hasattr(row, "tolist") else row) if int(value) >= 0 ] if not token_ids: continue next_token_id = self._next_sequence_token_id(token_ids, generated_ids) if next_token_id is None: continue token = self.embedding_model.id_to_token[next_token_id] if self._allowed_generation_token(token, generated_tokens): return True return False def _next_sequence_token_id( self, token_ids: list[int], generated_ids: list[int], ) -> int | None: if not generated_ids: return token_ids[0] if len(generated_ids) >= len(token_ids): return None if token_ids[: len(generated_ids)] != generated_ids: return None return token_ids[len(generated_ids)] def _transition_prior(self, context_tokens: list[str]) -> Vector: prior, _ = self._transition_prior_with_order(context_tokens) return prior def _transition_prior_with_order( self, context_tokens: list[str], ) -> tuple[Vector, int | None]: assert self.embedding_model is not None if not self.transition_tables: return [0.0 for _ in self.embedding_model.id_to_token], None for order in TRANSITION_ORDERS: if len(context_tokens) < order: continue key = tuple(context_tokens[-order:]) transitions = self.transition_tables.get(order, {}).get(key) if not transitions: continue prior = [0.0 for _ in self.embedding_model.id_to_token] for token, probability in transitions.items(): token_id = self.embedding_model.token_to_id.get(token) if token_id is not None: prior[token_id] = probability return _normalize_vector(prior), order return [0.0 for _ in self.embedding_model.id_to_token], None def _transition_prior_array_with_order( self, context_tokens: list[str], ) -> tuple[object, int | None]: assert np is not None assert self.embedding_model is not None prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) if not self.transition_tables: return prior, None for order in TRANSITION_ORDERS: if len(context_tokens) < order: continue key = tuple(context_tokens[-order:]) transitions = self.transition_tables.get(order, {}).get(key) if not transitions: continue for token, probability in transitions.items(): token_id = self.embedding_model.token_to_id.get(token) if token_id is not None: prior[token_id] = probability total = float(prior.sum()) if total > 0.0: prior /= total return prior, order return prior, None def _copy_prior(self, context_tokens: list[str]) -> Vector: assert self.embedding_model is not None assert self.tokenizer is not None prior = [0.0 for _ in self.embedding_model.id_to_token] decay = 0.82 answer_start = None for index in range(len(context_tokens) - 1, -1, -1): if context_tokens[index] == "": answer_start = index + 1 break source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens if not source_tokens: return prior for distance, token in enumerate(reversed(source_tokens[-8:])): if token in self.tokenizer.special_tokens: continue if not self._eligible_copy_token(token): continue token_id = self.embedding_model.token_to_id.get(token) if token_id is None: continue prior[token_id] += decay**distance return _normalize_vector(prior) def _copy_prior_array(self, context_tokens: list[str]) -> object: assert np is not None assert self.embedding_model is not None assert self.tokenizer is not None prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) decay = 0.82 answer_start = None for index in range(len(context_tokens) - 1, -1, -1): if context_tokens[index] == "": answer_start = index + 1 break source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens for distance, token in enumerate(reversed(source_tokens[-8:])): if token in self.tokenizer.special_tokens: continue if not self._eligible_copy_token(token): continue token_id = self.embedding_model.token_to_id.get(token) if token_id is None: continue prior[token_id] += decay**distance total = float(prior.sum()) if total > 0.0: prior /= total return prior def _preference_prior(self) -> Vector: assert self.embedding_model is not None if not self.preference_bias or not any(value != 0.0 for value in self.preference_bias): return [0.0 for _ in self.embedding_model.id_to_token] eligible_indices = [ index for index, token in enumerate(self.embedding_model.id_to_token) if self.preference_bias[index] > 0.0 and self._eligible_preference_token(token) ] if not eligible_indices: return [0.0 for _ in self.embedding_model.id_to_token] eligible_probabilities = self._calibrated_softmax( [self.preference_bias[index] for index in eligible_indices] ) prior = [0.0 for _ in self.embedding_model.id_to_token] for index, probability in zip(eligible_indices, eligible_probabilities): prior[index] = probability return prior def _preference_prior_array(self) -> object: assert np is not None assert self.embedding_model is not None if self.preference_bias_array is None or not np.any(self.preference_bias_array != 0.0): return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) if self.preference_valid_mask_array is None or not np.any(self.preference_valid_mask_array): return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) positive_mask = self.preference_bias_array > 0.0 active_mask = self.preference_valid_mask_array & positive_mask if not np.any(active_mask): return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64) prior[active_mask] = self._calibrated_softmax_array( self.preference_bias_array[active_mask] ) return prior def _eligible_preference_token(self, token: str) -> bool: assert self.tokenizer is not None if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens: return False if not self._starts_new_word(token): return False rendered = self._render_token(token) if not rendered.strip() or self._is_punctuation_piece(rendered): return False alphanumeric = "".join(character for character in rendered if character.isalnum()) return len(alphanumeric) >= 1 def _build_transition_tables( self, tokens: list[str], ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: counts: dict[int, dict[tuple[str, ...], dict[str, int]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } for order in sorted(TRANSITION_ORDERS): for index in range(order - 1, len(tokens) - 1): key = tuple(tokens[index - order + 1 : index + 1]) nxt = tokens[index + 1] bucket = counts[order].setdefault(key, {}) bucket[nxt] = bucket.get(nxt, 0) + 1 probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } for order, mapping in counts.items(): items = list(mapping.items()) items.sort(key=lambda item: (-sum(item[1].values()), item[0])) if ( self.config.max_transition_contexts_per_order is not None and self.config.max_transition_contexts_per_order >= 0 ): items = items[: self.config.max_transition_contexts_per_order] for key, bucket in items: next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0])) if self.config.max_transition_next_tokens > 0: next_items = next_items[: self.config.max_transition_next_tokens] total = sum(value for _, value in next_items) if total <= 0: continue probabilities[order][key] = { token: value / total for token, value in next_items } return probabilities def _serialize_transition_tables(self) -> dict[str, dict[str, dict[str, float]]]: assert self.transition_tables is not None return { str(order): { _encode_ngram_key(key): value for key, value in mapping.items() } for order, mapping in self.transition_tables.items() } def _deserialize_transition_tables( self, payload: dict[str, dict[str, dict[str, float]]], ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: tables: dict[int, dict[tuple[str, ...], dict[str, float]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } for order_text, mapping in payload.items(): order = int(order_text) tables[order] = { _decode_ngram_key(key): { str(token): float(probability) for token, probability in value.items() } for key, value in mapping.items() } return tables def _eligible_copy_token(self, token: str) -> bool: rendered = self._render_token(token) if not rendered.strip(): return False if self._is_punctuation_piece(rendered): return False if not self._starts_new_word(token): return False alphanumeric = "".join(character for character in rendered if character.isalnum()) return len(alphanumeric) >= 2 def _allowed_generation_token(self, token: str, generated_tokens: list[str]) -> bool: assert self.embedding_model is not None if len(self.embedding_model.id_to_token) < 1024: return True if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens: return False rendered = self._render_token(token) if rendered == "\n": return bool(generated_tokens) if not rendered.strip(): return False if self._is_word_joiner_token(token): return ( self._can_attach_word_joiner(generated_tokens) or self._can_start_line_with_word_joiner(token, generated_tokens) ) if self._is_structural_punctuation_token(token): return bool(generated_tokens) or self._can_start_answer_with_structural_punctuation(token) if self._is_structural_symbol_token(token): return bool(generated_tokens) or self._starts_new_word(token) if not self._starts_new_word(token): return False alphanumeric = "".join(character for character in rendered if character.isalnum()) return len(alphanumeric) >= 1 or not self._is_punctuation_piece(rendered) def _would_repeat_recent_pattern( self, candidate: str, generated_tokens: list[str], recent_rendered_words: list[str] | None = None, ) -> bool: if len(generated_tokens) >= 2 and generated_tokens[-1] == candidate and generated_tokens[-2] == candidate: return True if len(generated_tokens) >= 2: trigram = tuple(generated_tokens[-2:] + [candidate]) recent_tokens = generated_tokens[-12:] for index in range(max(0, len(recent_tokens) - 4)): if tuple(recent_tokens[index : index + 3]) == trigram: return True rendered_words = recent_rendered_words if rendered_words is None: rendered_words = self._recent_rendered_words(generated_tokens) candidate_word = self._render_token(candidate).casefold() if ( rendered_words and self._starts_new_word(candidate) and any(character.isalnum() for character in candidate_word) ): candidate_bigram = (rendered_words[-1], candidate_word) recent_window = rendered_words[-10:] recent_bigrams = { (recent_window[index], recent_window[index + 1]) for index in range(len(recent_window) - 1) } if candidate_bigram in recent_bigrams: return True if ( len(candidate_word) > 2 and rendered_words[-10:].count(candidate_word) >= 2 and not self._is_common_connector_token(candidate) ): return True return False def _recent_rendered_words(self, generated_tokens: list[str]) -> list[str]: rendered_words: list[str] = [] for token in generated_tokens: if not self._starts_new_word(token): continue rendered = self._render_token(token).casefold() if any(character.isalnum() for character in rendered): rendered_words.append(rendered) return rendered_words def _select_generation_token( self, distribution: dict[str, float], *, context_tokens: list[str] | None = None, generated_tokens: list[str] | None = None, temperature: float = DEFAULT_GENERATION_TEMPERATURE, top_k: int = DEFAULT_GENERATION_TOP_K, top_p: float = DEFAULT_GENERATION_TOP_P, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, preserve_dominant_candidates: bool = False, ) -> str: assert self.tokenizer is not None generated_tokens = generated_tokens or [] candidates = self._prepare_generation_candidates( distribution, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=preserve_dominant_candidates, ) if candidates: return self._sample_generation_candidate( candidates, context_tokens=context_tokens or [], generated_tokens=generated_tokens, stochastic=temperature > 0.0, ) for token, _ in sorted(distribution.items(), key=lambda item: item[1], reverse=True): if token in self.tokenizer.special_tokens: continue if token == self.tokenizer.unk_token: continue if not self._allowed_generation_token(token, generated_tokens): continue return token return "" def _select_generation_token_from_array( self, probabilities: object, *, context_tokens: list[str], generated_tokens: list[str], temperature: float = DEFAULT_GENERATION_TEMPERATURE, top_k: int = DEFAULT_GENERATION_TOP_K, top_p: float = DEFAULT_GENERATION_TOP_P, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, preserve_dominant_candidates: bool = False, ) -> str: assert np is not None assert self.tokenizer is not None assert self.embedding_model is not None values = np.asarray(probabilities, dtype=np.float64) if values.size == 0: return "" pool_size = min(values.size, max(top_k * 4, 64)) if pool_size <= 0: pool_size = min(values.size, 64) if pool_size < values.size: candidate_indices = np.argpartition(values, -pool_size)[-pool_size:] candidate_indices = candidate_indices[np.argsort(values[candidate_indices])[::-1]] else: candidate_indices = np.argsort(values)[::-1] distribution: dict[str, float] = {} for raw_index in candidate_indices: index = int(raw_index) score = float(values[index]) if score <= 0.0: continue token = self.embedding_model.id_to_token[index] if token in self.tokenizer.special_tokens or token == self.tokenizer.unk_token: continue distribution[token] = score return self._select_generation_token( distribution, context_tokens=context_tokens, generated_tokens=generated_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, preserve_dominant_candidates=preserve_dominant_candidates, ) def _prepare_generation_candidates( self, distribution: dict[str, float], *, generated_tokens: list[str], temperature: float, top_k: int, top_p: float, repetition_penalty: float, preserve_dominant_candidates: bool = False, ) -> list[tuple[str, float]]: assert self.tokenizer is not None assert self.embedding_model is not None generated_word_count = self._generated_word_count(generated_tokens) clause_words = self._words_since_clause_break(generated_tokens) recent_rendered_words = self._recent_rendered_words(generated_tokens) best_probability = max(distribution.values(), default=0.0) adjusted: list[tuple[str, float]] = [] for token, probability in sorted(distribution.items(), key=lambda item: item[1], reverse=True): if token in self.tokenizer.special_tokens: continue if token == self.tokenizer.unk_token or probability <= 0.0: continue if not self._allowed_generation_token(token, generated_tokens): continue repeats_recent_pattern = self._would_repeat_recent_pattern( token, generated_tokens, recent_rendered_words=recent_rendered_words, ) if ( repeats_recent_pattern and not ( preserve_dominant_candidates and best_probability > 0.0 and probability >= best_probability * 0.80 ) ): continue score = probability rendered = self._render_token(token) punctuation_token = self._is_structural_punctuation_token(token) starts_new_word = self._starts_new_word(token) alphanumeric = "".join(character for character in rendered if character.isalnum()) if generated_tokens and starts_new_word and alphanumeric: previous_rendered = self._render_token(generated_tokens[-1]) previous_alphanumeric = "".join( character for character in previous_rendered if character.isalnum() ) if previous_alphanumeric.casefold() == alphanumeric.casefold(): continue common_connector = self._is_common_connector_token(token) if ( starts_new_word and len(alphanumeric) == 1 and not common_connector ): score *= 0.08 recent_count = generated_tokens[-12:].count(token) if recent_count > 0 and not common_connector: score /= repetition_penalty ** (2 * recent_count) if generated_tokens and token == generated_tokens[-1]: score /= repetition_penalty**3 if generated_tokens and token in generated_tokens[-4:] and not common_connector: score *= 0.35 if generated_tokens and not starts_new_word and self._starts_new_word(generated_tokens[-1]): score *= 0.08 if not generated_tokens and punctuation_token: if best_probability <= 0.0 or probability < best_probability * 0.80: score *= 0.01 elif not generated_tokens and not starts_new_word: score *= 0.02 if punctuation_token: if generated_tokens and self._is_structural_punctuation_token(generated_tokens[-1]): score *= 0.05 if clause_words >= 6: score *= 1.0 + min(1.4, 0.18 * (clause_words - 5)) elif generated_word_count >= 12: score *= 1.1 if score > 0.0: adjusted.append((token, score)) if not adjusted: return [] adjusted.sort(key=lambda item: item[1], reverse=True) if top_k > 0: adjusted = adjusted[:top_k] if 0.0 < top_p < 1.0: kept: list[tuple[str, float]] = [] cumulative = 0.0 total = sum(score for _, score in adjusted) for token, score in adjusted: normalized = score / total if total else 0.0 kept.append((token, score)) cumulative += normalized if cumulative >= top_p: break adjusted = kept if temperature <= 0.0: return [(adjusted[0][0], 1.0)] exponent = 1.0 / temperature tempered = [ (token, score**exponent) for token, score in adjusted if score > 0.0 ] total = sum(score for _, score in tempered) if total <= 0.0: return [] return [(token, score / total) for token, score in tempered] def _sample_generation_candidate( self, candidates: list[tuple[str, float]], *, context_tokens: list[str], generated_tokens: list[str], stochastic: bool = False, ) -> str: if not candidates: return "" if len(candidates) == 1: return candidates[0][0] top_probability = candidates[0][1] second_probability = candidates[1][1] top_has_clear_half_majority = top_probability >= 0.5 and ( second_probability <= 0.0 or top_probability - second_probability >= 0.02 ) if top_has_clear_half_majority or ( second_probability > 0.0 and top_probability >= second_probability * 2.5 ) or ( top_probability >= 0.08 and second_probability > 0.0 and top_probability >= second_probability * 1.35 ): return candidates[0][0] if stochastic: threshold = random.random() else: seed_payload = "\u0002".join([*context_tokens, "", *generated_tokens, str(len(candidates))]) seed = int.from_bytes(hashlib.sha256(seed_payload.encode("utf-8")).digest()[:8], "big") threshold = random.Random(seed).random() cumulative = 0.0 for token, probability in candidates: cumulative += probability if threshold <= cumulative: return token return candidates[-1][0] def _top_entries_from_vector( self, values: Vector, limit: int, ) -> list[dict[str, object]]: if limit <= 0: return [] ranked = sorted( enumerate(values), key=lambda item: item[1], reverse=True, ) return [ self._token_entry(index, probability) for index, probability in ranked[:limit] if probability > 0.0 ] def _token_entry( self, index: int, probability: float, ) -> dict[str, object]: assert self.embedding_model is not None token = self.embedding_model.id_to_token[index] return { "token": token, "text": self._render_token(token), "probability": probability, } def _build_reasoning_summary( self, transition_order: int | None, blend_weights: dict[str, float], ) -> str: dominant_source = max(blend_weights.items(), key=lambda item: item[1])[0] if blend_weights else "base" if transition_order is not None: transition_message = f" Transition prior is using order-{transition_order} context." else: transition_message = " Transition prior found no matching n-gram." return ( "Generation is running on analytical state, recurrent traces, and corpus-derived token transitions." f"{transition_message}" f" Dominant blend source: {dominant_source}." ) def _generated_word_count(self, tokens: list[str]) -> int: return len(self._decode_tokens(tokens).split()) def _is_structural_punctuation_text(self, text: str) -> bool: if len(text) != 1: return False if self._is_word_joiner_text(text): return False category = unicodedata.category(text) return category.startswith("P") def _is_structural_punctuation_token(self, token: str) -> bool: return self._is_structural_punctuation_text(self._render_token(token)) def _is_structural_symbol_token(self, token: str) -> bool: rendered = self._render_token(token) return len(rendered) == 1 and unicodedata.category(rendered).startswith("S") def _is_word_joiner_token(self, token: str) -> bool: return self._is_word_joiner_text(self._render_token(token)) def _is_word_joiner_text(self, text: str) -> bool: if len(text) != 1: return False category = unicodedata.category(text) if category in ("Pc", "Pd", "Lm"): return True name = unicodedata.name(text, "") return "APOSTROPHE" in name or ( "SINGLE" in name and "QUOTATION MARK" in name ) def _can_start_line_with_word_joiner(self, token: str, generated_tokens: list[str]) -> bool: rendered = self._render_token(token) if len(rendered) != 1 or unicodedata.category(rendered) != "Pd": return False if not self._starts_new_word(token): return False return not generated_tokens or self._render_token(generated_tokens[-1]) == "\n" def _can_start_answer_with_structural_punctuation(self, token: str) -> bool: rendered = self._render_token(token) if len(rendered) != 1 or not self._starts_new_word(token): return False return unicodedata.category(rendered) in ("Ps", "Pi") def _is_common_connector_token(self, token: str) -> bool: rendered = self._render_token(token) return rendered.isalpha() and len(rendered) <= 3 def _can_attach_word_joiner(self, generated_tokens: list[str]) -> bool: if not generated_tokens: return False rendered = self._render_token(generated_tokens[-1]) if not rendered: return False if any(character.isalnum() for character in rendered): return True if len(rendered) != 1: return False return unicodedata.category(rendered) in ("Ps", "Pi") def _words_since_clause_break(self, tokens: list[str]) -> int: assert self.tokenizer is not None words = 0 for token in reversed(tokens): if token in self.tokenizer.special_tokens: continue rendered = self._render_token(token) if self._is_structural_punctuation_text(rendered): break if self._starts_new_word(token) and not self._is_punctuation_piece(rendered): words += 1 return words def _should_stop_generation(self, generated_tokens: list[str]) -> bool: if not generated_tokens: return False if not self._is_terminal_punctuation_text(self._render_token(generated_tokens[-1])): return False return self._generated_word_count(generated_tokens) >= 14 def _is_terminal_punctuation_text(self, text: str) -> bool: if not self._is_structural_punctuation_text(text): return False name = unicodedata.name(text, "") return ( "FULL STOP" in name or "QUESTION MARK" in name or "EXCLAMATION MARK" in name ) def _starts_new_word(self, token: str) -> bool: assert self.tokenizer is not None if token in self.tokenizer.special_tokens: return True if token.startswith(self.tokenizer.word_prefix): return True return len(token) == 1 and not token.isalnum() and not self._is_word_joiner_token(token) def _decode_tokens(self, tokens: list[str]) -> str: assert self.tokenizer is not None return self.tokenizer.decode(tokens) def _render_token(self, token: str) -> str: assert self.tokenizer is not None if token.startswith(self.tokenizer.word_prefix): return token[len(self.tokenizer.word_prefix) :] return token def _require_fit(self) -> None: if ( self.tokenizer is None or self.embedding_model is None or self.memory_units is None or self.readout_weights is None or self.ternary_mask is None or self.associative_keys is None or self.associative_key_norms is None or self.associative_values is None or self.transition_tables is None ): raise RuntimeError("Call fit() before using the REFRAMR model.") def _ensure_numeric_caches(self) -> None: if np is None: return if self.readout_weights_array is None: self._refresh_numeric_caches()