| from __future__ import annotations |
|
|
| import json |
| import random |
| import re |
| import site |
| import sys |
| import time |
| from collections import Counter |
| from collections.abc import Iterable, Iterator |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| from .config import ReframrConfig |
| from .corpus import build_vocabulary_from_counts |
| from .embeddings import fit_ppmi_embedding_from_cooccurrence, fit_randomized_ppmi_embedding_from_counts |
| from .hippo import AnalyticalMemoryUnit |
| from .linalg import Matrix, Vector, norm, zeros, zeros_vector |
| from .model import ReframrModel, RUNTIME_ARRAY_DTYPE, TRANSITION_ORDERS, np |
| from .reservoir import ( |
| ridge_regression_readout_from_diagonal_moments, |
| ridge_regression_readout_from_moments, |
| ) |
| from .ternary import apply_ternary_mask, derive_ternary_mask_from_feature_energy |
| from .text_quality import clean_answer_text, clean_context_text, clean_training_text |
| from .tokenizer import NativeTokenizer |
|
|
| try: |
| from scipy import sparse as scipy_sparse |
| except (ImportError, ModuleNotFoundError, OSError): |
| scipy_sparse = None |
|
|
| TEXT_FIELD_PREFERENCES = ( |
| "text", |
| "content", |
| "body", |
| "article", |
| "document", |
| "passage", |
| "markdown", |
| "answer", |
| "response", |
| ) |
|
|
| DIALOGUE_FIELD_PREFERENCES = ( |
| "messages", |
| "conversation", |
| "conversations", |
| "dialogue", |
| "dialog", |
| "turns", |
| "chosen", |
| ) |
| INSTRUCTION_FIELD_PAIRS = ( |
| ("instruction", "output"), |
| ("prompt", "completion"), |
| ("prompt", "response"), |
| ("question", "answer"), |
| ("question", "response"), |
| ("query", "answer"), |
| ("query", "response"), |
| ) |
| TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE) |
| ROLE_ALIASES = { |
| "assistant": "assistant", |
| "assistant_response": "assistant", |
| "bot": "assistant", |
| "gpt": "assistant", |
| "model": "assistant", |
| "human": "user", |
| "prompter": "user", |
| "user": "user", |
| "customer": "user", |
| "system": "system", |
| } |
| ANSWER_READOUT_WEIGHT = 1.0 |
| CONTEXT_READOUT_WEIGHT = 0.0 |
| CONTEXT_STAT_WEIGHT = 0.02 |
| PLAIN_TEXT_READOUT_WEIGHT = 0.03 |
| PREFERENCE_REJECTED_TOKENIZER_WEIGHT = 0.0 |
| PREFERENCE_BIAS_SCALE = 0.95 |
| MAX_PREFERENCE_STATE_PAIRS = 512 |
| ANSWER_START_TOKEN_WINDOW = 12 |
| ANSWER_START_DECAY = 0.86 |
| MAX_ANSWER_SEQUENCE_EXAMPLES = 196608 |
| MAX_ANSWER_SEQUENCE_TOKENS = 192 |
| HF_STREAM_MAX_RETRIES = 5 |
| HF_STREAM_RETRY_BASE_DELAY_SECONDS = 0.25 |
| FULL_READOUT_FEATURE_LIMIT = 2304 |
| FULL_READOUT_EXAMPLE_LIMIT = 25000 |
|
|
|
|
| @dataclass(slots=True) |
| class CorpusPlanEntry: |
| source: str |
| name: str |
| dataset: str = "" |
| path: str = "" |
| config: str | None = None |
| split: str = "train" |
| limit: int = 0 |
| weight: float = 1.0 |
| text_field: str | None = None |
| min_words: int = 0 |
| max_words: int = 0 |
| min_alpha_ratio: float = 0.0 |
| allowed_languages: tuple[str, ...] = () |
| records: tuple[object, ...] = () |
| streaming: bool = True |
| trust_remote_code: bool = False |
|
|
|
|
| @dataclass(slots=True) |
| class StreamDocument: |
| text: str |
| weight: float |
| source: str |
| language: str = "" |
| preference_rejected_text: str = "" |
|
|
|
|
| class StreamingCooccurrenceAccumulator: |
| def __init__(self, token_to_id: dict[str, int], window_size: int) -> None: |
| self.token_to_id = token_to_id |
| self.window_size = window_size |
| self.rows: dict[int, dict[int, float]] = {} |
|
|
| def update_tokens(self, tokens: list[str], *, weight: float) -> None: |
| token_ids = [self.token_to_id[token] for token in tokens if token in self.token_to_id] |
| for index, token_id in enumerate(token_ids): |
| for offset in range(1, self.window_size + 1): |
| other_index = index + offset |
| if other_index >= len(token_ids): |
| break |
| other_id = token_ids[other_index] |
| delta = weight * (1.0 / offset) |
| self.rows.setdefault(token_id, {})[other_id] = ( |
| self.rows.setdefault(token_id, {}).get(other_id, 0.0) + delta |
| ) |
| self.rows.setdefault(other_id, {})[token_id] = ( |
| self.rows.setdefault(other_id, {}).get(token_id, 0.0) + delta |
| ) |
|
|
| def to_dense(self) -> Matrix: |
| size = len(self.token_to_id) |
| matrix = zeros(size, size) |
| for row, columns in self.rows.items(): |
| for col, value in columns.items(): |
| matrix[row][col] = value |
| return matrix |
|
|
| def to_sparse(self) -> object: |
| if scipy_sparse is None or np is None: |
| return self.to_dense() |
| rows: list[int] = [] |
| cols: list[int] = [] |
| data: list[float] = [] |
| for row, columns in self.rows.items(): |
| for col, value in columns.items(): |
| rows.append(row) |
| cols.append(col) |
| data.append(value) |
| size = len(self.token_to_id) |
| return scipy_sparse.coo_matrix( |
| ( |
| np.asarray(data, dtype=np.float64), |
| (np.asarray(rows, dtype=np.int64), np.asarray(cols, dtype=np.int64)), |
| ), |
| shape=(size, size), |
| dtype=np.float64, |
| ).tocsr() |
|
|
|
|
| class TransitionAccumulator: |
| def __init__( |
| self, |
| *, |
| max_contexts_per_order: int | None = None, |
| max_next_tokens: int = 0, |
| ) -> None: |
| self.max_contexts_per_order = max_contexts_per_order |
| self.max_next_tokens = max_next_tokens |
| self.context_soft_limit = ( |
| max_contexts_per_order * 4 |
| if max_contexts_per_order is not None and max_contexts_per_order > 0 |
| else None |
| ) |
| self.next_token_soft_limit = max_next_tokens * 4 if max_next_tokens > 0 else None |
| self.counts: dict[int, dict[tuple[str, ...], dict[str, float]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
|
|
| def update_tokens(self, tokens: list[str], *, weight: float) -> None: |
| for order in sorted(TRANSITION_ORDERS): |
| order_counts = self.counts[order] |
| for index in range(order - 1, len(tokens) - 1): |
| key = tuple(tokens[index - order + 1 : index + 1]) |
| nxt = tokens[index + 1] |
| if ( |
| self.context_soft_limit is not None |
| and key not in order_counts |
| and len(order_counts) >= self.context_soft_limit |
| ): |
| continue |
| bucket = order_counts.setdefault(key, {}) |
| if ( |
| self.next_token_soft_limit is not None |
| and nxt not in bucket |
| and len(bucket) >= self.next_token_soft_limit |
| ): |
| continue |
| bucket[nxt] = bucket.get(nxt, 0.0) + weight |
|
|
| def finalize( |
| self, |
| *, |
| max_contexts_per_order: int | None, |
| max_next_tokens: int, |
| ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]: |
| probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = { |
| order: {} for order in sorted(TRANSITION_ORDERS) |
| } |
| for order, mapping in self.counts.items(): |
| items = list(mapping.items()) |
| items.sort(key=lambda item: (-sum(item[1].values()), item[0])) |
| if max_contexts_per_order is not None and max_contexts_per_order >= 0: |
| items = items[:max_contexts_per_order] |
| for key, bucket in items: |
| next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0])) |
| if max_next_tokens > 0: |
| next_items = next_items[:max_next_tokens] |
| total = sum(value for _, value in next_items) |
| if total <= 0.0: |
| continue |
| probabilities[order][key] = { |
| token: value / total |
| for token, value in next_items |
| } |
| return probabilities |
|
|
|
|
| class StateReservoir: |
| def __init__(self, capacity: int | None, *, seed: int = 13) -> None: |
| self.capacity = capacity |
| self.random = random.Random(seed) |
| self.states: list[Vector] = [] |
| self.labels: list[int] = [] |
| self.weights: list[float] = [] |
| self.seen = 0 |
| self.total_weight = 0.0 |
|
|
| def reserve_slot(self, weight: float = 1.0) -> int | None: |
| if weight <= 0.0: |
| return None |
| self.seen += 1 |
| self.total_weight += weight |
| if self.capacity is None: |
| return len(self.states) |
| if self.capacity <= 0: |
| return None |
| if len(self.states) < self.capacity: |
| return len(self.states) |
| keep_probability = min(1.0, (self.capacity * weight) / max(self.total_weight, 1e-12)) |
| if self.random.random() >= keep_probability: |
| return None |
| return self.random.randrange(self.capacity) |
|
|
| def store_reserved( |
| self, |
| slot: int, |
| state: Vector, |
| label_id: int, |
| *, |
| example_weight: float = 1.0, |
| ) -> None: |
| stored_state = state.copy() if hasattr(state, "copy") else state[:] |
| if slot == len(self.states): |
| self.states.append(stored_state) |
| self.labels.append(label_id) |
| self.weights.append(example_weight) |
| elif 0 <= slot < len(self.states): |
| self.states[slot] = stored_state |
| self.labels[slot] = label_id |
| self.weights[slot] = example_weight |
|
|
| def consider(self, state: Vector, label_id: int, weight: float = 1.0) -> None: |
| slot = self.reserve_slot(weight=weight) |
| if slot is not None: |
| self.store_reserved(slot, state, label_id, example_weight=weight) |
|
|
|
|
| class SequenceReservoir: |
| def __init__(self, capacity: int | None, *, seed: int = 41) -> None: |
| self.capacity = capacity |
| self.random = random.Random(seed) |
| self.keys: list[Vector] = [] |
| self.prompt_rows: list[list[int]] = [] |
| self.token_rows: list[list[int]] = [] |
| self.weights: list[float] = [] |
| self.seen_weight = 0.0 |
|
|
| def reserve_slot(self, *, weight: float = 1.0) -> int | None: |
| if self.capacity == 0 or weight <= 0.0: |
| return None |
| self.seen_weight += weight |
| if self.capacity is None or len(self.keys) < self.capacity: |
| return len(self.keys) |
| probability = min(1.0, (self.capacity * weight) / max(self.seen_weight, 1e-12)) |
| if self.random.random() >= probability: |
| return None |
| return self.random.randrange(self.capacity) |
|
|
| def store_reserved( |
| self, |
| slot: int, |
| key: Vector, |
| prompt_token_ids: list[int], |
| token_ids: list[int], |
| *, |
| example_weight: float = 1.0, |
| ) -> None: |
| key_copy = key.tolist() if hasattr(key, "tolist") else list(key) |
| prompt_row = prompt_token_ids[:MAX_ANSWER_SEQUENCE_TOKENS] |
| row = token_ids[:MAX_ANSWER_SEQUENCE_TOKENS] |
| if self.capacity is None or slot >= len(self.keys): |
| self.keys.append(key_copy) |
| self.prompt_rows.append(prompt_row) |
| self.token_rows.append(row) |
| self.weights.append(example_weight) |
| return |
| self.keys[slot] = key_copy |
| self.prompt_rows[slot] = prompt_row |
| self.token_rows[slot] = row |
| self.weights[slot] = example_weight |
|
|
| def consider( |
| self, |
| key: Vector, |
| prompt_token_ids: list[int], |
| token_ids: list[int], |
| weight: float = 1.0, |
| ) -> None: |
| if not token_ids: |
| return |
| slot = self.reserve_slot(weight=weight) |
| if slot is not None: |
| self.store_reserved(slot, key, prompt_token_ids, token_ids, example_weight=weight) |
|
|
|
|
| def _word_count(text: str) -> int: |
| return len(text.split()) |
|
|
|
|
| def _alpha_ratio(text: str) -> float: |
| if not text: |
| return 0.0 |
| alpha_count = sum(character.isalpha() for character in text) |
| return alpha_count / len(text) |
|
|
|
|
| def _row_language(row: dict[str, object]) -> str: |
| for candidate in ("lang", "language", "locale"): |
| value = row.get(candidate) |
| if isinstance(value, str) and value.strip(): |
| return value.strip() |
| return "" |
|
|
|
|
| def _normalize_role(raw_role: object) -> str: |
| role = str(raw_role or "").strip().casefold() |
| return ROLE_ALIASES.get(role, role) |
|
|
|
|
| def _message_content(message: dict[str, object]) -> str: |
| for field in ("content", "value", "text", "message"): |
| value = message.get(field) |
| if isinstance(value, str) and value.strip(): |
| return clean_training_text(value) |
| return "" |
|
|
|
|
| def _message_role(message: dict[str, object]) -> str: |
| for field in ("role", "from", "speaker", "author"): |
| value = message.get(field) |
| if value is not None: |
| normalized = _normalize_role(value) |
| if normalized: |
| return normalized |
| return "" |
|
|
|
|
| def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]: |
| if not isinstance(raw_messages, list): |
| return [] |
|
|
| parsed: list[dict[str, str]] = [] |
| for message in raw_messages: |
| if not isinstance(message, dict): |
| continue |
| role = _message_role(message) |
| content = _message_content(message) |
| if role not in {"system", "user", "assistant"} or not content: |
| continue |
| parsed.append({"role": role, "content": content}) |
| return parsed |
|
|
|
|
| def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]: |
| if not isinstance(raw_text, str): |
| return [] |
|
|
| text = raw_text.strip() |
| if not text: |
| return [] |
|
|
| matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text)) |
| if not matches: |
| return [] |
|
|
| parsed: list[dict[str, str]] = [] |
| for index, match in enumerate(matches): |
| role = _normalize_role(match.group(1)) |
| start = match.end() |
| end = matches[index + 1].start() if index + 1 < len(matches) else len(text) |
| content = clean_training_text(text[start:end].strip()) |
| if role in {"system", "user", "assistant"} and content: |
| parsed.append({"role": role, "content": content}) |
| return parsed |
|
|
|
|
| def _render_prompt(messages: list[dict[str, str]]) -> str: |
| parts = [] |
| for message in messages: |
| content = clean_context_text(message["content"]) |
| if content: |
| parts.append(content) |
| return "\n".join(parts).strip() |
|
|
|
|
| def _last_user_prompt_before(messages: list[dict[str, str]], end_index: int) -> str: |
| for message in reversed(messages[:end_index]): |
| if message["role"] == "user": |
| return clean_context_text(message["content"]) |
| return _render_prompt(messages[:end_index]) |
|
|
|
|
| def _compose_training_text(context: object, answer: object) -> str: |
| prompt_text = clean_context_text(_flatten_value(context)) |
| answer_text = clean_answer_text(_flatten_value(answer)) |
| if prompt_text and answer_text: |
| return f"<reason> {prompt_text} <answer> {answer_text}".strip() |
| return clean_training_text(answer_text or prompt_text) |
|
|
|
|
| def _compose_from_messages(messages: list[dict[str, str]]) -> str: |
| assistant_index = None |
| for index in range(len(messages) - 1, -1, -1): |
| if messages[index]["role"] == "assistant": |
| assistant_index = index |
| break |
| if assistant_index is not None: |
| prompt = _last_user_prompt_before(messages, assistant_index) |
| answer = clean_answer_text(messages[assistant_index]["content"]) |
| if prompt and answer: |
| return f"<reason> {prompt} <answer> {answer}".strip() |
| return "\n".join( |
| message["content"] |
| for message in messages |
| if message.get("content") |
| ).strip() |
|
|
|
|
| def _flatten_message_list(messages: object) -> str: |
| parsed = _parse_dialogue_messages(messages) |
| if parsed: |
| return _compose_from_messages(parsed) |
| if not isinstance(messages, list): |
| return "" |
| parts: list[str] = [] |
| for message in messages: |
| if not isinstance(message, dict): |
| continue |
| content = str( |
| message.get("content", message.get("value", message.get("text", ""))) |
| ).strip() |
| if not content: |
| continue |
| parts.append(clean_training_text(content)) |
| return "\n".join(parts).strip() |
|
|
|
|
| def _flatten_value(value: object) -> str: |
| if isinstance(value, str): |
| parsed = _parse_transcript_messages(value) |
| if parsed: |
| return _compose_from_messages(parsed) |
| return clean_training_text(value.strip()) |
| if isinstance(value, list): |
| return _flatten_message_list(value) |
| if isinstance(value, dict): |
| for field in ("messages", "conversation", "conversations", "dialogue", "turns"): |
| nested_messages = value.get(field) |
| text = _flatten_message_list(nested_messages) |
| if text: |
| return text |
| for field in ("text", "content", "value", "message"): |
| nested = value.get(field) |
| if isinstance(nested, str) and nested.strip(): |
| return _flatten_value(nested) |
| return "" |
|
|
|
|
| def _safe_flag(value: object) -> bool | None: |
| if isinstance(value, bool): |
| return value |
| if isinstance(value, str): |
| normalized = value.strip().casefold() |
| if normalized in {"true", "1", "yes", "safe"}: |
| return True |
| if normalized in {"false", "0", "no", "unsafe"}: |
| return False |
| return None |
|
|
|
|
| def _selected_response_fields(row: dict[str, object]) -> tuple[str, str]: |
| if "response_0" not in row or "response_1" not in row: |
| return "", "" |
| safe_0 = _safe_flag(row.get("is_response_0_safe")) |
| safe_1 = _safe_flag(row.get("is_response_1_safe")) |
| if safe_0 is not None and safe_1 is not None: |
| if safe_0 and not safe_1: |
| return "response_0", "response_1" |
| if safe_1 and not safe_0: |
| return "response_1", "response_0" |
| if safe_0 and safe_1: |
| return "response_0", "" |
| return "", "" |
| for selector in ("safer_response_id", "better_response_id"): |
| raw_value = row.get(selector) |
| try: |
| preferred = int(raw_value) |
| except (TypeError, ValueError): |
| continue |
| chosen = "response_1" if preferred == 1 else "response_0" |
| rejected = "response_0" if chosen == "response_1" else "response_1" |
| return chosen, rejected |
| return "response_0", "response_1" |
|
|
|
|
| def _extract_preference_pair(row: dict[str, object]) -> tuple[str, str]: |
| if "chosen" in row and "rejected" in row: |
| chosen_text = clean_training_text(_flatten_value(row.get("chosen"))) |
| rejected_text = clean_training_text(_flatten_value(row.get("rejected"))) |
| if chosen_text and rejected_text: |
| return chosen_text, rejected_text |
| if "response_0" in row and "response_1" in row: |
| preferred_field, rejected_field = _selected_response_fields(row) |
| if not preferred_field or not rejected_field: |
| return "", "" |
| prompt = row.get("prompt", row.get("question", row.get("query", ""))) |
| if prompt: |
| chosen_text = _compose_training_text(prompt, row.get(preferred_field)) |
| rejected_text = _compose_training_text(prompt, row.get(rejected_field)) |
| if chosen_text and rejected_text: |
| return clean_training_text(chosen_text), clean_training_text(rejected_text) |
| chosen_text = clean_training_text(_flatten_value(row.get(preferred_field))) |
| rejected_text = clean_training_text(_flatten_value(row.get(rejected_field))) |
| if chosen_text and rejected_text: |
| return chosen_text, rejected_text |
| return "", "" |
|
|
|
|
| def _extract_preference_value(row: dict[str, object]) -> str: |
| chosen_text, _ = _extract_preference_pair(row) |
| return chosen_text |
|
|
|
|
| def _extract_row_text(row: dict[str, object], text_field: str | None) -> str: |
| if "context" in row and "answer" in row: |
| context = clean_context_text(_flatten_value(row.get("context"))) |
| answer = clean_answer_text(_flatten_value(row.get("answer"))) |
| if context and answer: |
| return f"<reason> {context} <answer> {answer}".strip() |
|
|
| if "response_0" in row and "response_1" in row: |
| preferred_field, _ = _selected_response_fields(row) |
| prompt = row.get("prompt", row.get("question", row.get("query", ""))) |
| if preferred_field and prompt: |
| text = _compose_training_text(prompt, row.get(preferred_field)) |
| if text: |
| return text |
|
|
| for prompt_field, answer_field in INSTRUCTION_FIELD_PAIRS: |
| if prompt_field in row and answer_field in row: |
| text = _compose_training_text(row.get(prompt_field), row.get(answer_field)) |
| if text: |
| return text |
|
|
| if text_field is not None: |
| return clean_training_text(_flatten_value(row.get(text_field))) |
|
|
| preferred = _extract_preference_value(row) |
| if preferred: |
| return clean_training_text(preferred) |
|
|
| for field in TEXT_FIELD_PREFERENCES: |
| text = _flatten_value(row.get(field)) |
| if text: |
| return clean_training_text(text) |
| for field in DIALOGUE_FIELD_PREFERENCES: |
| text = _flatten_value(row.get(field)) |
| if text: |
| return clean_training_text(text) |
| return "" |
|
|
|
|
| def _passes_text_quality(text: str, language: str, entry: CorpusPlanEntry) -> bool: |
| if not text: |
| return False |
| word_count = _word_count(text) |
| if entry.min_words > 0 and word_count < entry.min_words: |
| return False |
| if entry.max_words > 0 and word_count > entry.max_words: |
| return False |
| if entry.min_alpha_ratio > 0.0 and _alpha_ratio(text) < entry.min_alpha_ratio: |
| return False |
| if entry.allowed_languages: |
| if not language or language.casefold() not in entry.allowed_languages: |
| return False |
| return True |
|
|
|
|
| def load_corpus_plan(source: str | Path) -> list[CorpusPlanEntry]: |
| payload = json.loads(Path(source).read_text(encoding="utf-8-sig")) |
| raw_entries = payload.get("sources", payload.get("datasets", [])) |
| if not isinstance(raw_entries, list) or not raw_entries: |
| raise ValueError("Corpus plan must define a non-empty 'sources' list.") |
|
|
| entries: list[CorpusPlanEntry] = [] |
| for index, raw_entry in enumerate(raw_entries, start=1): |
| if not isinstance(raw_entry, dict): |
| raise ValueError("Each corpus plan entry must be an object.") |
| source = str(raw_entry.get("source", "hf")).strip() or "hf" |
| name = str( |
| raw_entry.get("name", raw_entry.get("dataset", f"source-{index}")) |
| ).strip() or f"source-{index}" |
| raw_languages = raw_entry.get("allowed_languages", []) |
| allowed_languages = tuple( |
| str(value).strip().casefold() |
| for value in raw_languages |
| if str(value).strip() |
| ) if isinstance(raw_languages, list) else () |
| raw_records = raw_entry.get("records", raw_entry.get("texts", [])) |
| if source == "inline" and not isinstance(raw_records, list): |
| raise ValueError("Inline corpus plan entries must provide a records/texts list.") |
| entries.append( |
| CorpusPlanEntry( |
| source=source, |
| name=name, |
| dataset=str(raw_entry.get("dataset", "")), |
| path=str(raw_entry.get("path", raw_entry.get("file", ""))), |
| config=( |
| str(raw_entry["config"]) |
| if raw_entry.get("config") is not None |
| else None |
| ), |
| split=str(raw_entry.get("split", "train")), |
| limit=int(raw_entry.get("limit", 0)), |
| weight=float(raw_entry.get("weight", 1.0)), |
| text_field=( |
| str(raw_entry["text_field"]) |
| if raw_entry.get("text_field") is not None |
| else None |
| ), |
| min_words=int(raw_entry.get("min_words", 0)), |
| max_words=int(raw_entry.get("max_words", 0)), |
| min_alpha_ratio=float(raw_entry.get("min_alpha_ratio", 0.0)), |
| allowed_languages=allowed_languages, |
| records=tuple(raw_records) if isinstance(raw_records, list) else (), |
| streaming=bool(raw_entry.get("streaming", True)), |
| trust_remote_code=bool(raw_entry.get("trust_remote_code", False)), |
| ) |
| ) |
| return entries |
|
|
|
|
| def _iter_hf_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]: |
| try: |
| from datasets import load_dataset |
| except ModuleNotFoundError: |
| user_site = site.getusersitepackages() |
| if user_site and user_site not in sys.path: |
| sys.path.append(user_site) |
| from datasets import load_dataset |
|
|
| dataset_kwargs: dict[str, object] = { |
| "split": entry.split, |
| "streaming": entry.streaming, |
| } |
| if entry.config: |
| dataset_kwargs["name"] = entry.config |
| if entry.trust_remote_code: |
| dataset_kwargs["trust_remote_code"] = True |
|
|
| for row in load_dataset(entry.dataset, **dataset_kwargs): |
| yield dict(row) |
|
|
|
|
| def _iter_file_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]: |
| raw_path = entry.path or entry.dataset |
| if not raw_path: |
| raise ValueError("File corpus plan entries must provide a path.") |
| path = Path(raw_path) |
| suffix = path.suffix.lower() |
| if suffix == ".jsonl": |
| with path.open("r", encoding="utf-8") as handle: |
| for line in handle: |
| if line.strip(): |
| row = json.loads(line) |
| yield row if isinstance(row, dict) else {"text": str(row)} |
| return |
| if suffix == ".json": |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| if isinstance(payload, list): |
| for row in payload: |
| yield row if isinstance(row, dict) else {"text": str(row)} |
| return |
| if isinstance(payload, dict): |
| rows = payload.get("records", payload.get("texts")) |
| if isinstance(rows, list): |
| for row in rows: |
| yield row if isinstance(row, dict) else {"text": str(row)} |
| return |
| yield payload |
| return |
| if suffix in {".txt", ".md", ".text"}: |
| yield {"text": path.read_text(encoding="utf-8")} |
| return |
| raise ValueError(f"Unsupported file corpus source: {path}") |
|
|
|
|
| def iter_corpus_plan_documents(plan: Iterable[CorpusPlanEntry]) -> Iterator[StreamDocument]: |
| for entry in plan: |
| accepted = 0 |
| attempts = 0 |
| while True: |
| accepted_seen_this_attempt = 0 |
| try: |
| if entry.source == "inline": |
| row_iterator = ( |
| item if isinstance(item, dict) else {"text": str(item)} |
| for item in entry.records |
| ) |
| elif entry.source == "hf": |
| row_iterator = _iter_hf_rows(entry) |
| elif entry.source == "file": |
| row_iterator = _iter_file_rows(entry) |
| else: |
| raise ValueError(f"Unsupported corpus plan source: {entry.source}") |
|
|
| for row in row_iterator: |
| language = _row_language(row) |
| _, rejected_text = _extract_preference_pair(row) |
| text = clean_training_text(_extract_row_text(row, entry.text_field)) |
| if not _passes_text_quality(text, language, entry): |
| continue |
| accepted_seen_this_attempt += 1 |
| if accepted_seen_this_attempt <= accepted: |
| continue |
| yield StreamDocument( |
| text=text, |
| weight=entry.weight, |
| source=entry.name, |
| language=language, |
| preference_rejected_text=rejected_text, |
| ) |
| accepted += 1 |
| if entry.limit > 0 and accepted >= entry.limit: |
| break |
| break |
| except Exception as exc: |
| if entry.source != "hf": |
| raise |
| if attempts >= HF_STREAM_MAX_RETRIES: |
| print( |
| f"[source] {entry.name} skipped after {attempts} retries; " |
| f"accepted {accepted} documents before final error: {exc}" |
| ) |
| break |
| attempts += 1 |
| delay = min( |
| 15.0, |
| HF_STREAM_RETRY_BASE_DELAY_SECONDS * (2 ** (attempts - 1)), |
| ) |
| print( |
| f"[source] {entry.name} stream interrupted after {accepted} accepted " |
| f"documents; retry {attempts}/{HF_STREAM_MAX_RETRIES} in {delay:.2f}s: {exc}" |
| ) |
| time.sleep(delay) |
|
|
|
|
| def _log_progress(label: str, processed: int, log_every: int) -> None: |
| if log_every > 0 and processed % log_every == 0: |
| print(f"[{label}] processed {processed} documents") |
|
|
|
|
| def _answer_boundary(tokens: list[str]) -> int | None: |
| try: |
| return tokens.index("<answer>") |
| except ValueError: |
| return None |
|
|
|
|
| def _weighted_text_parts_for_statistics(text: str, document_weight: float) -> list[tuple[str, float]]: |
| if "<answer>" not in text: |
| return [(text, document_weight)] |
| context, answer = text.split("<answer>", 1) |
| context = clean_context_text(context.replace("<reason>", " ")) |
| answer = clean_answer_text(answer) |
| parts: list[tuple[str, float]] = [] |
| if context: |
| parts.append((context, document_weight * CONTEXT_STAT_WEIGHT)) |
| if answer: |
| parts.append((answer, document_weight * ANSWER_READOUT_WEIGHT)) |
| return parts or [(text, document_weight)] |
|
|
|
|
| def _weighted_token_sequences_for_statistics( |
| tokens: list[str], |
| tokenizer: NativeTokenizer, |
| document_weight: float, |
| ) -> list[tuple[list[str], float]]: |
| answer_index = _answer_boundary(tokens) |
| if answer_index is None: |
| sequence = [token for token in tokens if token not in tokenizer.special_tokens] |
| return [(sequence, document_weight)] if sequence else [] |
| context_tokens = [ |
| token for token in tokens[:answer_index] if token not in tokenizer.special_tokens |
| ] |
| answer_tokens = [ |
| token for token in tokens[answer_index + 1 :] if token not in tokenizer.special_tokens |
| ] |
| sequences: list[tuple[list[str], float]] = [] |
| if context_tokens: |
| sequences.append((context_tokens, document_weight * CONTEXT_STAT_WEIGHT)) |
| if answer_tokens: |
| sequences.append((answer_tokens, document_weight * ANSWER_READOUT_WEIGHT)) |
| return sequences |
|
|
|
|
| def _readout_weight_for_target( |
| answer_index: int | None, |
| target_index: int, |
| document_weight: float, |
| ) -> float: |
| if answer_index is None: |
| return document_weight * PLAIN_TEXT_READOUT_WEIGHT |
| if target_index <= answer_index: |
| return document_weight * CONTEXT_READOUT_WEIGHT |
| return document_weight * ANSWER_READOUT_WEIGHT |
|
|
|
|
| def _answer_payload_tokens(tokens: list[str], tokenizer: NativeTokenizer) -> list[str]: |
| answer_index = _answer_boundary(tokens) |
| payload = tokens[answer_index + 1 :] if answer_index is not None else tokens |
| return [token for token in payload if token not in tokenizer.special_tokens] |
|
|
|
|
| def _standardized_preference_bias(values: object, active_mask: object | None = None) -> list[float]: |
| if np is not None: |
| bias = np.asarray(values, dtype=np.float64) |
| if bias.size == 0: |
| return [] |
| active = ( |
| np.asarray(active_mask, dtype=bool) |
| if active_mask is not None |
| else np.ones(bias.shape, dtype=bool) |
| ) |
| if not np.any(active): |
| return [0.0 for _ in range(int(bias.size))] |
| active_values = bias[active] |
| spread = float(active_values.std()) |
| if spread <= 1e-12: |
| return [0.0 for _ in range(int(bias.size))] |
| standardized = np.zeros_like(bias, dtype=np.float64) |
| standardized[active] = ( |
| (active_values - float(active_values.mean())) / spread |
| ) * PREFERENCE_BIAS_SCALE |
| return np.clip(standardized, -2.5, 2.5).astype(float).tolist() |
| raw_values = [float(value) for value in values] |
| if not raw_values: |
| return [] |
| average = sum(raw_values) / len(raw_values) |
| variance = sum((value - average) * (value - average) for value in raw_values) / len(raw_values) |
| spread = variance**0.5 |
| if spread <= 1e-12: |
| return [0.0 for _ in raw_values] |
| active_indices = ( |
| [ |
| index |
| for index, active in enumerate(active_mask) |
| if active |
| ] |
| if active_mask is not None |
| else list(range(len(raw_values))) |
| ) |
| if not active_indices: |
| return [0.0 for _ in raw_values] |
| active_values = [raw_values[index] for index in active_indices] |
| average = mean(active_values) |
| spread = (mean([(value - average) * (value - average) for value in active_values])) ** 0.5 |
| if spread <= 1e-12: |
| return [0.0 for _ in raw_values] |
| standardized = [0.0 for _ in raw_values] |
| for index in active_indices: |
| standardized[index] = max( |
| -2.5, |
| min(2.5, ((raw_values[index] - average) / spread) * PREFERENCE_BIAS_SCALE), |
| ) |
| return standardized |
|
|
|
|
| def _candidate_preference_bias_from_state_vector( |
| model: ReframrModel, |
| preference_state: object, |
| ) -> object: |
| if np is None: |
| return None |
| assert model.embedding_model is not None |
| assert model.memory_units is not None |
| assert model.ternary_mask is not None |
|
|
| embeddings = np.asarray(model.embedding_model.embeddings, dtype=np.float64) |
| if embeddings.size == 0: |
| return np.zeros(0, dtype=np.float64) |
| state_vector = np.asarray(preference_state, dtype=np.float64) |
| mask = np.asarray(model.ternary_mask, dtype=np.float64) * float(model.ternary_scale) |
| if state_vector.shape[0] != mask.shape[0]: |
| return np.zeros(embeddings.shape[0], dtype=np.float64) |
|
|
| state_indices = np.arange(model.config.state_dim, dtype=np.int64) |
| drive = ( |
| embeddings[:, state_indices % model.config.embedding_dim] |
| + (0.5 * embeddings[:, (3 * state_indices + 1) % model.config.embedding_dim]) |
| - (0.25 * embeddings[:, (5 * state_indices + 2) % model.config.embedding_dim]) |
| ) |
| scores = np.zeros(embeddings.shape[0], dtype=np.float64) |
| offset = 0 |
| for unit in model.memory_units: |
| hidden_end = offset + model.config.state_dim |
| trace_end = hidden_end + model.config.embedding_dim |
| hidden_pref = state_vector[offset:hidden_end] * mask[offset:hidden_end] |
| trace_pref = state_vector[hidden_end:trace_end] * mask[hidden_end:trace_end] |
| hidden_delta_axis = np.asarray(unit.input_projection, dtype=np.float64) * hidden_pref |
| trace_gain = 1.0 - (1.0 / (1.0 + unit.timescale)) |
| scores += drive @ hidden_delta_axis |
| scores += embeddings @ (trace_gain * trace_pref) |
| offset = trace_end |
| return scores |
|
|
|
|
| def _derive_preference_bias_from_pairs( |
| model: ReframrModel, |
| preference_token_pairs: list[tuple[list[str], list[str], float]], |
| tokenizer: NativeTokenizer, |
| ) -> tuple[list[float], int]: |
| assert model.embedding_model is not None |
| vocab_size = len(model.embedding_model.id_to_token) |
| if not preference_token_pairs: |
| return [0.0 for _ in range(vocab_size)], 0 |
|
|
| if np is not None: |
| token_bias = np.zeros(vocab_size, dtype=np.float64) |
| active_token_mask = np.zeros(vocab_size, dtype=bool) |
| state_delta = np.zeros(model._combined_state_width(), dtype=np.float64) |
| else: |
| token_bias = [0.0 for _ in range(vocab_size)] |
| active_token_ids: set[int] = set() |
| state_delta = [0.0 for _ in range(model._combined_state_width())] |
| pair_weight_total = 0.0 |
| state_pair_count = 0 |
| state_stride = max( |
| 1, |
| (len(preference_token_pairs) + MAX_PREFERENCE_STATE_PAIRS - 1) |
| // MAX_PREFERENCE_STATE_PAIRS, |
| ) |
|
|
| for pair_index, (chosen_tokens, rejected_tokens, pair_weight) in enumerate(preference_token_pairs): |
| chosen_answer = _answer_payload_tokens(chosen_tokens, tokenizer) |
| rejected_answer = _answer_payload_tokens(rejected_tokens, tokenizer) |
| if chosen_answer: |
| delta = pair_weight / max(1, len(chosen_answer)) |
| for token in chosen_answer: |
| token_id = model.embedding_model.token_to_id.get(token) |
| if token_id is not None: |
| token_bias[token_id] += delta |
| if np is not None: |
| active_token_mask[token_id] = True |
| else: |
| active_token_ids.add(token_id) |
| if rejected_answer: |
| delta = pair_weight / max(1, len(rejected_answer)) |
| for token in rejected_answer: |
| token_id = model.embedding_model.token_to_id.get(token) |
| if token_id is not None: |
| token_bias[token_id] -= delta |
| if np is not None: |
| active_token_mask[token_id] = True |
| else: |
| active_token_ids.add(token_id) |
|
|
| if pair_index % state_stride != 0 or state_pair_count >= MAX_PREFERENCE_STATE_PAIRS: |
| continue |
| chosen_state = model._masked_decode_state(model._build_decode_state(chosen_tokens)) |
| rejected_state = model._masked_decode_state(model._build_decode_state(rejected_tokens)) |
| if len(chosen_state) != len(rejected_state): |
| continue |
| pair_weight_total += pair_weight |
| state_pair_count += 1 |
| if np is not None: |
| state_delta += pair_weight * ( |
| np.asarray(chosen_state, dtype=np.float64) |
| - np.asarray(rejected_state, dtype=np.float64) |
| ) |
| else: |
| for index, (chosen_value, rejected_value) in enumerate(zip(chosen_state, rejected_state)): |
| state_delta[index] += pair_weight * (chosen_value - rejected_value) |
|
|
| if pair_weight_total > 0.0: |
| if np is not None: |
| state_delta = state_delta / pair_weight_total |
| candidate_bias = _candidate_preference_bias_from_state_vector(model, state_delta) |
| if candidate_bias is not None: |
| token_bias[active_token_mask] = ( |
| token_bias[active_token_mask] + candidate_bias[active_token_mask] |
| ) |
| else: |
| state_delta = [value / pair_weight_total for value in state_delta] |
| if np is not None: |
| return _standardized_preference_bias(token_bias, active_token_mask), state_pair_count |
| active_mask = [index in active_token_ids for index in range(vocab_size)] |
| return _standardized_preference_bias(token_bias, active_mask), state_pair_count |
|
|
|
|
| def _solve_weighted_prompt_readout( |
| states: list[Vector], |
| labels: list[int], |
| weights: list[float], |
| *, |
| vocab_size: int, |
| diagonal: object, |
| state_offset: object, |
| regularization: float, |
| ) -> tuple[object, object, int]: |
| if np is None or not states or not labels or not weights: |
| return [], [0.0 for _ in range(vocab_size)], 0 |
| state_matrix = np.asarray(states, dtype=np.float64) |
| label_array = np.asarray(labels, dtype=np.int64) |
| weight_vector = np.asarray(weights, dtype=np.float64) |
| valid_mask = ( |
| (label_array >= 0) |
| & (label_array < vocab_size) |
| & (weight_vector > 0.0) |
| ) |
| if not np.any(valid_mask): |
| return [], [0.0 for _ in range(vocab_size)], 0 |
| state_matrix = state_matrix[valid_mask] |
| label_array = label_array[valid_mask] |
| weight_vector = weight_vector[valid_mask] |
| diagonal_array = np.asarray(diagonal, dtype=np.float64) |
| offset_array = np.asarray(state_offset, dtype=np.float64) |
| if ( |
| len(state_matrix.shape) != 2 |
| or diagonal_array.shape[0] != state_matrix.shape[1] |
| or offset_array.shape[0] != state_matrix.shape[1] |
| ): |
| return [], [0.0 for _ in range(vocab_size)], 0 |
| masked_states = state_matrix * diagonal_array[None, :] |
| centered_states = masked_states - offset_array[None, :] |
| weighted_centered_states = weight_vector[:, None] * centered_states |
| gram = centered_states.T @ weighted_centered_states |
| cross = np.zeros((vocab_size, centered_states.shape[1]), dtype=np.float64) |
| np.add.at(cross, label_array, weighted_centered_states) |
| total_weight = float(weight_vector.sum()) |
| if total_weight <= 0.0: |
| return [], [0.0 for _ in range(vocab_size)], 0 |
| bias = np.zeros(vocab_size, dtype=np.float64) |
| np.add.at(bias, label_array, weight_vector) |
| bias /= total_weight |
| readout = ridge_regression_readout_from_moments( |
| gram, |
| cross, |
| regularization=regularization, |
| ) |
| return readout, bias, int(label_array.shape[0]) |
|
|
|
|
| def fit_model_from_corpus_plan( |
| plan: Iterable[CorpusPlanEntry], |
| config: ReframrConfig, |
| *, |
| log_every: int = 0, |
| ) -> tuple[ReframrModel, dict[str, object]]: |
| entries = list(plan) |
| if not entries: |
| raise ValueError("Cannot fit REFRAMR without any corpus plan entries.") |
| stage_seconds: dict[str, float] = {} |
| stage_started = time.perf_counter() |
|
|
| def finish_stage(name: str) -> None: |
| nonlocal stage_started |
| now = time.perf_counter() |
| elapsed = round(now - stage_started, 6) |
| stage_seconds[name] = elapsed |
| if log_every > 0: |
| print(f"[stage] {name} finished in {elapsed:.3f}s") |
| stage_started = now |
|
|
| seed_tokenizer = NativeTokenizer( |
| merges=[], |
| vocab=[], |
| base_symbols=[], |
| lowercase=config.lowercase, |
| ) |
| segment_counts: Counter[str] = Counter() |
| source_counts: dict[str, int] = {} |
| documents: list[StreamDocument] = [] |
| processed = 0 |
| for entry in entries: |
| if log_every > 0: |
| print(f"[source] {entry.name} started") |
| source_start = processed |
| for document in iter_corpus_plan_documents([entry]): |
| documents.append(document) |
| processed += 1 |
| source_counts[document.source] = source_counts.get(document.source, 0) + 1 |
| for text_part, part_weight in _weighted_text_parts_for_statistics( |
| document.text, |
| document.weight, |
| ): |
| for segment in seed_tokenizer.pretokenize(text_part): |
| segment_counts[segment] += part_weight |
| if document.preference_rejected_text: |
| rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT |
| for text_part, part_weight in _weighted_text_parts_for_statistics( |
| document.preference_rejected_text, |
| rejected_weight, |
| ): |
| for segment in seed_tokenizer.pretokenize(text_part): |
| segment_counts[segment] += part_weight |
| _log_progress("tokenizer", processed, log_every) |
| if log_every > 0: |
| print(f"[source] {entry.name} accepted {processed - source_start} documents") |
| if processed == 0: |
| raise ValueError("Corpus plan did not yield any usable documents after filtering.") |
| finish_stage("stream_and_segment") |
| tokenizer = NativeTokenizer.train_from_segment_counts( |
| segment_counts, |
| vocab_size=config.tokenizer_vocab_size, |
| min_pair_frequency=config.tokenizer_min_pair_frequency, |
| lowercase=config.lowercase, |
| ) |
| finish_stage("tokenizer_fit") |
|
|
| token_counts: Counter[str] = Counter() |
| raw_tokenized_documents: list[list[str]] = [] |
| raw_rejected_tokenized_documents: list[list[str]] = [] |
| processed = 0 |
| for document in documents: |
| processed += 1 |
| tokens = tokenizer.encode(document.text) |
| raw_tokenized_documents.append(tokens) |
| for token in tokens: |
| if token in tokenizer.special_tokens: |
| token_counts[token] += document.weight |
| for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics( |
| tokens, |
| tokenizer, |
| document.weight, |
| ): |
| for token in token_sequence: |
| token_counts[token] += sequence_weight |
| rejected_tokens = ( |
| tokenizer.encode(document.preference_rejected_text) |
| if document.preference_rejected_text |
| else [] |
| ) |
| raw_rejected_tokenized_documents.append(rejected_tokens) |
| rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT |
| for token in rejected_tokens: |
| if token in tokenizer.special_tokens: |
| token_counts[token] += rejected_weight |
| for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics( |
| rejected_tokens, |
| tokenizer, |
| rejected_weight, |
| ): |
| for token in token_sequence: |
| token_counts[token] += sequence_weight |
| _log_progress("vocab", processed, log_every) |
| token_to_id, id_to_token = build_vocabulary_from_counts( |
| token_counts, |
| min_frequency=config.min_frequency, |
| max_vocab=config.max_vocab, |
| ) |
| if not id_to_token: |
| raise ValueError("Streaming recompute could not derive an embedding vocabulary.") |
| finish_stage("vocabulary") |
|
|
| cooccurrence = StreamingCooccurrenceAccumulator(token_to_id, config.window_size) |
| tokenized_documents: list[list[str]] = [] |
| preference_token_pairs: list[tuple[list[str], list[str], float]] = [] |
| processed = 0 |
| for document, raw_tokens, raw_rejected_tokens in zip( |
| documents, |
| raw_tokenized_documents, |
| raw_rejected_tokenized_documents, |
| ): |
| processed += 1 |
| tokens = [token for token in raw_tokens if token in token_to_id] |
| tokenized_documents.append(tokens) |
| rejected_tokens = [token for token in raw_rejected_tokens if token in token_to_id] |
| if len(tokens) > 1 and len(rejected_tokens) > 1: |
| preference_token_pairs.append((tokens, rejected_tokens, document.weight)) |
| for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics( |
| tokens, |
| tokenizer, |
| document.weight, |
| ): |
| if len(token_sequence) > 1: |
| cooccurrence.update_tokens(token_sequence, weight=sequence_weight) |
| _log_progress("cooccurrence", processed, log_every) |
| finish_stage("cooccurrence") |
| if np is not None: |
| embedding_model = fit_randomized_ppmi_embedding_from_counts( |
| id_to_token, |
| cooccurrence.rows, |
| embedding_dim=config.embedding_dim, |
| ) |
| else: |
| embedding_model = fit_ppmi_embedding_from_cooccurrence( |
| id_to_token, |
| cooccurrence.to_sparse(), |
| embedding_dim=config.embedding_dim, |
| ) |
| finish_stage("embedding") |
|
|
| model = ReframrModel(config) |
| model.tokenizer = tokenizer |
| model.embedding_model = embedding_model |
| model.memory_units = [ |
| AnalyticalMemoryUnit(config.state_dim, timescale) |
| for timescale in config.timescales |
| ] |
| model.trace_token_weights = model._derive_trace_token_weights_from_counts(token_counts) |
|
|
| feature_count = len(model._zero_combined_state()) |
| if np is not None: |
| feature_second_moment = np.zeros(feature_count, dtype=np.float64) |
| raw_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64) |
| else: |
| feature_second_moment = zeros_vector(feature_count) |
| raw_cross = zeros(len(embedding_model.id_to_token), feature_count) |
| example_weight_total = 0.0 |
| has_answer_targets = any(_answer_boundary(tokens) is not None for tokens in tokenized_documents) |
| if config.max_training_examples is None: |
| answer_reservoir_capacity = None |
| general_reservoir_capacity = None |
| elif config.max_training_examples <= 0: |
| answer_reservoir_capacity = 0 |
| general_reservoir_capacity = 0 |
| elif has_answer_targets: |
| answer_reservoir_capacity = max(1, int(config.max_training_examples * 0.75)) |
| general_reservoir_capacity = max(0, config.max_training_examples - answer_reservoir_capacity) |
| else: |
| answer_reservoir_capacity = 0 |
| general_reservoir_capacity = config.max_training_examples |
| answer_sequence_capacity = MAX_ANSWER_SEQUENCE_EXAMPLES if has_answer_targets else 0 |
| answer_reservoir = StateReservoir(answer_reservoir_capacity, seed=17) |
| general_reservoir = StateReservoir(general_reservoir_capacity, seed=13) |
| answer_intent_reservoir = StateReservoir(answer_reservoir_capacity, seed=29) |
| answer_start_reservoir = StateReservoir(answer_reservoir_capacity, seed=37) |
| answer_sequence_reservoir = SequenceReservoir(answer_sequence_capacity, seed=41) |
| moment_reservoir = StateReservoir( |
| config.max_training_examples if config.max_training_examples is not None else None, |
| seed=31, |
| ) |
| transitions = TransitionAccumulator( |
| max_contexts_per_order=config.max_transition_contexts_per_order, |
| max_next_tokens=config.max_transition_next_tokens, |
| ) |
| if np is not None: |
| target_label_mass = np.zeros(len(embedding_model.id_to_token), dtype=np.float64) |
| else: |
| target_label_mass = zeros_vector(len(embedding_model.id_to_token)) |
| for document, tokens in zip(documents, tokenized_documents): |
| answer_index = _answer_boundary(tokens) |
| for index in range(len(tokens) - 1): |
| next_token = tokens[index + 1] |
| if tokenizer is not None and next_token in tokenizer.special_tokens: |
| continue |
| next_token_id = embedding_model.token_to_id.get(next_token, -1) |
| if next_token_id < 0: |
| continue |
| label_weight = _readout_weight_for_target(answer_index, index + 1, document.weight) |
| if label_weight > 0.0: |
| target_label_mass[next_token_id] += label_weight |
| if np is not None: |
| positive_label_mass = target_label_mass[target_label_mass > 0.0] |
| reference_label_mass = ( |
| float(np.median(positive_label_mass)) |
| if positive_label_mass.size |
| else 1.0 |
| ) |
| target_balance = np.ones(len(embedding_model.id_to_token), dtype=np.float64) |
| np.divide( |
| reference_label_mass, |
| np.maximum(target_label_mass, 1e-12), |
| out=target_balance, |
| where=target_label_mass > 0.0, |
| ) |
| target_balance = np.clip(np.sqrt(target_balance), 0.25, 4.0) |
| else: |
| positive_label_mass = [value for value in target_label_mass if value > 0.0] |
| if positive_label_mass: |
| sorted_mass = sorted(positive_label_mass) |
| reference_label_mass = sorted_mass[len(sorted_mass) // 2] |
| else: |
| reference_label_mass = 1.0 |
| target_balance = [ |
| max(0.25, min(4.0, (reference_label_mass / max(value, 1e-12)) ** 0.5)) |
| if value > 0.0 |
| else 1.0 |
| for value in target_label_mass |
| ] |
| processed = 0 |
| embedding_array = ( |
| np.asarray(embedding_model.embeddings, dtype=RUNTIME_ARRAY_DTYPE) |
| if np is not None |
| else None |
| ) |
| trace_embedding_array = ( |
| model._build_trace_embedding_table_array(embedding_array) |
| if np is not None and embedding_array is not None |
| else None |
| ) |
| if np is not None: |
| trace_decay = np.asarray( |
| [1.0 / (1.0 + unit.timescale) for unit in model.memory_units], |
| dtype=RUNTIME_ARRAY_DTYPE, |
| ) |
| trace_gain = 1.0 - trace_decay |
| transition_stack = np.asarray( |
| [unit.transition for unit in model.memory_units], |
| dtype=RUNTIME_ARRAY_DTYPE, |
| ) |
| input_projection_stack = np.asarray( |
| [unit.input_projection for unit in model.memory_units], |
| dtype=RUNTIME_ARRAY_DTYPE, |
| ) |
| drive_indices = np.arange(config.state_dim, dtype=np.int64) |
| drive_primary = drive_indices % config.embedding_dim |
| drive_secondary = (3 * drive_indices + 1) % config.embedding_dim |
| drive_tertiary = (5 * drive_indices + 2) % config.embedding_dim |
| else: |
| trace_decay = None |
| trace_gain = None |
| transition_stack = None |
| input_projection_stack = None |
| drive_primary = None |
| drive_secondary = None |
| drive_tertiary = None |
| for document, tokens in zip(documents, tokenized_documents): |
| processed += 1 |
| if len(tokens) < 2: |
| _log_progress("state", processed, log_every) |
| continue |
|
|
| answer_index = _answer_boundary(tokens) |
| for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics( |
| tokens, |
| tokenizer, |
| document.weight, |
| ): |
| if len(token_sequence) > 1: |
| transitions.update_tokens(token_sequence, weight=sequence_weight) |
| if np is not None: |
| hidden_state_matrix = np.zeros((len(config.timescales), config.state_dim), dtype=RUNTIME_ARRAY_DTYPE) |
| context_trace_matrix = np.zeros((len(config.timescales), config.embedding_dim), dtype=RUNTIME_ARRAY_DTYPE) |
| else: |
| hidden_states = [zeros_vector(config.state_dim) for _ in config.timescales] |
| context_traces = [zeros_vector(config.embedding_dim) for _ in config.timescales] |
| answer_anchor_state = None |
| for index in range(len(tokens) - 1): |
| token = tokens[index] |
| token_id = embedding_model.token_to_id.get(token, -1) |
| if ( |
| np is not None |
| and embedding_array is not None |
| and trace_decay is not None |
| and trace_gain is not None |
| and transition_stack is not None |
| and input_projection_stack is not None |
| and drive_primary is not None |
| and drive_secondary is not None |
| and drive_tertiary is not None |
| and trace_embedding_array is not None |
| and token_id >= 0 |
| ): |
| embedding = embedding_array[token_id] |
| trace_embedding = trace_embedding_array[token_id] |
| drive = ( |
| embedding[drive_primary] |
| + (0.5 * embedding[drive_secondary]) |
| - (0.25 * embedding[drive_tertiary]) |
| ) |
| hidden_state_matrix = ( |
| (transition_stack @ hidden_state_matrix[:, :, None])[:, :, 0] |
| + (input_projection_stack * drive[None, :]) |
| ) |
| context_trace_matrix = ( |
| context_trace_matrix + (trace_gain[:, None] * trace_embedding[None, :]) |
| ) |
| else: |
| hidden_states, context_traces, combined_state = model._step_hidden_states( |
| hidden_states, |
| context_traces, |
| token, |
| ) |
| if token == "<answer>": |
| if np is not None: |
| answer_anchor_state = np.concatenate( |
| (hidden_state_matrix, context_trace_matrix), |
| axis=1, |
| ).reshape(-1).copy() |
| else: |
| answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:] |
| next_token = tokens[index + 1] |
| if next_token in tokenizer.special_tokens: |
| continue |
| next_token_id = embedding_model.token_to_id.get(next_token, -1) |
| if next_token_id < 0: |
| continue |
| raw_readout_weight = _readout_weight_for_target(answer_index, index + 1, document.weight) |
| readout_weight = raw_readout_weight * float(target_balance[next_token_id]) |
| if readout_weight <= 0.0: |
| continue |
| moment_slot = moment_reservoir.reserve_slot(weight=readout_weight) |
| is_answer_target = answer_index is not None and index + 1 > answer_index |
| target_reservoir = answer_reservoir if is_answer_target else general_reservoir |
| memory_weight = readout_weight * float(target_balance[next_token_id]) |
| answer_token_offset = ( |
| index - answer_index |
| if is_answer_target and answer_index is not None |
| else None |
| ) |
| intent_slot = ( |
| answer_intent_reservoir.reserve_slot(weight=memory_weight) |
| if is_answer_target and answer_anchor_state is not None |
| else None |
| ) |
| answer_start_weight = ( |
| raw_readout_weight * (ANSWER_START_DECAY ** answer_token_offset) |
| if ( |
| answer_token_offset is not None |
| and answer_token_offset < ANSWER_START_TOKEN_WINDOW |
| ) |
| else 0.0 |
| ) |
| answer_start_slot = ( |
| answer_start_reservoir.reserve_slot(weight=answer_start_weight) |
| if answer_start_weight > 0.0 and answer_anchor_state is not None |
| else None |
| ) |
| if np is not None: |
| reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight) |
| if moment_slot is not None or reservoir_slot is not None: |
| combined_state = np.concatenate( |
| (hidden_state_matrix, context_trace_matrix), |
| axis=1, |
| ).reshape(-1).copy() |
| if moment_slot is not None: |
| moment_reservoir.store_reserved( |
| moment_slot, |
| combined_state, |
| next_token_id, |
| example_weight=readout_weight, |
| ) |
| if reservoir_slot is not None: |
| target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id) |
| if intent_slot is not None: |
| answer_intent_reservoir.store_reserved( |
| intent_slot, |
| answer_anchor_state, |
| next_token_id, |
| example_weight=memory_weight, |
| ) |
| if answer_start_slot is not None: |
| answer_start_reservoir.store_reserved( |
| answer_start_slot, |
| answer_anchor_state, |
| next_token_id, |
| example_weight=answer_start_weight * float(target_balance[next_token_id]), |
| ) |
| else: |
| reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight) |
| if moment_slot is None and reservoir_slot is None: |
| continue |
| if moment_slot is not None: |
| moment_reservoir.store_reserved( |
| moment_slot, |
| combined_state, |
| next_token_id, |
| example_weight=readout_weight, |
| ) |
| if reservoir_slot is not None: |
| target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id) |
| if intent_slot is not None: |
| answer_intent_reservoir.store_reserved( |
| intent_slot, |
| answer_anchor_state, |
| next_token_id, |
| example_weight=memory_weight, |
| ) |
| if answer_start_slot is not None: |
| answer_start_reservoir.store_reserved( |
| answer_start_slot, |
| answer_anchor_state, |
| next_token_id, |
| example_weight=answer_start_weight * target_balance[next_token_id], |
| ) |
| if answer_anchor_state is not None and answer_index is not None: |
| prompt_token_ids = [ |
| embedding_model.token_to_id[token] |
| for token in tokens[:answer_index] |
| if token not in tokenizer.special_tokens |
| and token in embedding_model.token_to_id |
| ] |
| answer_token_ids = [ |
| embedding_model.token_to_id[token] |
| for token in tokens[answer_index + 1 :] |
| if token not in tokenizer.special_tokens |
| and token in embedding_model.token_to_id |
| ] |
| answer_sequence_reservoir.consider( |
| answer_anchor_state, |
| prompt_token_ids, |
| answer_token_ids, |
| weight=document.weight * ANSWER_READOUT_WEIGHT, |
| ) |
| _log_progress("state", processed, log_every) |
|
|
| moment_states = moment_reservoir.states |
| moment_labels = moment_reservoir.labels |
| moment_weights = moment_reservoir.weights |
| example_weight_total = sum(moment_weights) |
| if np is not None and moment_states: |
| state_matrix = np.asarray(moment_states, dtype=np.float64) |
| weight_vector = np.asarray(moment_weights, dtype=np.float64) |
| weighted_states = weight_vector[:, None] * state_matrix |
| feature_second_moment += (weighted_states * state_matrix).sum(axis=0) |
| np.add.at(raw_cross, moment_labels, weighted_states) |
| elif moment_states: |
| for state, label_id, readout_weight in zip(moment_states, moment_labels, moment_weights): |
| for feature, value in enumerate(state): |
| weighted_value = readout_weight * value |
| feature_second_moment[feature] += weighted_value * value |
| raw_cross[label_id][feature] += weighted_value |
|
|
| if example_weight_total <= 0.0: |
| raise ValueError("Streaming recompute did not collect any next-token training examples.") |
|
|
| if np is not None: |
| feature_energy = (feature_second_moment / example_weight_total).tolist() |
| else: |
| feature_energy = [ |
| feature_second_moment[index] / example_weight_total |
| for index in range(feature_count) |
| ] |
| ternary_scale, ternary_mask = derive_ternary_mask_from_feature_energy(feature_energy) |
| if np is not None: |
| diagonal = np.asarray([ternary_scale * value for value in ternary_mask], dtype=np.float64) |
| masked_feature_second_moment = feature_second_moment * diagonal * diagonal |
| masked_cross = raw_cross * diagonal[None, :] |
| else: |
| diagonal = [ternary_scale * value for value in ternary_mask] |
| masked_feature_second_moment = [ |
| feature_second_moment[index] * diagonal[index] * diagonal[index] |
| for index in range(feature_count) |
| ] |
| masked_cross = [ |
| [ |
| raw_cross[row][col] * diagonal[col] |
| for col in range(feature_count) |
| ] |
| for row in range(len(raw_cross)) |
| ] |
| readout_solver = "diagonal" |
| state_offset_values: object |
| readout_bias_values: object |
| if ( |
| np is not None |
| and moment_states |
| and feature_count <= FULL_READOUT_FEATURE_LIMIT |
| and len(moment_states) <= FULL_READOUT_EXAMPLE_LIMIT |
| ): |
| state_matrix = np.asarray(moment_states, dtype=np.float64) |
| weight_vector = np.asarray(moment_weights, dtype=np.float64) |
| label_array = np.asarray(moment_labels, dtype=np.int64) |
| masked_states = state_matrix * diagonal[None, :] |
| total_weight = float(weight_vector.sum()) |
| if total_weight <= 0.0: |
| total_weight = 1.0 |
| state_offset_values = (weight_vector[:, None] * masked_states).sum(axis=0) / total_weight |
| centered_states = masked_states - state_offset_values[None, :] |
| weighted_centered_states = weight_vector[:, None] * centered_states |
| gram = centered_states.T @ weighted_centered_states |
| full_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64) |
| np.add.at(full_cross, label_array, weighted_centered_states) |
| readout_bias_values = np.zeros(len(embedding_model.id_to_token), dtype=np.float64) |
| np.add.at(readout_bias_values, label_array, weight_vector) |
| readout_bias_values /= total_weight |
| readout_weights = ridge_regression_readout_from_moments( |
| gram, |
| full_cross, |
| regularization=config.regularization, |
| ) |
| readout_solver = "full" |
| else: |
| state_offset_values = ( |
| np.zeros(feature_count, dtype=np.float64) |
| if np is not None |
| else [0.0 for _ in range(feature_count)] |
| ) |
| if np is not None: |
| label_total = max(float(target_label_mass.sum()), 1.0) |
| readout_bias_values = target_label_mass / label_total |
| else: |
| label_total = max(sum(target_label_mass), 1.0) |
| readout_bias_values = [value / label_total for value in target_label_mass] |
| readout_weights = ridge_regression_readout_from_diagonal_moments( |
| masked_feature_second_moment, |
| masked_cross, |
| regularization=config.regularization, |
| ) |
| finish_stage("state_and_readout") |
|
|
| model.ternary_scale = ternary_scale |
| model.ternary_mask = ternary_mask |
| model.readout_weights = readout_weights |
| model.state_offset = ( |
| state_offset_values.tolist() |
| if hasattr(state_offset_values, "tolist") |
| else list(state_offset_values) |
| ) |
| model.readout_bias = ( |
| readout_bias_values.tolist() |
| if hasattr(readout_bias_values, "tolist") |
| else list(readout_bias_values) |
| ) |
| model.preference_bias, preference_state_pairs = _derive_preference_bias_from_pairs( |
| model, |
| preference_token_pairs, |
| tokenizer, |
| ) |
| finish_stage("preference") |
| reservoir_states = answer_reservoir.states + general_reservoir.states |
| reservoir_labels = answer_reservoir.labels + general_reservoir.labels |
| answer_intent_states = answer_intent_reservoir.states |
| answer_intent_labels = answer_intent_reservoir.labels |
| answer_start_states = answer_start_reservoir.states |
| answer_start_labels = answer_start_reservoir.labels |
| answer_sequence_states = answer_sequence_reservoir.keys |
| answer_sequence_prompt_rows = answer_sequence_reservoir.prompt_rows |
| answer_sequence_rows = answer_sequence_reservoir.token_rows |
| prompt_answer_weights, prompt_answer_bias, prompt_answer_readout_examples = ( |
| _solve_weighted_prompt_readout( |
| answer_intent_states, |
| answer_intent_labels, |
| answer_intent_reservoir.weights, |
| vocab_size=len(embedding_model.id_to_token), |
| diagonal=diagonal, |
| state_offset=state_offset_values, |
| regularization=config.regularization, |
| ) |
| ) |
| ( |
| prompt_answer_start_weights, |
| prompt_answer_start_bias, |
| prompt_answer_start_readout_examples, |
| ) = _solve_weighted_prompt_readout( |
| answer_start_states, |
| answer_start_labels, |
| answer_start_reservoir.weights, |
| vocab_size=len(embedding_model.id_to_token), |
| diagonal=diagonal, |
| state_offset=state_offset_values, |
| regularization=config.regularization, |
| ) |
| model.prompt_answer_weights = prompt_answer_weights |
| model.prompt_answer_bias = ( |
| prompt_answer_bias.tolist() |
| if hasattr(prompt_answer_bias, "tolist") |
| else list(prompt_answer_bias) |
| ) |
| model.prompt_answer_start_weights = prompt_answer_start_weights |
| model.prompt_answer_start_bias = ( |
| prompt_answer_start_bias.tolist() |
| if hasattr(prompt_answer_start_bias, "tolist") |
| else list(prompt_answer_start_bias) |
| ) |
| if np is not None and reservoir_states: |
| reservoir_array = np.asarray(reservoir_states, dtype=RUNTIME_ARRAY_DTYPE) |
| mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale |
| offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE) |
| associative_array = ((reservoir_array * mask_array[None, :]) - offset_array[None, :]).astype( |
| RUNTIME_ARRAY_DTYPE, |
| copy=False, |
| ) |
| model.associative_keys = associative_array |
| model.associative_key_norms = np.linalg.norm(associative_array, axis=1).tolist() |
| else: |
| offset_vector = model.state_offset |
| model.associative_keys = [ |
| [ |
| value - offset_vector[index] |
| for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale)) |
| ] |
| for state in reservoir_states |
| ] |
| model.associative_key_norms = [norm(state) for state in model.associative_keys] |
| model.associative_values = reservoir_labels[:] |
| if np is not None and answer_intent_states: |
| answer_intent_array = np.asarray(answer_intent_states, dtype=RUNTIME_ARRAY_DTYPE) |
| mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale |
| offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE) |
| answer_array = ((answer_intent_array * mask_array[None, :]) - offset_array[None, :]).astype( |
| RUNTIME_ARRAY_DTYPE, |
| copy=False, |
| ) |
| model.answer_keys = answer_array |
| model.answer_key_norms = np.linalg.norm(answer_array, axis=1).tolist() |
| else: |
| offset_vector = model.state_offset |
| model.answer_keys = [ |
| [ |
| value - offset_vector[index] |
| for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale)) |
| ] |
| for state in answer_intent_states |
| ] |
| model.answer_key_norms = [norm(state) for state in model.answer_keys] |
| model.answer_values = answer_intent_labels[:] |
| if np is not None and answer_start_states: |
| answer_start_array = np.asarray(answer_start_states, dtype=RUNTIME_ARRAY_DTYPE) |
| mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale |
| offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE) |
| start_array = ((answer_start_array * mask_array[None, :]) - offset_array[None, :]).astype( |
| RUNTIME_ARRAY_DTYPE, |
| copy=False, |
| ) |
| model.answer_start_keys = start_array |
| model.answer_start_key_norms = np.linalg.norm(start_array, axis=1).tolist() |
| else: |
| offset_vector = model.state_offset |
| model.answer_start_keys = [ |
| [ |
| value - offset_vector[index] |
| for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale)) |
| ] |
| for state in answer_start_states |
| ] |
| model.answer_start_key_norms = [norm(state) for state in model.answer_start_keys] |
| model.answer_start_values = answer_start_labels[:] |
| if np is not None and answer_sequence_states: |
| answer_sequence_array = np.asarray(answer_sequence_states, dtype=RUNTIME_ARRAY_DTYPE) |
| mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale |
| offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE) |
| sequence_array = ((answer_sequence_array * mask_array[None, :]) - offset_array[None, :]).astype( |
| RUNTIME_ARRAY_DTYPE, |
| copy=False, |
| ) |
| model.answer_sequence_keys = sequence_array |
| model.answer_sequence_key_norms = np.linalg.norm(sequence_array, axis=1).tolist() |
| else: |
| offset_vector = model.state_offset |
| model.answer_sequence_keys = [ |
| [ |
| value - offset_vector[index] |
| for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale)) |
| ] |
| for state in answer_sequence_states |
| ] |
| model.answer_sequence_key_norms = [norm(state) for state in model.answer_sequence_keys] |
| if np is not None: |
| padded_answer_sequences = np.full( |
| (len(answer_sequence_rows), MAX_ANSWER_SEQUENCE_TOKENS), |
| -1, |
| dtype=np.int32, |
| ) |
| for row_index, row in enumerate(answer_sequence_rows): |
| row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS) |
| if row_width > 0: |
| padded_answer_sequences[row_index, :row_width] = row[:row_width] |
| padded_answer_sequence_prompts = np.full( |
| (len(answer_sequence_prompt_rows), MAX_ANSWER_SEQUENCE_TOKENS), |
| -1, |
| dtype=np.int32, |
| ) |
| for row_index, row in enumerate(answer_sequence_prompt_rows): |
| row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS) |
| if row_width > 0: |
| padded_answer_sequence_prompts[row_index, :row_width] = row[:row_width] |
| else: |
| padded_answer_sequences = [ |
| row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))] |
| for row in answer_sequence_rows |
| ] |
| padded_answer_sequence_prompts = [ |
| row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))] |
| for row in answer_sequence_prompt_rows |
| ] |
| model.answer_sequence_prompt_tokens = padded_answer_sequence_prompts |
| model.answer_sequence_tokens = padded_answer_sequences |
| model.transition_tables = transitions.finalize( |
| max_contexts_per_order=config.max_transition_contexts_per_order, |
| max_next_tokens=config.max_transition_next_tokens, |
| ) |
| finish_stage("model_finalize") |
|
|
| payload = { |
| "streaming": True, |
| "documents_processed": processed, |
| "source_counts": source_counts, |
| "embedding_vocab_size": len(embedding_model.id_to_token), |
| "tokenizer_vocab_size": tokenizer.vocab_size, |
| "examples_processed": int(round(example_weight_total)), |
| "associative_examples": len(model.associative_keys), |
| "answer_associative_examples": len(answer_reservoir.states), |
| "general_associative_examples": len(general_reservoir.states), |
| "answer_intent_examples": len(model.answer_keys), |
| "answer_start_examples": len(model.answer_start_keys), |
| "answer_sequence_examples": len(model.answer_sequence_keys), |
| "prompt_answer_readout_examples": prompt_answer_readout_examples, |
| "prompt_answer_start_readout_examples": prompt_answer_start_readout_examples, |
| "stage_seconds": stage_seconds, |
| "target_balance_reference": round(float(reference_label_mass), 6), |
| "readout_solver": readout_solver, |
| "preference_pairs": len(preference_token_pairs), |
| "preference_state_pairs": preference_state_pairs, |
| } |
| return model, payload |
|
|