from __future__ import annotations import json import random import re import site import sys import time from collections import Counter from collections.abc import Iterable, Iterator from dataclasses import dataclass from pathlib import Path from .config import ReframrConfig from .corpus import build_vocabulary_from_counts from .embeddings import fit_ppmi_embedding_from_cooccurrence, fit_randomized_ppmi_embedding_from_counts from .hippo import AnalyticalMemoryUnit from .linalg import Matrix, Vector, norm, zeros, zeros_vector from .model import ReframrModel, RUNTIME_ARRAY_DTYPE, TRANSITION_ORDERS, np from .reservoir import ( ridge_regression_readout_from_diagonal_moments, ridge_regression_readout_from_moments, ) from .ternary import apply_ternary_mask, derive_ternary_mask_from_feature_energy from .text_quality import clean_answer_text, clean_context_text, clean_training_text from .tokenizer import NativeTokenizer try: from scipy import sparse as scipy_sparse except (ImportError, ModuleNotFoundError, OSError): scipy_sparse = None TEXT_FIELD_PREFERENCES = ( "text", "content", "body", "article", "document", "passage", "markdown", "answer", "response", ) DIALOGUE_FIELD_PREFERENCES = ( "messages", "conversation", "conversations", "dialogue", "dialog", "turns", "chosen", ) INSTRUCTION_FIELD_PAIRS = ( ("instruction", "output"), ("prompt", "completion"), ("prompt", "response"), ("question", "answer"), ("question", "response"), ("query", "answer"), ("query", "response"), ) TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE) ROLE_ALIASES = { "assistant": "assistant", "assistant_response": "assistant", "bot": "assistant", "gpt": "assistant", "model": "assistant", "human": "user", "prompter": "user", "user": "user", "customer": "user", "system": "system", } ANSWER_READOUT_WEIGHT = 1.0 CONTEXT_READOUT_WEIGHT = 0.0 CONTEXT_STAT_WEIGHT = 0.02 PLAIN_TEXT_READOUT_WEIGHT = 0.03 PREFERENCE_REJECTED_TOKENIZER_WEIGHT = 0.0 PREFERENCE_BIAS_SCALE = 0.95 MAX_PREFERENCE_STATE_PAIRS = 512 ANSWER_START_TOKEN_WINDOW = 12 ANSWER_START_DECAY = 0.86 MAX_ANSWER_SEQUENCE_EXAMPLES = 196608 MAX_ANSWER_SEQUENCE_TOKENS = 192 HF_STREAM_MAX_RETRIES = 5 HF_STREAM_RETRY_BASE_DELAY_SECONDS = 0.25 FULL_READOUT_FEATURE_LIMIT = 2304 FULL_READOUT_EXAMPLE_LIMIT = 25000 @dataclass(slots=True) class CorpusPlanEntry: source: str name: str dataset: str = "" path: str = "" config: str | None = None split: str = "train" limit: int = 0 weight: float = 1.0 text_field: str | None = None min_words: int = 0 max_words: int = 0 min_alpha_ratio: float = 0.0 allowed_languages: tuple[str, ...] = () records: tuple[object, ...] = () streaming: bool = True trust_remote_code: bool = False @dataclass(slots=True) class StreamDocument: text: str weight: float source: str language: str = "" preference_rejected_text: str = "" class StreamingCooccurrenceAccumulator: def __init__(self, token_to_id: dict[str, int], window_size: int) -> None: self.token_to_id = token_to_id self.window_size = window_size self.rows: dict[int, dict[int, float]] = {} def update_tokens(self, tokens: list[str], *, weight: float) -> None: token_ids = [self.token_to_id[token] for token in tokens if token in self.token_to_id] for index, token_id in enumerate(token_ids): for offset in range(1, self.window_size + 1): other_index = index + offset if other_index >= len(token_ids): break other_id = token_ids[other_index] delta = weight * (1.0 / offset) self.rows.setdefault(token_id, {})[other_id] = ( self.rows.setdefault(token_id, {}).get(other_id, 0.0) + delta ) self.rows.setdefault(other_id, {})[token_id] = ( self.rows.setdefault(other_id, {}).get(token_id, 0.0) + delta ) def to_dense(self) -> Matrix: size = len(self.token_to_id) matrix = zeros(size, size) for row, columns in self.rows.items(): for col, value in columns.items(): matrix[row][col] = value return matrix def to_sparse(self) -> object: if scipy_sparse is None or np is None: return self.to_dense() rows: list[int] = [] cols: list[int] = [] data: list[float] = [] for row, columns in self.rows.items(): for col, value in columns.items(): rows.append(row) cols.append(col) data.append(value) size = len(self.token_to_id) return scipy_sparse.coo_matrix( ( np.asarray(data, dtype=np.float64), (np.asarray(rows, dtype=np.int64), np.asarray(cols, dtype=np.int64)), ), shape=(size, size), dtype=np.float64, ).tocsr() class TransitionAccumulator: def __init__( self, *, max_contexts_per_order: int | None = None, max_next_tokens: int = 0, ) -> None: self.max_contexts_per_order = max_contexts_per_order self.max_next_tokens = max_next_tokens self.context_soft_limit = ( max_contexts_per_order * 4 if max_contexts_per_order is not None and max_contexts_per_order > 0 else None ) self.next_token_soft_limit = max_next_tokens * 4 if max_next_tokens > 0 else None self.counts: dict[int, dict[tuple[str, ...], dict[str, float]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } def update_tokens(self, tokens: list[str], *, weight: float) -> None: for order in sorted(TRANSITION_ORDERS): order_counts = self.counts[order] for index in range(order - 1, len(tokens) - 1): key = tuple(tokens[index - order + 1 : index + 1]) nxt = tokens[index + 1] if ( self.context_soft_limit is not None and key not in order_counts and len(order_counts) >= self.context_soft_limit ): continue bucket = order_counts.setdefault(key, {}) if ( self.next_token_soft_limit is not None and nxt not in bucket and len(bucket) >= self.next_token_soft_limit ): continue bucket[nxt] = bucket.get(nxt, 0.0) + weight def finalize( self, *, max_contexts_per_order: int | None, max_next_tokens: int, ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = { order: {} for order in sorted(TRANSITION_ORDERS) } for order, mapping in self.counts.items(): items = list(mapping.items()) items.sort(key=lambda item: (-sum(item[1].values()), item[0])) if max_contexts_per_order is not None and max_contexts_per_order >= 0: items = items[:max_contexts_per_order] for key, bucket in items: next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0])) if max_next_tokens > 0: next_items = next_items[:max_next_tokens] total = sum(value for _, value in next_items) if total <= 0.0: continue probabilities[order][key] = { token: value / total for token, value in next_items } return probabilities class StateReservoir: def __init__(self, capacity: int | None, *, seed: int = 13) -> None: self.capacity = capacity self.random = random.Random(seed) self.states: list[Vector] = [] self.labels: list[int] = [] self.weights: list[float] = [] self.seen = 0 self.total_weight = 0.0 def reserve_slot(self, weight: float = 1.0) -> int | None: if weight <= 0.0: return None self.seen += 1 self.total_weight += weight if self.capacity is None: return len(self.states) if self.capacity <= 0: return None if len(self.states) < self.capacity: return len(self.states) keep_probability = min(1.0, (self.capacity * weight) / max(self.total_weight, 1e-12)) if self.random.random() >= keep_probability: return None return self.random.randrange(self.capacity) def store_reserved( self, slot: int, state: Vector, label_id: int, *, example_weight: float = 1.0, ) -> None: stored_state = state.copy() if hasattr(state, "copy") else state[:] if slot == len(self.states): self.states.append(stored_state) self.labels.append(label_id) self.weights.append(example_weight) elif 0 <= slot < len(self.states): self.states[slot] = stored_state self.labels[slot] = label_id self.weights[slot] = example_weight def consider(self, state: Vector, label_id: int, weight: float = 1.0) -> None: slot = self.reserve_slot(weight=weight) if slot is not None: self.store_reserved(slot, state, label_id, example_weight=weight) class SequenceReservoir: def __init__(self, capacity: int | None, *, seed: int = 41) -> None: self.capacity = capacity self.random = random.Random(seed) self.keys: list[Vector] = [] self.prompt_rows: list[list[int]] = [] self.token_rows: list[list[int]] = [] self.weights: list[float] = [] self.seen_weight = 0.0 def reserve_slot(self, *, weight: float = 1.0) -> int | None: if self.capacity == 0 or weight <= 0.0: return None self.seen_weight += weight if self.capacity is None or len(self.keys) < self.capacity: return len(self.keys) probability = min(1.0, (self.capacity * weight) / max(self.seen_weight, 1e-12)) if self.random.random() >= probability: return None return self.random.randrange(self.capacity) def store_reserved( self, slot: int, key: Vector, prompt_token_ids: list[int], token_ids: list[int], *, example_weight: float = 1.0, ) -> None: key_copy = key.tolist() if hasattr(key, "tolist") else list(key) prompt_row = prompt_token_ids[:MAX_ANSWER_SEQUENCE_TOKENS] row = token_ids[:MAX_ANSWER_SEQUENCE_TOKENS] if self.capacity is None or slot >= len(self.keys): self.keys.append(key_copy) self.prompt_rows.append(prompt_row) self.token_rows.append(row) self.weights.append(example_weight) return self.keys[slot] = key_copy self.prompt_rows[slot] = prompt_row self.token_rows[slot] = row self.weights[slot] = example_weight def consider( self, key: Vector, prompt_token_ids: list[int], token_ids: list[int], weight: float = 1.0, ) -> None: if not token_ids: return slot = self.reserve_slot(weight=weight) if slot is not None: self.store_reserved(slot, key, prompt_token_ids, token_ids, example_weight=weight) def _word_count(text: str) -> int: return len(text.split()) def _alpha_ratio(text: str) -> float: if not text: return 0.0 alpha_count = sum(character.isalpha() for character in text) return alpha_count / len(text) def _row_language(row: dict[str, object]) -> str: for candidate in ("lang", "language", "locale"): value = row.get(candidate) if isinstance(value, str) and value.strip(): return value.strip() return "" def _normalize_role(raw_role: object) -> str: role = str(raw_role or "").strip().casefold() return ROLE_ALIASES.get(role, role) def _message_content(message: dict[str, object]) -> str: for field in ("content", "value", "text", "message"): value = message.get(field) if isinstance(value, str) and value.strip(): return clean_training_text(value) return "" def _message_role(message: dict[str, object]) -> str: for field in ("role", "from", "speaker", "author"): value = message.get(field) if value is not None: normalized = _normalize_role(value) if normalized: return normalized return "" def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]: if not isinstance(raw_messages, list): return [] parsed: list[dict[str, str]] = [] for message in raw_messages: if not isinstance(message, dict): continue role = _message_role(message) content = _message_content(message) if role not in {"system", "user", "assistant"} or not content: continue parsed.append({"role": role, "content": content}) return parsed def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]: if not isinstance(raw_text, str): return [] text = raw_text.strip() if not text: return [] matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text)) if not matches: return [] parsed: list[dict[str, str]] = [] for index, match in enumerate(matches): role = _normalize_role(match.group(1)) start = match.end() end = matches[index + 1].start() if index + 1 < len(matches) else len(text) content = clean_training_text(text[start:end].strip()) if role in {"system", "user", "assistant"} and content: parsed.append({"role": role, "content": content}) return parsed def _render_prompt(messages: list[dict[str, str]]) -> str: parts = [] for message in messages: content = clean_context_text(message["content"]) if content: parts.append(content) return "\n".join(parts).strip() def _last_user_prompt_before(messages: list[dict[str, str]], end_index: int) -> str: for message in reversed(messages[:end_index]): if message["role"] == "user": return clean_context_text(message["content"]) return _render_prompt(messages[:end_index]) def _compose_training_text(context: object, answer: object) -> str: prompt_text = clean_context_text(_flatten_value(context)) answer_text = clean_answer_text(_flatten_value(answer)) if prompt_text and answer_text: return f" {prompt_text} {answer_text}".strip() return clean_training_text(answer_text or prompt_text) def _compose_from_messages(messages: list[dict[str, str]]) -> str: assistant_index = None for index in range(len(messages) - 1, -1, -1): if messages[index]["role"] == "assistant": assistant_index = index break if assistant_index is not None: prompt = _last_user_prompt_before(messages, assistant_index) answer = clean_answer_text(messages[assistant_index]["content"]) if prompt and answer: return f" {prompt} {answer}".strip() return "\n".join( message["content"] for message in messages if message.get("content") ).strip() def _flatten_message_list(messages: object) -> str: parsed = _parse_dialogue_messages(messages) if parsed: return _compose_from_messages(parsed) if not isinstance(messages, list): return "" parts: list[str] = [] for message in messages: if not isinstance(message, dict): continue content = str( message.get("content", message.get("value", message.get("text", ""))) ).strip() if not content: continue parts.append(clean_training_text(content)) return "\n".join(parts).strip() def _flatten_value(value: object) -> str: if isinstance(value, str): parsed = _parse_transcript_messages(value) if parsed: return _compose_from_messages(parsed) return clean_training_text(value.strip()) if isinstance(value, list): return _flatten_message_list(value) if isinstance(value, dict): for field in ("messages", "conversation", "conversations", "dialogue", "turns"): nested_messages = value.get(field) text = _flatten_message_list(nested_messages) if text: return text for field in ("text", "content", "value", "message"): nested = value.get(field) if isinstance(nested, str) and nested.strip(): return _flatten_value(nested) return "" def _safe_flag(value: object) -> bool | None: if isinstance(value, bool): return value if isinstance(value, str): normalized = value.strip().casefold() if normalized in {"true", "1", "yes", "safe"}: return True if normalized in {"false", "0", "no", "unsafe"}: return False return None def _selected_response_fields(row: dict[str, object]) -> tuple[str, str]: if "response_0" not in row or "response_1" not in row: return "", "" safe_0 = _safe_flag(row.get("is_response_0_safe")) safe_1 = _safe_flag(row.get("is_response_1_safe")) if safe_0 is not None and safe_1 is not None: if safe_0 and not safe_1: return "response_0", "response_1" if safe_1 and not safe_0: return "response_1", "response_0" if safe_0 and safe_1: return "response_0", "" return "", "" for selector in ("safer_response_id", "better_response_id"): raw_value = row.get(selector) try: preferred = int(raw_value) except (TypeError, ValueError): continue chosen = "response_1" if preferred == 1 else "response_0" rejected = "response_0" if chosen == "response_1" else "response_1" return chosen, rejected return "response_0", "response_1" def _extract_preference_pair(row: dict[str, object]) -> tuple[str, str]: if "chosen" in row and "rejected" in row: chosen_text = clean_training_text(_flatten_value(row.get("chosen"))) rejected_text = clean_training_text(_flatten_value(row.get("rejected"))) if chosen_text and rejected_text: return chosen_text, rejected_text if "response_0" in row and "response_1" in row: preferred_field, rejected_field = _selected_response_fields(row) if not preferred_field or not rejected_field: return "", "" prompt = row.get("prompt", row.get("question", row.get("query", ""))) if prompt: chosen_text = _compose_training_text(prompt, row.get(preferred_field)) rejected_text = _compose_training_text(prompt, row.get(rejected_field)) if chosen_text and rejected_text: return clean_training_text(chosen_text), clean_training_text(rejected_text) chosen_text = clean_training_text(_flatten_value(row.get(preferred_field))) rejected_text = clean_training_text(_flatten_value(row.get(rejected_field))) if chosen_text and rejected_text: return chosen_text, rejected_text return "", "" def _extract_preference_value(row: dict[str, object]) -> str: chosen_text, _ = _extract_preference_pair(row) return chosen_text def _extract_row_text(row: dict[str, object], text_field: str | None) -> str: if "context" in row and "answer" in row: context = clean_context_text(_flatten_value(row.get("context"))) answer = clean_answer_text(_flatten_value(row.get("answer"))) if context and answer: return f" {context} {answer}".strip() if "response_0" in row and "response_1" in row: preferred_field, _ = _selected_response_fields(row) prompt = row.get("prompt", row.get("question", row.get("query", ""))) if preferred_field and prompt: text = _compose_training_text(prompt, row.get(preferred_field)) if text: return text for prompt_field, answer_field in INSTRUCTION_FIELD_PAIRS: if prompt_field in row and answer_field in row: text = _compose_training_text(row.get(prompt_field), row.get(answer_field)) if text: return text if text_field is not None: return clean_training_text(_flatten_value(row.get(text_field))) preferred = _extract_preference_value(row) if preferred: return clean_training_text(preferred) for field in TEXT_FIELD_PREFERENCES: text = _flatten_value(row.get(field)) if text: return clean_training_text(text) for field in DIALOGUE_FIELD_PREFERENCES: text = _flatten_value(row.get(field)) if text: return clean_training_text(text) return "" def _passes_text_quality(text: str, language: str, entry: CorpusPlanEntry) -> bool: if not text: return False word_count = _word_count(text) if entry.min_words > 0 and word_count < entry.min_words: return False if entry.max_words > 0 and word_count > entry.max_words: return False if entry.min_alpha_ratio > 0.0 and _alpha_ratio(text) < entry.min_alpha_ratio: return False if entry.allowed_languages: if not language or language.casefold() not in entry.allowed_languages: return False return True def load_corpus_plan(source: str | Path) -> list[CorpusPlanEntry]: payload = json.loads(Path(source).read_text(encoding="utf-8-sig")) raw_entries = payload.get("sources", payload.get("datasets", [])) if not isinstance(raw_entries, list) or not raw_entries: raise ValueError("Corpus plan must define a non-empty 'sources' list.") entries: list[CorpusPlanEntry] = [] for index, raw_entry in enumerate(raw_entries, start=1): if not isinstance(raw_entry, dict): raise ValueError("Each corpus plan entry must be an object.") source = str(raw_entry.get("source", "hf")).strip() or "hf" name = str( raw_entry.get("name", raw_entry.get("dataset", f"source-{index}")) ).strip() or f"source-{index}" raw_languages = raw_entry.get("allowed_languages", []) allowed_languages = tuple( str(value).strip().casefold() for value in raw_languages if str(value).strip() ) if isinstance(raw_languages, list) else () raw_records = raw_entry.get("records", raw_entry.get("texts", [])) if source == "inline" and not isinstance(raw_records, list): raise ValueError("Inline corpus plan entries must provide a records/texts list.") entries.append( CorpusPlanEntry( source=source, name=name, dataset=str(raw_entry.get("dataset", "")), path=str(raw_entry.get("path", raw_entry.get("file", ""))), config=( str(raw_entry["config"]) if raw_entry.get("config") is not None else None ), split=str(raw_entry.get("split", "train")), limit=int(raw_entry.get("limit", 0)), weight=float(raw_entry.get("weight", 1.0)), text_field=( str(raw_entry["text_field"]) if raw_entry.get("text_field") is not None else None ), min_words=int(raw_entry.get("min_words", 0)), max_words=int(raw_entry.get("max_words", 0)), min_alpha_ratio=float(raw_entry.get("min_alpha_ratio", 0.0)), allowed_languages=allowed_languages, records=tuple(raw_records) if isinstance(raw_records, list) else (), streaming=bool(raw_entry.get("streaming", True)), trust_remote_code=bool(raw_entry.get("trust_remote_code", False)), ) ) return entries def _iter_hf_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]: try: from datasets import load_dataset except ModuleNotFoundError: user_site = site.getusersitepackages() if user_site and user_site not in sys.path: sys.path.append(user_site) from datasets import load_dataset dataset_kwargs: dict[str, object] = { "split": entry.split, "streaming": entry.streaming, } if entry.config: dataset_kwargs["name"] = entry.config if entry.trust_remote_code: dataset_kwargs["trust_remote_code"] = True for row in load_dataset(entry.dataset, **dataset_kwargs): yield dict(row) def _iter_file_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]: raw_path = entry.path or entry.dataset if not raw_path: raise ValueError("File corpus plan entries must provide a path.") path = Path(raw_path) suffix = path.suffix.lower() if suffix == ".jsonl": with path.open("r", encoding="utf-8") as handle: for line in handle: if line.strip(): row = json.loads(line) yield row if isinstance(row, dict) else {"text": str(row)} return if suffix == ".json": payload = json.loads(path.read_text(encoding="utf-8")) if isinstance(payload, list): for row in payload: yield row if isinstance(row, dict) else {"text": str(row)} return if isinstance(payload, dict): rows = payload.get("records", payload.get("texts")) if isinstance(rows, list): for row in rows: yield row if isinstance(row, dict) else {"text": str(row)} return yield payload return if suffix in {".txt", ".md", ".text"}: yield {"text": path.read_text(encoding="utf-8")} return raise ValueError(f"Unsupported file corpus source: {path}") def iter_corpus_plan_documents(plan: Iterable[CorpusPlanEntry]) -> Iterator[StreamDocument]: for entry in plan: accepted = 0 attempts = 0 while True: accepted_seen_this_attempt = 0 try: if entry.source == "inline": row_iterator = ( item if isinstance(item, dict) else {"text": str(item)} for item in entry.records ) elif entry.source == "hf": row_iterator = _iter_hf_rows(entry) elif entry.source == "file": row_iterator = _iter_file_rows(entry) else: raise ValueError(f"Unsupported corpus plan source: {entry.source}") for row in row_iterator: language = _row_language(row) _, rejected_text = _extract_preference_pair(row) text = clean_training_text(_extract_row_text(row, entry.text_field)) if not _passes_text_quality(text, language, entry): continue accepted_seen_this_attempt += 1 if accepted_seen_this_attempt <= accepted: continue yield StreamDocument( text=text, weight=entry.weight, source=entry.name, language=language, preference_rejected_text=rejected_text, ) accepted += 1 if entry.limit > 0 and accepted >= entry.limit: break break except Exception as exc: if entry.source != "hf": raise if attempts >= HF_STREAM_MAX_RETRIES: print( f"[source] {entry.name} skipped after {attempts} retries; " f"accepted {accepted} documents before final error: {exc}" ) break attempts += 1 delay = min( 15.0, HF_STREAM_RETRY_BASE_DELAY_SECONDS * (2 ** (attempts - 1)), ) print( f"[source] {entry.name} stream interrupted after {accepted} accepted " f"documents; retry {attempts}/{HF_STREAM_MAX_RETRIES} in {delay:.2f}s: {exc}" ) time.sleep(delay) def _log_progress(label: str, processed: int, log_every: int) -> None: if log_every > 0 and processed % log_every == 0: print(f"[{label}] processed {processed} documents") def _answer_boundary(tokens: list[str]) -> int | None: try: return tokens.index("") except ValueError: return None def _weighted_text_parts_for_statistics(text: str, document_weight: float) -> list[tuple[str, float]]: if "" not in text: return [(text, document_weight)] context, answer = text.split("", 1) context = clean_context_text(context.replace("", " ")) answer = clean_answer_text(answer) parts: list[tuple[str, float]] = [] if context: parts.append((context, document_weight * CONTEXT_STAT_WEIGHT)) if answer: parts.append((answer, document_weight * ANSWER_READOUT_WEIGHT)) return parts or [(text, document_weight)] def _weighted_token_sequences_for_statistics( tokens: list[str], tokenizer: NativeTokenizer, document_weight: float, ) -> list[tuple[list[str], float]]: answer_index = _answer_boundary(tokens) if answer_index is None: sequence = [token for token in tokens if token not in tokenizer.special_tokens] return [(sequence, document_weight)] if sequence else [] context_tokens = [ token for token in tokens[:answer_index] if token not in tokenizer.special_tokens ] answer_tokens = [ token for token in tokens[answer_index + 1 :] if token not in tokenizer.special_tokens ] sequences: list[tuple[list[str], float]] = [] if context_tokens: sequences.append((context_tokens, document_weight * CONTEXT_STAT_WEIGHT)) if answer_tokens: sequences.append((answer_tokens, document_weight * ANSWER_READOUT_WEIGHT)) return sequences def _readout_weight_for_target( answer_index: int | None, target_index: int, document_weight: float, ) -> float: if answer_index is None: return document_weight * PLAIN_TEXT_READOUT_WEIGHT if target_index <= answer_index: return document_weight * CONTEXT_READOUT_WEIGHT return document_weight * ANSWER_READOUT_WEIGHT def _answer_payload_tokens(tokens: list[str], tokenizer: NativeTokenizer) -> list[str]: answer_index = _answer_boundary(tokens) payload = tokens[answer_index + 1 :] if answer_index is not None else tokens return [token for token in payload if token not in tokenizer.special_tokens] def _standardized_preference_bias(values: object, active_mask: object | None = None) -> list[float]: if np is not None: bias = np.asarray(values, dtype=np.float64) if bias.size == 0: return [] active = ( np.asarray(active_mask, dtype=bool) if active_mask is not None else np.ones(bias.shape, dtype=bool) ) if not np.any(active): return [0.0 for _ in range(int(bias.size))] active_values = bias[active] spread = float(active_values.std()) if spread <= 1e-12: return [0.0 for _ in range(int(bias.size))] standardized = np.zeros_like(bias, dtype=np.float64) standardized[active] = ( (active_values - float(active_values.mean())) / spread ) * PREFERENCE_BIAS_SCALE return np.clip(standardized, -2.5, 2.5).astype(float).tolist() raw_values = [float(value) for value in values] if not raw_values: return [] average = sum(raw_values) / len(raw_values) variance = sum((value - average) * (value - average) for value in raw_values) / len(raw_values) spread = variance**0.5 if spread <= 1e-12: return [0.0 for _ in raw_values] active_indices = ( [ index for index, active in enumerate(active_mask) if active ] if active_mask is not None else list(range(len(raw_values))) ) if not active_indices: return [0.0 for _ in raw_values] active_values = [raw_values[index] for index in active_indices] average = mean(active_values) spread = (mean([(value - average) * (value - average) for value in active_values])) ** 0.5 if spread <= 1e-12: return [0.0 for _ in raw_values] standardized = [0.0 for _ in raw_values] for index in active_indices: standardized[index] = max( -2.5, min(2.5, ((raw_values[index] - average) / spread) * PREFERENCE_BIAS_SCALE), ) return standardized def _candidate_preference_bias_from_state_vector( model: ReframrModel, preference_state: object, ) -> object: if np is None: return None assert model.embedding_model is not None assert model.memory_units is not None assert model.ternary_mask is not None embeddings = np.asarray(model.embedding_model.embeddings, dtype=np.float64) if embeddings.size == 0: return np.zeros(0, dtype=np.float64) state_vector = np.asarray(preference_state, dtype=np.float64) mask = np.asarray(model.ternary_mask, dtype=np.float64) * float(model.ternary_scale) if state_vector.shape[0] != mask.shape[0]: return np.zeros(embeddings.shape[0], dtype=np.float64) state_indices = np.arange(model.config.state_dim, dtype=np.int64) drive = ( embeddings[:, state_indices % model.config.embedding_dim] + (0.5 * embeddings[:, (3 * state_indices + 1) % model.config.embedding_dim]) - (0.25 * embeddings[:, (5 * state_indices + 2) % model.config.embedding_dim]) ) scores = np.zeros(embeddings.shape[0], dtype=np.float64) offset = 0 for unit in model.memory_units: hidden_end = offset + model.config.state_dim trace_end = hidden_end + model.config.embedding_dim hidden_pref = state_vector[offset:hidden_end] * mask[offset:hidden_end] trace_pref = state_vector[hidden_end:trace_end] * mask[hidden_end:trace_end] hidden_delta_axis = np.asarray(unit.input_projection, dtype=np.float64) * hidden_pref trace_gain = 1.0 - (1.0 / (1.0 + unit.timescale)) scores += drive @ hidden_delta_axis scores += embeddings @ (trace_gain * trace_pref) offset = trace_end return scores def _derive_preference_bias_from_pairs( model: ReframrModel, preference_token_pairs: list[tuple[list[str], list[str], float]], tokenizer: NativeTokenizer, ) -> tuple[list[float], int]: assert model.embedding_model is not None vocab_size = len(model.embedding_model.id_to_token) if not preference_token_pairs: return [0.0 for _ in range(vocab_size)], 0 if np is not None: token_bias = np.zeros(vocab_size, dtype=np.float64) active_token_mask = np.zeros(vocab_size, dtype=bool) state_delta = np.zeros(model._combined_state_width(), dtype=np.float64) else: token_bias = [0.0 for _ in range(vocab_size)] active_token_ids: set[int] = set() state_delta = [0.0 for _ in range(model._combined_state_width())] pair_weight_total = 0.0 state_pair_count = 0 state_stride = max( 1, (len(preference_token_pairs) + MAX_PREFERENCE_STATE_PAIRS - 1) // MAX_PREFERENCE_STATE_PAIRS, ) for pair_index, (chosen_tokens, rejected_tokens, pair_weight) in enumerate(preference_token_pairs): chosen_answer = _answer_payload_tokens(chosen_tokens, tokenizer) rejected_answer = _answer_payload_tokens(rejected_tokens, tokenizer) if chosen_answer: delta = pair_weight / max(1, len(chosen_answer)) for token in chosen_answer: token_id = model.embedding_model.token_to_id.get(token) if token_id is not None: token_bias[token_id] += delta if np is not None: active_token_mask[token_id] = True else: active_token_ids.add(token_id) if rejected_answer: delta = pair_weight / max(1, len(rejected_answer)) for token in rejected_answer: token_id = model.embedding_model.token_to_id.get(token) if token_id is not None: token_bias[token_id] -= delta if np is not None: active_token_mask[token_id] = True else: active_token_ids.add(token_id) if pair_index % state_stride != 0 or state_pair_count >= MAX_PREFERENCE_STATE_PAIRS: continue chosen_state = model._masked_decode_state(model._build_decode_state(chosen_tokens)) rejected_state = model._masked_decode_state(model._build_decode_state(rejected_tokens)) if len(chosen_state) != len(rejected_state): continue pair_weight_total += pair_weight state_pair_count += 1 if np is not None: state_delta += pair_weight * ( np.asarray(chosen_state, dtype=np.float64) - np.asarray(rejected_state, dtype=np.float64) ) else: for index, (chosen_value, rejected_value) in enumerate(zip(chosen_state, rejected_state)): state_delta[index] += pair_weight * (chosen_value - rejected_value) if pair_weight_total > 0.0: if np is not None: state_delta = state_delta / pair_weight_total candidate_bias = _candidate_preference_bias_from_state_vector(model, state_delta) if candidate_bias is not None: token_bias[active_token_mask] = ( token_bias[active_token_mask] + candidate_bias[active_token_mask] ) else: state_delta = [value / pair_weight_total for value in state_delta] if np is not None: return _standardized_preference_bias(token_bias, active_token_mask), state_pair_count active_mask = [index in active_token_ids for index in range(vocab_size)] return _standardized_preference_bias(token_bias, active_mask), state_pair_count def _solve_weighted_prompt_readout( states: list[Vector], labels: list[int], weights: list[float], *, vocab_size: int, diagonal: object, state_offset: object, regularization: float, ) -> tuple[object, object, int]: if np is None or not states or not labels or not weights: return [], [0.0 for _ in range(vocab_size)], 0 state_matrix = np.asarray(states, dtype=np.float64) label_array = np.asarray(labels, dtype=np.int64) weight_vector = np.asarray(weights, dtype=np.float64) valid_mask = ( (label_array >= 0) & (label_array < vocab_size) & (weight_vector > 0.0) ) if not np.any(valid_mask): return [], [0.0 for _ in range(vocab_size)], 0 state_matrix = state_matrix[valid_mask] label_array = label_array[valid_mask] weight_vector = weight_vector[valid_mask] diagonal_array = np.asarray(diagonal, dtype=np.float64) offset_array = np.asarray(state_offset, dtype=np.float64) if ( len(state_matrix.shape) != 2 or diagonal_array.shape[0] != state_matrix.shape[1] or offset_array.shape[0] != state_matrix.shape[1] ): return [], [0.0 for _ in range(vocab_size)], 0 masked_states = state_matrix * diagonal_array[None, :] centered_states = masked_states - offset_array[None, :] weighted_centered_states = weight_vector[:, None] * centered_states gram = centered_states.T @ weighted_centered_states cross = np.zeros((vocab_size, centered_states.shape[1]), dtype=np.float64) np.add.at(cross, label_array, weighted_centered_states) total_weight = float(weight_vector.sum()) if total_weight <= 0.0: return [], [0.0 for _ in range(vocab_size)], 0 bias = np.zeros(vocab_size, dtype=np.float64) np.add.at(bias, label_array, weight_vector) bias /= total_weight readout = ridge_regression_readout_from_moments( gram, cross, regularization=regularization, ) return readout, bias, int(label_array.shape[0]) def fit_model_from_corpus_plan( plan: Iterable[CorpusPlanEntry], config: ReframrConfig, *, log_every: int = 0, ) -> tuple[ReframrModel, dict[str, object]]: entries = list(plan) if not entries: raise ValueError("Cannot fit REFRAMR without any corpus plan entries.") stage_seconds: dict[str, float] = {} stage_started = time.perf_counter() def finish_stage(name: str) -> None: nonlocal stage_started now = time.perf_counter() elapsed = round(now - stage_started, 6) stage_seconds[name] = elapsed if log_every > 0: print(f"[stage] {name} finished in {elapsed:.3f}s") stage_started = now seed_tokenizer = NativeTokenizer( merges=[], vocab=[], base_symbols=[], lowercase=config.lowercase, ) segment_counts: Counter[str] = Counter() source_counts: dict[str, int] = {} documents: list[StreamDocument] = [] processed = 0 for entry in entries: if log_every > 0: print(f"[source] {entry.name} started") source_start = processed for document in iter_corpus_plan_documents([entry]): documents.append(document) processed += 1 source_counts[document.source] = source_counts.get(document.source, 0) + 1 for text_part, part_weight in _weighted_text_parts_for_statistics( document.text, document.weight, ): for segment in seed_tokenizer.pretokenize(text_part): segment_counts[segment] += part_weight if document.preference_rejected_text: rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT for text_part, part_weight in _weighted_text_parts_for_statistics( document.preference_rejected_text, rejected_weight, ): for segment in seed_tokenizer.pretokenize(text_part): segment_counts[segment] += part_weight _log_progress("tokenizer", processed, log_every) if log_every > 0: print(f"[source] {entry.name} accepted {processed - source_start} documents") if processed == 0: raise ValueError("Corpus plan did not yield any usable documents after filtering.") finish_stage("stream_and_segment") tokenizer = NativeTokenizer.train_from_segment_counts( segment_counts, vocab_size=config.tokenizer_vocab_size, min_pair_frequency=config.tokenizer_min_pair_frequency, lowercase=config.lowercase, ) finish_stage("tokenizer_fit") token_counts: Counter[str] = Counter() raw_tokenized_documents: list[list[str]] = [] raw_rejected_tokenized_documents: list[list[str]] = [] processed = 0 for document in documents: processed += 1 tokens = tokenizer.encode(document.text) raw_tokenized_documents.append(tokens) for token in tokens: if token in tokenizer.special_tokens: token_counts[token] += document.weight for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics( tokens, tokenizer, document.weight, ): for token in token_sequence: token_counts[token] += sequence_weight rejected_tokens = ( tokenizer.encode(document.preference_rejected_text) if document.preference_rejected_text else [] ) raw_rejected_tokenized_documents.append(rejected_tokens) rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT for token in rejected_tokens: if token in tokenizer.special_tokens: token_counts[token] += rejected_weight for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics( rejected_tokens, tokenizer, rejected_weight, ): for token in token_sequence: token_counts[token] += sequence_weight _log_progress("vocab", processed, log_every) token_to_id, id_to_token = build_vocabulary_from_counts( token_counts, min_frequency=config.min_frequency, max_vocab=config.max_vocab, ) if not id_to_token: raise ValueError("Streaming recompute could not derive an embedding vocabulary.") finish_stage("vocabulary") cooccurrence = StreamingCooccurrenceAccumulator(token_to_id, config.window_size) tokenized_documents: list[list[str]] = [] preference_token_pairs: list[tuple[list[str], list[str], float]] = [] processed = 0 for document, raw_tokens, raw_rejected_tokens in zip( documents, raw_tokenized_documents, raw_rejected_tokenized_documents, ): processed += 1 tokens = [token for token in raw_tokens if token in token_to_id] tokenized_documents.append(tokens) rejected_tokens = [token for token in raw_rejected_tokens if token in token_to_id] if len(tokens) > 1 and len(rejected_tokens) > 1: preference_token_pairs.append((tokens, rejected_tokens, document.weight)) for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics( tokens, tokenizer, document.weight, ): if len(token_sequence) > 1: cooccurrence.update_tokens(token_sequence, weight=sequence_weight) _log_progress("cooccurrence", processed, log_every) finish_stage("cooccurrence") if np is not None: embedding_model = fit_randomized_ppmi_embedding_from_counts( id_to_token, cooccurrence.rows, embedding_dim=config.embedding_dim, ) else: embedding_model = fit_ppmi_embedding_from_cooccurrence( id_to_token, cooccurrence.to_sparse(), embedding_dim=config.embedding_dim, ) finish_stage("embedding") model = ReframrModel(config) model.tokenizer = tokenizer model.embedding_model = embedding_model model.memory_units = [ AnalyticalMemoryUnit(config.state_dim, timescale) for timescale in config.timescales ] model.trace_token_weights = model._derive_trace_token_weights_from_counts(token_counts) feature_count = len(model._zero_combined_state()) if np is not None: feature_second_moment = np.zeros(feature_count, dtype=np.float64) raw_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64) else: feature_second_moment = zeros_vector(feature_count) raw_cross = zeros(len(embedding_model.id_to_token), feature_count) example_weight_total = 0.0 has_answer_targets = any(_answer_boundary(tokens) is not None for tokens in tokenized_documents) if config.max_training_examples is None: answer_reservoir_capacity = None general_reservoir_capacity = None elif config.max_training_examples <= 0: answer_reservoir_capacity = 0 general_reservoir_capacity = 0 elif has_answer_targets: answer_reservoir_capacity = max(1, int(config.max_training_examples * 0.75)) general_reservoir_capacity = max(0, config.max_training_examples - answer_reservoir_capacity) else: answer_reservoir_capacity = 0 general_reservoir_capacity = config.max_training_examples answer_sequence_capacity = MAX_ANSWER_SEQUENCE_EXAMPLES if has_answer_targets else 0 answer_reservoir = StateReservoir(answer_reservoir_capacity, seed=17) general_reservoir = StateReservoir(general_reservoir_capacity, seed=13) answer_intent_reservoir = StateReservoir(answer_reservoir_capacity, seed=29) answer_start_reservoir = StateReservoir(answer_reservoir_capacity, seed=37) answer_sequence_reservoir = SequenceReservoir(answer_sequence_capacity, seed=41) moment_reservoir = StateReservoir( config.max_training_examples if config.max_training_examples is not None else None, seed=31, ) transitions = TransitionAccumulator( max_contexts_per_order=config.max_transition_contexts_per_order, max_next_tokens=config.max_transition_next_tokens, ) if np is not None: target_label_mass = np.zeros(len(embedding_model.id_to_token), dtype=np.float64) else: target_label_mass = zeros_vector(len(embedding_model.id_to_token)) for document, tokens in zip(documents, tokenized_documents): answer_index = _answer_boundary(tokens) for index in range(len(tokens) - 1): next_token = tokens[index + 1] if tokenizer is not None and next_token in tokenizer.special_tokens: continue next_token_id = embedding_model.token_to_id.get(next_token, -1) if next_token_id < 0: continue label_weight = _readout_weight_for_target(answer_index, index + 1, document.weight) if label_weight > 0.0: target_label_mass[next_token_id] += label_weight if np is not None: positive_label_mass = target_label_mass[target_label_mass > 0.0] reference_label_mass = ( float(np.median(positive_label_mass)) if positive_label_mass.size else 1.0 ) target_balance = np.ones(len(embedding_model.id_to_token), dtype=np.float64) np.divide( reference_label_mass, np.maximum(target_label_mass, 1e-12), out=target_balance, where=target_label_mass > 0.0, ) target_balance = np.clip(np.sqrt(target_balance), 0.25, 4.0) else: positive_label_mass = [value for value in target_label_mass if value > 0.0] if positive_label_mass: sorted_mass = sorted(positive_label_mass) reference_label_mass = sorted_mass[len(sorted_mass) // 2] else: reference_label_mass = 1.0 target_balance = [ max(0.25, min(4.0, (reference_label_mass / max(value, 1e-12)) ** 0.5)) if value > 0.0 else 1.0 for value in target_label_mass ] processed = 0 embedding_array = ( np.asarray(embedding_model.embeddings, dtype=RUNTIME_ARRAY_DTYPE) if np is not None else None ) trace_embedding_array = ( model._build_trace_embedding_table_array(embedding_array) if np is not None and embedding_array is not None else None ) if np is not None: trace_decay = np.asarray( [1.0 / (1.0 + unit.timescale) for unit in model.memory_units], dtype=RUNTIME_ARRAY_DTYPE, ) trace_gain = 1.0 - trace_decay transition_stack = np.asarray( [unit.transition for unit in model.memory_units], dtype=RUNTIME_ARRAY_DTYPE, ) input_projection_stack = np.asarray( [unit.input_projection for unit in model.memory_units], dtype=RUNTIME_ARRAY_DTYPE, ) drive_indices = np.arange(config.state_dim, dtype=np.int64) drive_primary = drive_indices % config.embedding_dim drive_secondary = (3 * drive_indices + 1) % config.embedding_dim drive_tertiary = (5 * drive_indices + 2) % config.embedding_dim else: trace_decay = None trace_gain = None transition_stack = None input_projection_stack = None drive_primary = None drive_secondary = None drive_tertiary = None for document, tokens in zip(documents, tokenized_documents): processed += 1 if len(tokens) < 2: _log_progress("state", processed, log_every) continue answer_index = _answer_boundary(tokens) for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics( tokens, tokenizer, document.weight, ): if len(token_sequence) > 1: transitions.update_tokens(token_sequence, weight=sequence_weight) if np is not None: hidden_state_matrix = np.zeros((len(config.timescales), config.state_dim), dtype=RUNTIME_ARRAY_DTYPE) context_trace_matrix = np.zeros((len(config.timescales), config.embedding_dim), dtype=RUNTIME_ARRAY_DTYPE) else: hidden_states = [zeros_vector(config.state_dim) for _ in config.timescales] context_traces = [zeros_vector(config.embedding_dim) for _ in config.timescales] answer_anchor_state = None for index in range(len(tokens) - 1): token = tokens[index] token_id = embedding_model.token_to_id.get(token, -1) if ( np is not None and embedding_array is not None and trace_decay is not None and trace_gain is not None and transition_stack is not None and input_projection_stack is not None and drive_primary is not None and drive_secondary is not None and drive_tertiary is not None and trace_embedding_array is not None and token_id >= 0 ): embedding = embedding_array[token_id] trace_embedding = trace_embedding_array[token_id] drive = ( embedding[drive_primary] + (0.5 * embedding[drive_secondary]) - (0.25 * embedding[drive_tertiary]) ) hidden_state_matrix = ( (transition_stack @ hidden_state_matrix[:, :, None])[:, :, 0] + (input_projection_stack * drive[None, :]) ) context_trace_matrix = ( context_trace_matrix + (trace_gain[:, None] * trace_embedding[None, :]) ) else: hidden_states, context_traces, combined_state = model._step_hidden_states( hidden_states, context_traces, token, ) if token == "": if np is not None: answer_anchor_state = np.concatenate( (hidden_state_matrix, context_trace_matrix), axis=1, ).reshape(-1).copy() else: answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:] next_token = tokens[index + 1] if next_token in tokenizer.special_tokens: continue next_token_id = embedding_model.token_to_id.get(next_token, -1) if next_token_id < 0: continue raw_readout_weight = _readout_weight_for_target(answer_index, index + 1, document.weight) readout_weight = raw_readout_weight * float(target_balance[next_token_id]) if readout_weight <= 0.0: continue moment_slot = moment_reservoir.reserve_slot(weight=readout_weight) is_answer_target = answer_index is not None and index + 1 > answer_index target_reservoir = answer_reservoir if is_answer_target else general_reservoir memory_weight = readout_weight * float(target_balance[next_token_id]) answer_token_offset = ( index - answer_index if is_answer_target and answer_index is not None else None ) intent_slot = ( answer_intent_reservoir.reserve_slot(weight=memory_weight) if is_answer_target and answer_anchor_state is not None else None ) answer_start_weight = ( raw_readout_weight * (ANSWER_START_DECAY ** answer_token_offset) if ( answer_token_offset is not None and answer_token_offset < ANSWER_START_TOKEN_WINDOW ) else 0.0 ) answer_start_slot = ( answer_start_reservoir.reserve_slot(weight=answer_start_weight) if answer_start_weight > 0.0 and answer_anchor_state is not None else None ) if np is not None: reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight) if moment_slot is not None or reservoir_slot is not None: combined_state = np.concatenate( (hidden_state_matrix, context_trace_matrix), axis=1, ).reshape(-1).copy() if moment_slot is not None: moment_reservoir.store_reserved( moment_slot, combined_state, next_token_id, example_weight=readout_weight, ) if reservoir_slot is not None: target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id) if intent_slot is not None: answer_intent_reservoir.store_reserved( intent_slot, answer_anchor_state, next_token_id, example_weight=memory_weight, ) if answer_start_slot is not None: answer_start_reservoir.store_reserved( answer_start_slot, answer_anchor_state, next_token_id, example_weight=answer_start_weight * float(target_balance[next_token_id]), ) else: reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight) if moment_slot is None and reservoir_slot is None: continue if moment_slot is not None: moment_reservoir.store_reserved( moment_slot, combined_state, next_token_id, example_weight=readout_weight, ) if reservoir_slot is not None: target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id) if intent_slot is not None: answer_intent_reservoir.store_reserved( intent_slot, answer_anchor_state, next_token_id, example_weight=memory_weight, ) if answer_start_slot is not None: answer_start_reservoir.store_reserved( answer_start_slot, answer_anchor_state, next_token_id, example_weight=answer_start_weight * target_balance[next_token_id], ) if answer_anchor_state is not None and answer_index is not None: prompt_token_ids = [ embedding_model.token_to_id[token] for token in tokens[:answer_index] if token not in tokenizer.special_tokens and token in embedding_model.token_to_id ] answer_token_ids = [ embedding_model.token_to_id[token] for token in tokens[answer_index + 1 :] if token not in tokenizer.special_tokens and token in embedding_model.token_to_id ] answer_sequence_reservoir.consider( answer_anchor_state, prompt_token_ids, answer_token_ids, weight=document.weight * ANSWER_READOUT_WEIGHT, ) _log_progress("state", processed, log_every) moment_states = moment_reservoir.states moment_labels = moment_reservoir.labels moment_weights = moment_reservoir.weights example_weight_total = sum(moment_weights) if np is not None and moment_states: state_matrix = np.asarray(moment_states, dtype=np.float64) weight_vector = np.asarray(moment_weights, dtype=np.float64) weighted_states = weight_vector[:, None] * state_matrix feature_second_moment += (weighted_states * state_matrix).sum(axis=0) np.add.at(raw_cross, moment_labels, weighted_states) elif moment_states: for state, label_id, readout_weight in zip(moment_states, moment_labels, moment_weights): for feature, value in enumerate(state): weighted_value = readout_weight * value feature_second_moment[feature] += weighted_value * value raw_cross[label_id][feature] += weighted_value if example_weight_total <= 0.0: raise ValueError("Streaming recompute did not collect any next-token training examples.") if np is not None: feature_energy = (feature_second_moment / example_weight_total).tolist() else: feature_energy = [ feature_second_moment[index] / example_weight_total for index in range(feature_count) ] ternary_scale, ternary_mask = derive_ternary_mask_from_feature_energy(feature_energy) if np is not None: diagonal = np.asarray([ternary_scale * value for value in ternary_mask], dtype=np.float64) masked_feature_second_moment = feature_second_moment * diagonal * diagonal masked_cross = raw_cross * diagonal[None, :] else: diagonal = [ternary_scale * value for value in ternary_mask] masked_feature_second_moment = [ feature_second_moment[index] * diagonal[index] * diagonal[index] for index in range(feature_count) ] masked_cross = [ [ raw_cross[row][col] * diagonal[col] for col in range(feature_count) ] for row in range(len(raw_cross)) ] readout_solver = "diagonal" state_offset_values: object readout_bias_values: object if ( np is not None and moment_states and feature_count <= FULL_READOUT_FEATURE_LIMIT and len(moment_states) <= FULL_READOUT_EXAMPLE_LIMIT ): state_matrix = np.asarray(moment_states, dtype=np.float64) weight_vector = np.asarray(moment_weights, dtype=np.float64) label_array = np.asarray(moment_labels, dtype=np.int64) masked_states = state_matrix * diagonal[None, :] total_weight = float(weight_vector.sum()) if total_weight <= 0.0: total_weight = 1.0 state_offset_values = (weight_vector[:, None] * masked_states).sum(axis=0) / total_weight centered_states = masked_states - state_offset_values[None, :] weighted_centered_states = weight_vector[:, None] * centered_states gram = centered_states.T @ weighted_centered_states full_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64) np.add.at(full_cross, label_array, weighted_centered_states) readout_bias_values = np.zeros(len(embedding_model.id_to_token), dtype=np.float64) np.add.at(readout_bias_values, label_array, weight_vector) readout_bias_values /= total_weight readout_weights = ridge_regression_readout_from_moments( gram, full_cross, regularization=config.regularization, ) readout_solver = "full" else: state_offset_values = ( np.zeros(feature_count, dtype=np.float64) if np is not None else [0.0 for _ in range(feature_count)] ) if np is not None: label_total = max(float(target_label_mass.sum()), 1.0) readout_bias_values = target_label_mass / label_total else: label_total = max(sum(target_label_mass), 1.0) readout_bias_values = [value / label_total for value in target_label_mass] readout_weights = ridge_regression_readout_from_diagonal_moments( masked_feature_second_moment, masked_cross, regularization=config.regularization, ) finish_stage("state_and_readout") model.ternary_scale = ternary_scale model.ternary_mask = ternary_mask model.readout_weights = readout_weights model.state_offset = ( state_offset_values.tolist() if hasattr(state_offset_values, "tolist") else list(state_offset_values) ) model.readout_bias = ( readout_bias_values.tolist() if hasattr(readout_bias_values, "tolist") else list(readout_bias_values) ) model.preference_bias, preference_state_pairs = _derive_preference_bias_from_pairs( model, preference_token_pairs, tokenizer, ) finish_stage("preference") reservoir_states = answer_reservoir.states + general_reservoir.states reservoir_labels = answer_reservoir.labels + general_reservoir.labels answer_intent_states = answer_intent_reservoir.states answer_intent_labels = answer_intent_reservoir.labels answer_start_states = answer_start_reservoir.states answer_start_labels = answer_start_reservoir.labels answer_sequence_states = answer_sequence_reservoir.keys answer_sequence_prompt_rows = answer_sequence_reservoir.prompt_rows answer_sequence_rows = answer_sequence_reservoir.token_rows prompt_answer_weights, prompt_answer_bias, prompt_answer_readout_examples = ( _solve_weighted_prompt_readout( answer_intent_states, answer_intent_labels, answer_intent_reservoir.weights, vocab_size=len(embedding_model.id_to_token), diagonal=diagonal, state_offset=state_offset_values, regularization=config.regularization, ) ) ( prompt_answer_start_weights, prompt_answer_start_bias, prompt_answer_start_readout_examples, ) = _solve_weighted_prompt_readout( answer_start_states, answer_start_labels, answer_start_reservoir.weights, vocab_size=len(embedding_model.id_to_token), diagonal=diagonal, state_offset=state_offset_values, regularization=config.regularization, ) model.prompt_answer_weights = prompt_answer_weights model.prompt_answer_bias = ( prompt_answer_bias.tolist() if hasattr(prompt_answer_bias, "tolist") else list(prompt_answer_bias) ) model.prompt_answer_start_weights = prompt_answer_start_weights model.prompt_answer_start_bias = ( prompt_answer_start_bias.tolist() if hasattr(prompt_answer_start_bias, "tolist") else list(prompt_answer_start_bias) ) if np is not None and reservoir_states: reservoir_array = np.asarray(reservoir_states, dtype=RUNTIME_ARRAY_DTYPE) mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE) associative_array = ((reservoir_array * mask_array[None, :]) - offset_array[None, :]).astype( RUNTIME_ARRAY_DTYPE, copy=False, ) model.associative_keys = associative_array model.associative_key_norms = np.linalg.norm(associative_array, axis=1).tolist() else: offset_vector = model.state_offset model.associative_keys = [ [ value - offset_vector[index] for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale)) ] for state in reservoir_states ] model.associative_key_norms = [norm(state) for state in model.associative_keys] model.associative_values = reservoir_labels[:] if np is not None and answer_intent_states: answer_intent_array = np.asarray(answer_intent_states, dtype=RUNTIME_ARRAY_DTYPE) mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE) answer_array = ((answer_intent_array * mask_array[None, :]) - offset_array[None, :]).astype( RUNTIME_ARRAY_DTYPE, copy=False, ) model.answer_keys = answer_array model.answer_key_norms = np.linalg.norm(answer_array, axis=1).tolist() else: offset_vector = model.state_offset model.answer_keys = [ [ value - offset_vector[index] for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale)) ] for state in answer_intent_states ] model.answer_key_norms = [norm(state) for state in model.answer_keys] model.answer_values = answer_intent_labels[:] if np is not None and answer_start_states: answer_start_array = np.asarray(answer_start_states, dtype=RUNTIME_ARRAY_DTYPE) mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE) start_array = ((answer_start_array * mask_array[None, :]) - offset_array[None, :]).astype( RUNTIME_ARRAY_DTYPE, copy=False, ) model.answer_start_keys = start_array model.answer_start_key_norms = np.linalg.norm(start_array, axis=1).tolist() else: offset_vector = model.state_offset model.answer_start_keys = [ [ value - offset_vector[index] for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale)) ] for state in answer_start_states ] model.answer_start_key_norms = [norm(state) for state in model.answer_start_keys] model.answer_start_values = answer_start_labels[:] if np is not None and answer_sequence_states: answer_sequence_array = np.asarray(answer_sequence_states, dtype=RUNTIME_ARRAY_DTYPE) mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE) sequence_array = ((answer_sequence_array * mask_array[None, :]) - offset_array[None, :]).astype( RUNTIME_ARRAY_DTYPE, copy=False, ) model.answer_sequence_keys = sequence_array model.answer_sequence_key_norms = np.linalg.norm(sequence_array, axis=1).tolist() else: offset_vector = model.state_offset model.answer_sequence_keys = [ [ value - offset_vector[index] for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale)) ] for state in answer_sequence_states ] model.answer_sequence_key_norms = [norm(state) for state in model.answer_sequence_keys] if np is not None: padded_answer_sequences = np.full( (len(answer_sequence_rows), MAX_ANSWER_SEQUENCE_TOKENS), -1, dtype=np.int32, ) for row_index, row in enumerate(answer_sequence_rows): row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS) if row_width > 0: padded_answer_sequences[row_index, :row_width] = row[:row_width] padded_answer_sequence_prompts = np.full( (len(answer_sequence_prompt_rows), MAX_ANSWER_SEQUENCE_TOKENS), -1, dtype=np.int32, ) for row_index, row in enumerate(answer_sequence_prompt_rows): row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS) if row_width > 0: padded_answer_sequence_prompts[row_index, :row_width] = row[:row_width] else: padded_answer_sequences = [ row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))] for row in answer_sequence_rows ] padded_answer_sequence_prompts = [ row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))] for row in answer_sequence_prompt_rows ] model.answer_sequence_prompt_tokens = padded_answer_sequence_prompts model.answer_sequence_tokens = padded_answer_sequences model.transition_tables = transitions.finalize( max_contexts_per_order=config.max_transition_contexts_per_order, max_next_tokens=config.max_transition_next_tokens, ) finish_stage("model_finalize") payload = { "streaming": True, "documents_processed": processed, "source_counts": source_counts, "embedding_vocab_size": len(embedding_model.id_to_token), "tokenizer_vocab_size": tokenizer.vocab_size, "examples_processed": int(round(example_weight_total)), "associative_examples": len(model.associative_keys), "answer_associative_examples": len(answer_reservoir.states), "general_associative_examples": len(general_reservoir.states), "answer_intent_examples": len(model.answer_keys), "answer_start_examples": len(model.answer_start_keys), "answer_sequence_examples": len(model.answer_sequence_keys), "prompt_answer_readout_examples": prompt_answer_readout_examples, "prompt_answer_start_readout_examples": prompt_answer_start_readout_examples, "stage_seconds": stage_seconds, "target_balance_reference": round(float(reference_label_mass), 6), "readout_solver": readout_solver, "preference_pairs": len(preference_token_pairs), "preference_state_pairs": preference_state_pairs, } return model, payload