Reframr-RFM-v1-Base / reframr /streaming.py
OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
from __future__ import annotations
import json
import random
import re
import site
import sys
import time
from collections import Counter
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from pathlib import Path
from .config import ReframrConfig
from .corpus import build_vocabulary_from_counts
from .embeddings import fit_ppmi_embedding_from_cooccurrence, fit_randomized_ppmi_embedding_from_counts
from .hippo import AnalyticalMemoryUnit
from .linalg import Matrix, Vector, norm, zeros, zeros_vector
from .model import ReframrModel, RUNTIME_ARRAY_DTYPE, TRANSITION_ORDERS, np
from .reservoir import (
ridge_regression_readout_from_diagonal_moments,
ridge_regression_readout_from_moments,
)
from .ternary import apply_ternary_mask, derive_ternary_mask_from_feature_energy
from .text_quality import clean_answer_text, clean_context_text, clean_training_text
from .tokenizer import NativeTokenizer
try:
from scipy import sparse as scipy_sparse
except (ImportError, ModuleNotFoundError, OSError):
scipy_sparse = None
TEXT_FIELD_PREFERENCES = (
"text",
"content",
"body",
"article",
"document",
"passage",
"markdown",
"answer",
"response",
)
DIALOGUE_FIELD_PREFERENCES = (
"messages",
"conversation",
"conversations",
"dialogue",
"dialog",
"turns",
"chosen",
)
INSTRUCTION_FIELD_PAIRS = (
("instruction", "output"),
("prompt", "completion"),
("prompt", "response"),
("question", "answer"),
("question", "response"),
("query", "answer"),
("query", "response"),
)
TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE)
ROLE_ALIASES = {
"assistant": "assistant",
"assistant_response": "assistant",
"bot": "assistant",
"gpt": "assistant",
"model": "assistant",
"human": "user",
"prompter": "user",
"user": "user",
"customer": "user",
"system": "system",
}
ANSWER_READOUT_WEIGHT = 1.0
CONTEXT_READOUT_WEIGHT = 0.0
CONTEXT_STAT_WEIGHT = 0.02
PLAIN_TEXT_READOUT_WEIGHT = 0.03
PREFERENCE_REJECTED_TOKENIZER_WEIGHT = 0.0
PREFERENCE_BIAS_SCALE = 0.95
MAX_PREFERENCE_STATE_PAIRS = 512
ANSWER_START_TOKEN_WINDOW = 12
ANSWER_START_DECAY = 0.86
MAX_ANSWER_SEQUENCE_EXAMPLES = 196608
MAX_ANSWER_SEQUENCE_TOKENS = 192
HF_STREAM_MAX_RETRIES = 5
HF_STREAM_RETRY_BASE_DELAY_SECONDS = 0.25
FULL_READOUT_FEATURE_LIMIT = 2304
FULL_READOUT_EXAMPLE_LIMIT = 25000
@dataclass(slots=True)
class CorpusPlanEntry:
source: str
name: str
dataset: str = ""
path: str = ""
config: str | None = None
split: str = "train"
limit: int = 0
weight: float = 1.0
text_field: str | None = None
min_words: int = 0
max_words: int = 0
min_alpha_ratio: float = 0.0
allowed_languages: tuple[str, ...] = ()
records: tuple[object, ...] = ()
streaming: bool = True
trust_remote_code: bool = False
@dataclass(slots=True)
class StreamDocument:
text: str
weight: float
source: str
language: str = ""
preference_rejected_text: str = ""
class StreamingCooccurrenceAccumulator:
def __init__(self, token_to_id: dict[str, int], window_size: int) -> None:
self.token_to_id = token_to_id
self.window_size = window_size
self.rows: dict[int, dict[int, float]] = {}
def update_tokens(self, tokens: list[str], *, weight: float) -> None:
token_ids = [self.token_to_id[token] for token in tokens if token in self.token_to_id]
for index, token_id in enumerate(token_ids):
for offset in range(1, self.window_size + 1):
other_index = index + offset
if other_index >= len(token_ids):
break
other_id = token_ids[other_index]
delta = weight * (1.0 / offset)
self.rows.setdefault(token_id, {})[other_id] = (
self.rows.setdefault(token_id, {}).get(other_id, 0.0) + delta
)
self.rows.setdefault(other_id, {})[token_id] = (
self.rows.setdefault(other_id, {}).get(token_id, 0.0) + delta
)
def to_dense(self) -> Matrix:
size = len(self.token_to_id)
matrix = zeros(size, size)
for row, columns in self.rows.items():
for col, value in columns.items():
matrix[row][col] = value
return matrix
def to_sparse(self) -> object:
if scipy_sparse is None or np is None:
return self.to_dense()
rows: list[int] = []
cols: list[int] = []
data: list[float] = []
for row, columns in self.rows.items():
for col, value in columns.items():
rows.append(row)
cols.append(col)
data.append(value)
size = len(self.token_to_id)
return scipy_sparse.coo_matrix(
(
np.asarray(data, dtype=np.float64),
(np.asarray(rows, dtype=np.int64), np.asarray(cols, dtype=np.int64)),
),
shape=(size, size),
dtype=np.float64,
).tocsr()
class TransitionAccumulator:
def __init__(
self,
*,
max_contexts_per_order: int | None = None,
max_next_tokens: int = 0,
) -> None:
self.max_contexts_per_order = max_contexts_per_order
self.max_next_tokens = max_next_tokens
self.context_soft_limit = (
max_contexts_per_order * 4
if max_contexts_per_order is not None and max_contexts_per_order > 0
else None
)
self.next_token_soft_limit = max_next_tokens * 4 if max_next_tokens > 0 else None
self.counts: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
def update_tokens(self, tokens: list[str], *, weight: float) -> None:
for order in sorted(TRANSITION_ORDERS):
order_counts = self.counts[order]
for index in range(order - 1, len(tokens) - 1):
key = tuple(tokens[index - order + 1 : index + 1])
nxt = tokens[index + 1]
if (
self.context_soft_limit is not None
and key not in order_counts
and len(order_counts) >= self.context_soft_limit
):
continue
bucket = order_counts.setdefault(key, {})
if (
self.next_token_soft_limit is not None
and nxt not in bucket
and len(bucket) >= self.next_token_soft_limit
):
continue
bucket[nxt] = bucket.get(nxt, 0.0) + weight
def finalize(
self,
*,
max_contexts_per_order: int | None,
max_next_tokens: int,
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order, mapping in self.counts.items():
items = list(mapping.items())
items.sort(key=lambda item: (-sum(item[1].values()), item[0]))
if max_contexts_per_order is not None and max_contexts_per_order >= 0:
items = items[:max_contexts_per_order]
for key, bucket in items:
next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0]))
if max_next_tokens > 0:
next_items = next_items[:max_next_tokens]
total = sum(value for _, value in next_items)
if total <= 0.0:
continue
probabilities[order][key] = {
token: value / total
for token, value in next_items
}
return probabilities
class StateReservoir:
def __init__(self, capacity: int | None, *, seed: int = 13) -> None:
self.capacity = capacity
self.random = random.Random(seed)
self.states: list[Vector] = []
self.labels: list[int] = []
self.weights: list[float] = []
self.seen = 0
self.total_weight = 0.0
def reserve_slot(self, weight: float = 1.0) -> int | None:
if weight <= 0.0:
return None
self.seen += 1
self.total_weight += weight
if self.capacity is None:
return len(self.states)
if self.capacity <= 0:
return None
if len(self.states) < self.capacity:
return len(self.states)
keep_probability = min(1.0, (self.capacity * weight) / max(self.total_weight, 1e-12))
if self.random.random() >= keep_probability:
return None
return self.random.randrange(self.capacity)
def store_reserved(
self,
slot: int,
state: Vector,
label_id: int,
*,
example_weight: float = 1.0,
) -> None:
stored_state = state.copy() if hasattr(state, "copy") else state[:]
if slot == len(self.states):
self.states.append(stored_state)
self.labels.append(label_id)
self.weights.append(example_weight)
elif 0 <= slot < len(self.states):
self.states[slot] = stored_state
self.labels[slot] = label_id
self.weights[slot] = example_weight
def consider(self, state: Vector, label_id: int, weight: float = 1.0) -> None:
slot = self.reserve_slot(weight=weight)
if slot is not None:
self.store_reserved(slot, state, label_id, example_weight=weight)
class SequenceReservoir:
def __init__(self, capacity: int | None, *, seed: int = 41) -> None:
self.capacity = capacity
self.random = random.Random(seed)
self.keys: list[Vector] = []
self.prompt_rows: list[list[int]] = []
self.token_rows: list[list[int]] = []
self.weights: list[float] = []
self.seen_weight = 0.0
def reserve_slot(self, *, weight: float = 1.0) -> int | None:
if self.capacity == 0 or weight <= 0.0:
return None
self.seen_weight += weight
if self.capacity is None or len(self.keys) < self.capacity:
return len(self.keys)
probability = min(1.0, (self.capacity * weight) / max(self.seen_weight, 1e-12))
if self.random.random() >= probability:
return None
return self.random.randrange(self.capacity)
def store_reserved(
self,
slot: int,
key: Vector,
prompt_token_ids: list[int],
token_ids: list[int],
*,
example_weight: float = 1.0,
) -> None:
key_copy = key.tolist() if hasattr(key, "tolist") else list(key)
prompt_row = prompt_token_ids[:MAX_ANSWER_SEQUENCE_TOKENS]
row = token_ids[:MAX_ANSWER_SEQUENCE_TOKENS]
if self.capacity is None or slot >= len(self.keys):
self.keys.append(key_copy)
self.prompt_rows.append(prompt_row)
self.token_rows.append(row)
self.weights.append(example_weight)
return
self.keys[slot] = key_copy
self.prompt_rows[slot] = prompt_row
self.token_rows[slot] = row
self.weights[slot] = example_weight
def consider(
self,
key: Vector,
prompt_token_ids: list[int],
token_ids: list[int],
weight: float = 1.0,
) -> None:
if not token_ids:
return
slot = self.reserve_slot(weight=weight)
if slot is not None:
self.store_reserved(slot, key, prompt_token_ids, token_ids, example_weight=weight)
def _word_count(text: str) -> int:
return len(text.split())
def _alpha_ratio(text: str) -> float:
if not text:
return 0.0
alpha_count = sum(character.isalpha() for character in text)
return alpha_count / len(text)
def _row_language(row: dict[str, object]) -> str:
for candidate in ("lang", "language", "locale"):
value = row.get(candidate)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def _normalize_role(raw_role: object) -> str:
role = str(raw_role or "").strip().casefold()
return ROLE_ALIASES.get(role, role)
def _message_content(message: dict[str, object]) -> str:
for field in ("content", "value", "text", "message"):
value = message.get(field)
if isinstance(value, str) and value.strip():
return clean_training_text(value)
return ""
def _message_role(message: dict[str, object]) -> str:
for field in ("role", "from", "speaker", "author"):
value = message.get(field)
if value is not None:
normalized = _normalize_role(value)
if normalized:
return normalized
return ""
def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]:
if not isinstance(raw_messages, list):
return []
parsed: list[dict[str, str]] = []
for message in raw_messages:
if not isinstance(message, dict):
continue
role = _message_role(message)
content = _message_content(message)
if role not in {"system", "user", "assistant"} or not content:
continue
parsed.append({"role": role, "content": content})
return parsed
def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]:
if not isinstance(raw_text, str):
return []
text = raw_text.strip()
if not text:
return []
matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text))
if not matches:
return []
parsed: list[dict[str, str]] = []
for index, match in enumerate(matches):
role = _normalize_role(match.group(1))
start = match.end()
end = matches[index + 1].start() if index + 1 < len(matches) else len(text)
content = clean_training_text(text[start:end].strip())
if role in {"system", "user", "assistant"} and content:
parsed.append({"role": role, "content": content})
return parsed
def _render_prompt(messages: list[dict[str, str]]) -> str:
parts = []
for message in messages:
content = clean_context_text(message["content"])
if content:
parts.append(content)
return "\n".join(parts).strip()
def _last_user_prompt_before(messages: list[dict[str, str]], end_index: int) -> str:
for message in reversed(messages[:end_index]):
if message["role"] == "user":
return clean_context_text(message["content"])
return _render_prompt(messages[:end_index])
def _compose_training_text(context: object, answer: object) -> str:
prompt_text = clean_context_text(_flatten_value(context))
answer_text = clean_answer_text(_flatten_value(answer))
if prompt_text and answer_text:
return f"<reason> {prompt_text} <answer> {answer_text}".strip()
return clean_training_text(answer_text or prompt_text)
def _compose_from_messages(messages: list[dict[str, str]]) -> str:
assistant_index = None
for index in range(len(messages) - 1, -1, -1):
if messages[index]["role"] == "assistant":
assistant_index = index
break
if assistant_index is not None:
prompt = _last_user_prompt_before(messages, assistant_index)
answer = clean_answer_text(messages[assistant_index]["content"])
if prompt and answer:
return f"<reason> {prompt} <answer> {answer}".strip()
return "\n".join(
message["content"]
for message in messages
if message.get("content")
).strip()
def _flatten_message_list(messages: object) -> str:
parsed = _parse_dialogue_messages(messages)
if parsed:
return _compose_from_messages(parsed)
if not isinstance(messages, list):
return ""
parts: list[str] = []
for message in messages:
if not isinstance(message, dict):
continue
content = str(
message.get("content", message.get("value", message.get("text", "")))
).strip()
if not content:
continue
parts.append(clean_training_text(content))
return "\n".join(parts).strip()
def _flatten_value(value: object) -> str:
if isinstance(value, str):
parsed = _parse_transcript_messages(value)
if parsed:
return _compose_from_messages(parsed)
return clean_training_text(value.strip())
if isinstance(value, list):
return _flatten_message_list(value)
if isinstance(value, dict):
for field in ("messages", "conversation", "conversations", "dialogue", "turns"):
nested_messages = value.get(field)
text = _flatten_message_list(nested_messages)
if text:
return text
for field in ("text", "content", "value", "message"):
nested = value.get(field)
if isinstance(nested, str) and nested.strip():
return _flatten_value(nested)
return ""
def _safe_flag(value: object) -> bool | None:
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().casefold()
if normalized in {"true", "1", "yes", "safe"}:
return True
if normalized in {"false", "0", "no", "unsafe"}:
return False
return None
def _selected_response_fields(row: dict[str, object]) -> tuple[str, str]:
if "response_0" not in row or "response_1" not in row:
return "", ""
safe_0 = _safe_flag(row.get("is_response_0_safe"))
safe_1 = _safe_flag(row.get("is_response_1_safe"))
if safe_0 is not None and safe_1 is not None:
if safe_0 and not safe_1:
return "response_0", "response_1"
if safe_1 and not safe_0:
return "response_1", "response_0"
if safe_0 and safe_1:
return "response_0", ""
return "", ""
for selector in ("safer_response_id", "better_response_id"):
raw_value = row.get(selector)
try:
preferred = int(raw_value)
except (TypeError, ValueError):
continue
chosen = "response_1" if preferred == 1 else "response_0"
rejected = "response_0" if chosen == "response_1" else "response_1"
return chosen, rejected
return "response_0", "response_1"
def _extract_preference_pair(row: dict[str, object]) -> tuple[str, str]:
if "chosen" in row and "rejected" in row:
chosen_text = clean_training_text(_flatten_value(row.get("chosen")))
rejected_text = clean_training_text(_flatten_value(row.get("rejected")))
if chosen_text and rejected_text:
return chosen_text, rejected_text
if "response_0" in row and "response_1" in row:
preferred_field, rejected_field = _selected_response_fields(row)
if not preferred_field or not rejected_field:
return "", ""
prompt = row.get("prompt", row.get("question", row.get("query", "")))
if prompt:
chosen_text = _compose_training_text(prompt, row.get(preferred_field))
rejected_text = _compose_training_text(prompt, row.get(rejected_field))
if chosen_text and rejected_text:
return clean_training_text(chosen_text), clean_training_text(rejected_text)
chosen_text = clean_training_text(_flatten_value(row.get(preferred_field)))
rejected_text = clean_training_text(_flatten_value(row.get(rejected_field)))
if chosen_text and rejected_text:
return chosen_text, rejected_text
return "", ""
def _extract_preference_value(row: dict[str, object]) -> str:
chosen_text, _ = _extract_preference_pair(row)
return chosen_text
def _extract_row_text(row: dict[str, object], text_field: str | None) -> str:
if "context" in row and "answer" in row:
context = clean_context_text(_flatten_value(row.get("context")))
answer = clean_answer_text(_flatten_value(row.get("answer")))
if context and answer:
return f"<reason> {context} <answer> {answer}".strip()
if "response_0" in row and "response_1" in row:
preferred_field, _ = _selected_response_fields(row)
prompt = row.get("prompt", row.get("question", row.get("query", "")))
if preferred_field and prompt:
text = _compose_training_text(prompt, row.get(preferred_field))
if text:
return text
for prompt_field, answer_field in INSTRUCTION_FIELD_PAIRS:
if prompt_field in row and answer_field in row:
text = _compose_training_text(row.get(prompt_field), row.get(answer_field))
if text:
return text
if text_field is not None:
return clean_training_text(_flatten_value(row.get(text_field)))
preferred = _extract_preference_value(row)
if preferred:
return clean_training_text(preferred)
for field in TEXT_FIELD_PREFERENCES:
text = _flatten_value(row.get(field))
if text:
return clean_training_text(text)
for field in DIALOGUE_FIELD_PREFERENCES:
text = _flatten_value(row.get(field))
if text:
return clean_training_text(text)
return ""
def _passes_text_quality(text: str, language: str, entry: CorpusPlanEntry) -> bool:
if not text:
return False
word_count = _word_count(text)
if entry.min_words > 0 and word_count < entry.min_words:
return False
if entry.max_words > 0 and word_count > entry.max_words:
return False
if entry.min_alpha_ratio > 0.0 and _alpha_ratio(text) < entry.min_alpha_ratio:
return False
if entry.allowed_languages:
if not language or language.casefold() not in entry.allowed_languages:
return False
return True
def load_corpus_plan(source: str | Path) -> list[CorpusPlanEntry]:
payload = json.loads(Path(source).read_text(encoding="utf-8-sig"))
raw_entries = payload.get("sources", payload.get("datasets", []))
if not isinstance(raw_entries, list) or not raw_entries:
raise ValueError("Corpus plan must define a non-empty 'sources' list.")
entries: list[CorpusPlanEntry] = []
for index, raw_entry in enumerate(raw_entries, start=1):
if not isinstance(raw_entry, dict):
raise ValueError("Each corpus plan entry must be an object.")
source = str(raw_entry.get("source", "hf")).strip() or "hf"
name = str(
raw_entry.get("name", raw_entry.get("dataset", f"source-{index}"))
).strip() or f"source-{index}"
raw_languages = raw_entry.get("allowed_languages", [])
allowed_languages = tuple(
str(value).strip().casefold()
for value in raw_languages
if str(value).strip()
) if isinstance(raw_languages, list) else ()
raw_records = raw_entry.get("records", raw_entry.get("texts", []))
if source == "inline" and not isinstance(raw_records, list):
raise ValueError("Inline corpus plan entries must provide a records/texts list.")
entries.append(
CorpusPlanEntry(
source=source,
name=name,
dataset=str(raw_entry.get("dataset", "")),
path=str(raw_entry.get("path", raw_entry.get("file", ""))),
config=(
str(raw_entry["config"])
if raw_entry.get("config") is not None
else None
),
split=str(raw_entry.get("split", "train")),
limit=int(raw_entry.get("limit", 0)),
weight=float(raw_entry.get("weight", 1.0)),
text_field=(
str(raw_entry["text_field"])
if raw_entry.get("text_field") is not None
else None
),
min_words=int(raw_entry.get("min_words", 0)),
max_words=int(raw_entry.get("max_words", 0)),
min_alpha_ratio=float(raw_entry.get("min_alpha_ratio", 0.0)),
allowed_languages=allowed_languages,
records=tuple(raw_records) if isinstance(raw_records, list) else (),
streaming=bool(raw_entry.get("streaming", True)),
trust_remote_code=bool(raw_entry.get("trust_remote_code", False)),
)
)
return entries
def _iter_hf_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
try:
from datasets import load_dataset
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
from datasets import load_dataset
dataset_kwargs: dict[str, object] = {
"split": entry.split,
"streaming": entry.streaming,
}
if entry.config:
dataset_kwargs["name"] = entry.config
if entry.trust_remote_code:
dataset_kwargs["trust_remote_code"] = True
for row in load_dataset(entry.dataset, **dataset_kwargs):
yield dict(row)
def _iter_file_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
raw_path = entry.path or entry.dataset
if not raw_path:
raise ValueError("File corpus plan entries must provide a path.")
path = Path(raw_path)
suffix = path.suffix.lower()
if suffix == ".jsonl":
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.strip():
row = json.loads(line)
yield row if isinstance(row, dict) else {"text": str(row)}
return
if suffix == ".json":
payload = json.loads(path.read_text(encoding="utf-8"))
if isinstance(payload, list):
for row in payload:
yield row if isinstance(row, dict) else {"text": str(row)}
return
if isinstance(payload, dict):
rows = payload.get("records", payload.get("texts"))
if isinstance(rows, list):
for row in rows:
yield row if isinstance(row, dict) else {"text": str(row)}
return
yield payload
return
if suffix in {".txt", ".md", ".text"}:
yield {"text": path.read_text(encoding="utf-8")}
return
raise ValueError(f"Unsupported file corpus source: {path}")
def iter_corpus_plan_documents(plan: Iterable[CorpusPlanEntry]) -> Iterator[StreamDocument]:
for entry in plan:
accepted = 0
attempts = 0
while True:
accepted_seen_this_attempt = 0
try:
if entry.source == "inline":
row_iterator = (
item if isinstance(item, dict) else {"text": str(item)}
for item in entry.records
)
elif entry.source == "hf":
row_iterator = _iter_hf_rows(entry)
elif entry.source == "file":
row_iterator = _iter_file_rows(entry)
else:
raise ValueError(f"Unsupported corpus plan source: {entry.source}")
for row in row_iterator:
language = _row_language(row)
_, rejected_text = _extract_preference_pair(row)
text = clean_training_text(_extract_row_text(row, entry.text_field))
if not _passes_text_quality(text, language, entry):
continue
accepted_seen_this_attempt += 1
if accepted_seen_this_attempt <= accepted:
continue
yield StreamDocument(
text=text,
weight=entry.weight,
source=entry.name,
language=language,
preference_rejected_text=rejected_text,
)
accepted += 1
if entry.limit > 0 and accepted >= entry.limit:
break
break
except Exception as exc:
if entry.source != "hf":
raise
if attempts >= HF_STREAM_MAX_RETRIES:
print(
f"[source] {entry.name} skipped after {attempts} retries; "
f"accepted {accepted} documents before final error: {exc}"
)
break
attempts += 1
delay = min(
15.0,
HF_STREAM_RETRY_BASE_DELAY_SECONDS * (2 ** (attempts - 1)),
)
print(
f"[source] {entry.name} stream interrupted after {accepted} accepted "
f"documents; retry {attempts}/{HF_STREAM_MAX_RETRIES} in {delay:.2f}s: {exc}"
)
time.sleep(delay)
def _log_progress(label: str, processed: int, log_every: int) -> None:
if log_every > 0 and processed % log_every == 0:
print(f"[{label}] processed {processed} documents")
def _answer_boundary(tokens: list[str]) -> int | None:
try:
return tokens.index("<answer>")
except ValueError:
return None
def _weighted_text_parts_for_statistics(text: str, document_weight: float) -> list[tuple[str, float]]:
if "<answer>" not in text:
return [(text, document_weight)]
context, answer = text.split("<answer>", 1)
context = clean_context_text(context.replace("<reason>", " "))
answer = clean_answer_text(answer)
parts: list[tuple[str, float]] = []
if context:
parts.append((context, document_weight * CONTEXT_STAT_WEIGHT))
if answer:
parts.append((answer, document_weight * ANSWER_READOUT_WEIGHT))
return parts or [(text, document_weight)]
def _weighted_token_sequences_for_statistics(
tokens: list[str],
tokenizer: NativeTokenizer,
document_weight: float,
) -> list[tuple[list[str], float]]:
answer_index = _answer_boundary(tokens)
if answer_index is None:
sequence = [token for token in tokens if token not in tokenizer.special_tokens]
return [(sequence, document_weight)] if sequence else []
context_tokens = [
token for token in tokens[:answer_index] if token not in tokenizer.special_tokens
]
answer_tokens = [
token for token in tokens[answer_index + 1 :] if token not in tokenizer.special_tokens
]
sequences: list[tuple[list[str], float]] = []
if context_tokens:
sequences.append((context_tokens, document_weight * CONTEXT_STAT_WEIGHT))
if answer_tokens:
sequences.append((answer_tokens, document_weight * ANSWER_READOUT_WEIGHT))
return sequences
def _readout_weight_for_target(
answer_index: int | None,
target_index: int,
document_weight: float,
) -> float:
if answer_index is None:
return document_weight * PLAIN_TEXT_READOUT_WEIGHT
if target_index <= answer_index:
return document_weight * CONTEXT_READOUT_WEIGHT
return document_weight * ANSWER_READOUT_WEIGHT
def _answer_payload_tokens(tokens: list[str], tokenizer: NativeTokenizer) -> list[str]:
answer_index = _answer_boundary(tokens)
payload = tokens[answer_index + 1 :] if answer_index is not None else tokens
return [token for token in payload if token not in tokenizer.special_tokens]
def _standardized_preference_bias(values: object, active_mask: object | None = None) -> list[float]:
if np is not None:
bias = np.asarray(values, dtype=np.float64)
if bias.size == 0:
return []
active = (
np.asarray(active_mask, dtype=bool)
if active_mask is not None
else np.ones(bias.shape, dtype=bool)
)
if not np.any(active):
return [0.0 for _ in range(int(bias.size))]
active_values = bias[active]
spread = float(active_values.std())
if spread <= 1e-12:
return [0.0 for _ in range(int(bias.size))]
standardized = np.zeros_like(bias, dtype=np.float64)
standardized[active] = (
(active_values - float(active_values.mean())) / spread
) * PREFERENCE_BIAS_SCALE
return np.clip(standardized, -2.5, 2.5).astype(float).tolist()
raw_values = [float(value) for value in values]
if not raw_values:
return []
average = sum(raw_values) / len(raw_values)
variance = sum((value - average) * (value - average) for value in raw_values) / len(raw_values)
spread = variance**0.5
if spread <= 1e-12:
return [0.0 for _ in raw_values]
active_indices = (
[
index
for index, active in enumerate(active_mask)
if active
]
if active_mask is not None
else list(range(len(raw_values)))
)
if not active_indices:
return [0.0 for _ in raw_values]
active_values = [raw_values[index] for index in active_indices]
average = mean(active_values)
spread = (mean([(value - average) * (value - average) for value in active_values])) ** 0.5
if spread <= 1e-12:
return [0.0 for _ in raw_values]
standardized = [0.0 for _ in raw_values]
for index in active_indices:
standardized[index] = max(
-2.5,
min(2.5, ((raw_values[index] - average) / spread) * PREFERENCE_BIAS_SCALE),
)
return standardized
def _candidate_preference_bias_from_state_vector(
model: ReframrModel,
preference_state: object,
) -> object:
if np is None:
return None
assert model.embedding_model is not None
assert model.memory_units is not None
assert model.ternary_mask is not None
embeddings = np.asarray(model.embedding_model.embeddings, dtype=np.float64)
if embeddings.size == 0:
return np.zeros(0, dtype=np.float64)
state_vector = np.asarray(preference_state, dtype=np.float64)
mask = np.asarray(model.ternary_mask, dtype=np.float64) * float(model.ternary_scale)
if state_vector.shape[0] != mask.shape[0]:
return np.zeros(embeddings.shape[0], dtype=np.float64)
state_indices = np.arange(model.config.state_dim, dtype=np.int64)
drive = (
embeddings[:, state_indices % model.config.embedding_dim]
+ (0.5 * embeddings[:, (3 * state_indices + 1) % model.config.embedding_dim])
- (0.25 * embeddings[:, (5 * state_indices + 2) % model.config.embedding_dim])
)
scores = np.zeros(embeddings.shape[0], dtype=np.float64)
offset = 0
for unit in model.memory_units:
hidden_end = offset + model.config.state_dim
trace_end = hidden_end + model.config.embedding_dim
hidden_pref = state_vector[offset:hidden_end] * mask[offset:hidden_end]
trace_pref = state_vector[hidden_end:trace_end] * mask[hidden_end:trace_end]
hidden_delta_axis = np.asarray(unit.input_projection, dtype=np.float64) * hidden_pref
trace_gain = 1.0 - (1.0 / (1.0 + unit.timescale))
scores += drive @ hidden_delta_axis
scores += embeddings @ (trace_gain * trace_pref)
offset = trace_end
return scores
def _derive_preference_bias_from_pairs(
model: ReframrModel,
preference_token_pairs: list[tuple[list[str], list[str], float]],
tokenizer: NativeTokenizer,
) -> tuple[list[float], int]:
assert model.embedding_model is not None
vocab_size = len(model.embedding_model.id_to_token)
if not preference_token_pairs:
return [0.0 for _ in range(vocab_size)], 0
if np is not None:
token_bias = np.zeros(vocab_size, dtype=np.float64)
active_token_mask = np.zeros(vocab_size, dtype=bool)
state_delta = np.zeros(model._combined_state_width(), dtype=np.float64)
else:
token_bias = [0.0 for _ in range(vocab_size)]
active_token_ids: set[int] = set()
state_delta = [0.0 for _ in range(model._combined_state_width())]
pair_weight_total = 0.0
state_pair_count = 0
state_stride = max(
1,
(len(preference_token_pairs) + MAX_PREFERENCE_STATE_PAIRS - 1)
// MAX_PREFERENCE_STATE_PAIRS,
)
for pair_index, (chosen_tokens, rejected_tokens, pair_weight) in enumerate(preference_token_pairs):
chosen_answer = _answer_payload_tokens(chosen_tokens, tokenizer)
rejected_answer = _answer_payload_tokens(rejected_tokens, tokenizer)
if chosen_answer:
delta = pair_weight / max(1, len(chosen_answer))
for token in chosen_answer:
token_id = model.embedding_model.token_to_id.get(token)
if token_id is not None:
token_bias[token_id] += delta
if np is not None:
active_token_mask[token_id] = True
else:
active_token_ids.add(token_id)
if rejected_answer:
delta = pair_weight / max(1, len(rejected_answer))
for token in rejected_answer:
token_id = model.embedding_model.token_to_id.get(token)
if token_id is not None:
token_bias[token_id] -= delta
if np is not None:
active_token_mask[token_id] = True
else:
active_token_ids.add(token_id)
if pair_index % state_stride != 0 or state_pair_count >= MAX_PREFERENCE_STATE_PAIRS:
continue
chosen_state = model._masked_decode_state(model._build_decode_state(chosen_tokens))
rejected_state = model._masked_decode_state(model._build_decode_state(rejected_tokens))
if len(chosen_state) != len(rejected_state):
continue
pair_weight_total += pair_weight
state_pair_count += 1
if np is not None:
state_delta += pair_weight * (
np.asarray(chosen_state, dtype=np.float64)
- np.asarray(rejected_state, dtype=np.float64)
)
else:
for index, (chosen_value, rejected_value) in enumerate(zip(chosen_state, rejected_state)):
state_delta[index] += pair_weight * (chosen_value - rejected_value)
if pair_weight_total > 0.0:
if np is not None:
state_delta = state_delta / pair_weight_total
candidate_bias = _candidate_preference_bias_from_state_vector(model, state_delta)
if candidate_bias is not None:
token_bias[active_token_mask] = (
token_bias[active_token_mask] + candidate_bias[active_token_mask]
)
else:
state_delta = [value / pair_weight_total for value in state_delta]
if np is not None:
return _standardized_preference_bias(token_bias, active_token_mask), state_pair_count
active_mask = [index in active_token_ids for index in range(vocab_size)]
return _standardized_preference_bias(token_bias, active_mask), state_pair_count
def _solve_weighted_prompt_readout(
states: list[Vector],
labels: list[int],
weights: list[float],
*,
vocab_size: int,
diagonal: object,
state_offset: object,
regularization: float,
) -> tuple[object, object, int]:
if np is None or not states or not labels or not weights:
return [], [0.0 for _ in range(vocab_size)], 0
state_matrix = np.asarray(states, dtype=np.float64)
label_array = np.asarray(labels, dtype=np.int64)
weight_vector = np.asarray(weights, dtype=np.float64)
valid_mask = (
(label_array >= 0)
& (label_array < vocab_size)
& (weight_vector > 0.0)
)
if not np.any(valid_mask):
return [], [0.0 for _ in range(vocab_size)], 0
state_matrix = state_matrix[valid_mask]
label_array = label_array[valid_mask]
weight_vector = weight_vector[valid_mask]
diagonal_array = np.asarray(diagonal, dtype=np.float64)
offset_array = np.asarray(state_offset, dtype=np.float64)
if (
len(state_matrix.shape) != 2
or diagonal_array.shape[0] != state_matrix.shape[1]
or offset_array.shape[0] != state_matrix.shape[1]
):
return [], [0.0 for _ in range(vocab_size)], 0
masked_states = state_matrix * diagonal_array[None, :]
centered_states = masked_states - offset_array[None, :]
weighted_centered_states = weight_vector[:, None] * centered_states
gram = centered_states.T @ weighted_centered_states
cross = np.zeros((vocab_size, centered_states.shape[1]), dtype=np.float64)
np.add.at(cross, label_array, weighted_centered_states)
total_weight = float(weight_vector.sum())
if total_weight <= 0.0:
return [], [0.0 for _ in range(vocab_size)], 0
bias = np.zeros(vocab_size, dtype=np.float64)
np.add.at(bias, label_array, weight_vector)
bias /= total_weight
readout = ridge_regression_readout_from_moments(
gram,
cross,
regularization=regularization,
)
return readout, bias, int(label_array.shape[0])
def fit_model_from_corpus_plan(
plan: Iterable[CorpusPlanEntry],
config: ReframrConfig,
*,
log_every: int = 0,
) -> tuple[ReframrModel, dict[str, object]]:
entries = list(plan)
if not entries:
raise ValueError("Cannot fit REFRAMR without any corpus plan entries.")
stage_seconds: dict[str, float] = {}
stage_started = time.perf_counter()
def finish_stage(name: str) -> None:
nonlocal stage_started
now = time.perf_counter()
elapsed = round(now - stage_started, 6)
stage_seconds[name] = elapsed
if log_every > 0:
print(f"[stage] {name} finished in {elapsed:.3f}s")
stage_started = now
seed_tokenizer = NativeTokenizer(
merges=[],
vocab=[],
base_symbols=[],
lowercase=config.lowercase,
)
segment_counts: Counter[str] = Counter()
source_counts: dict[str, int] = {}
documents: list[StreamDocument] = []
processed = 0
for entry in entries:
if log_every > 0:
print(f"[source] {entry.name} started")
source_start = processed
for document in iter_corpus_plan_documents([entry]):
documents.append(document)
processed += 1
source_counts[document.source] = source_counts.get(document.source, 0) + 1
for text_part, part_weight in _weighted_text_parts_for_statistics(
document.text,
document.weight,
):
for segment in seed_tokenizer.pretokenize(text_part):
segment_counts[segment] += part_weight
if document.preference_rejected_text:
rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT
for text_part, part_weight in _weighted_text_parts_for_statistics(
document.preference_rejected_text,
rejected_weight,
):
for segment in seed_tokenizer.pretokenize(text_part):
segment_counts[segment] += part_weight
_log_progress("tokenizer", processed, log_every)
if log_every > 0:
print(f"[source] {entry.name} accepted {processed - source_start} documents")
if processed == 0:
raise ValueError("Corpus plan did not yield any usable documents after filtering.")
finish_stage("stream_and_segment")
tokenizer = NativeTokenizer.train_from_segment_counts(
segment_counts,
vocab_size=config.tokenizer_vocab_size,
min_pair_frequency=config.tokenizer_min_pair_frequency,
lowercase=config.lowercase,
)
finish_stage("tokenizer_fit")
token_counts: Counter[str] = Counter()
raw_tokenized_documents: list[list[str]] = []
raw_rejected_tokenized_documents: list[list[str]] = []
processed = 0
for document in documents:
processed += 1
tokens = tokenizer.encode(document.text)
raw_tokenized_documents.append(tokens)
for token in tokens:
if token in tokenizer.special_tokens:
token_counts[token] += document.weight
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
tokens,
tokenizer,
document.weight,
):
for token in token_sequence:
token_counts[token] += sequence_weight
rejected_tokens = (
tokenizer.encode(document.preference_rejected_text)
if document.preference_rejected_text
else []
)
raw_rejected_tokenized_documents.append(rejected_tokens)
rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT
for token in rejected_tokens:
if token in tokenizer.special_tokens:
token_counts[token] += rejected_weight
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
rejected_tokens,
tokenizer,
rejected_weight,
):
for token in token_sequence:
token_counts[token] += sequence_weight
_log_progress("vocab", processed, log_every)
token_to_id, id_to_token = build_vocabulary_from_counts(
token_counts,
min_frequency=config.min_frequency,
max_vocab=config.max_vocab,
)
if not id_to_token:
raise ValueError("Streaming recompute could not derive an embedding vocabulary.")
finish_stage("vocabulary")
cooccurrence = StreamingCooccurrenceAccumulator(token_to_id, config.window_size)
tokenized_documents: list[list[str]] = []
preference_token_pairs: list[tuple[list[str], list[str], float]] = []
processed = 0
for document, raw_tokens, raw_rejected_tokens in zip(
documents,
raw_tokenized_documents,
raw_rejected_tokenized_documents,
):
processed += 1
tokens = [token for token in raw_tokens if token in token_to_id]
tokenized_documents.append(tokens)
rejected_tokens = [token for token in raw_rejected_tokens if token in token_to_id]
if len(tokens) > 1 and len(rejected_tokens) > 1:
preference_token_pairs.append((tokens, rejected_tokens, document.weight))
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
tokens,
tokenizer,
document.weight,
):
if len(token_sequence) > 1:
cooccurrence.update_tokens(token_sequence, weight=sequence_weight)
_log_progress("cooccurrence", processed, log_every)
finish_stage("cooccurrence")
if np is not None:
embedding_model = fit_randomized_ppmi_embedding_from_counts(
id_to_token,
cooccurrence.rows,
embedding_dim=config.embedding_dim,
)
else:
embedding_model = fit_ppmi_embedding_from_cooccurrence(
id_to_token,
cooccurrence.to_sparse(),
embedding_dim=config.embedding_dim,
)
finish_stage("embedding")
model = ReframrModel(config)
model.tokenizer = tokenizer
model.embedding_model = embedding_model
model.memory_units = [
AnalyticalMemoryUnit(config.state_dim, timescale)
for timescale in config.timescales
]
model.trace_token_weights = model._derive_trace_token_weights_from_counts(token_counts)
feature_count = len(model._zero_combined_state())
if np is not None:
feature_second_moment = np.zeros(feature_count, dtype=np.float64)
raw_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
else:
feature_second_moment = zeros_vector(feature_count)
raw_cross = zeros(len(embedding_model.id_to_token), feature_count)
example_weight_total = 0.0
has_answer_targets = any(_answer_boundary(tokens) is not None for tokens in tokenized_documents)
if config.max_training_examples is None:
answer_reservoir_capacity = None
general_reservoir_capacity = None
elif config.max_training_examples <= 0:
answer_reservoir_capacity = 0
general_reservoir_capacity = 0
elif has_answer_targets:
answer_reservoir_capacity = max(1, int(config.max_training_examples * 0.75))
general_reservoir_capacity = max(0, config.max_training_examples - answer_reservoir_capacity)
else:
answer_reservoir_capacity = 0
general_reservoir_capacity = config.max_training_examples
answer_sequence_capacity = MAX_ANSWER_SEQUENCE_EXAMPLES if has_answer_targets else 0
answer_reservoir = StateReservoir(answer_reservoir_capacity, seed=17)
general_reservoir = StateReservoir(general_reservoir_capacity, seed=13)
answer_intent_reservoir = StateReservoir(answer_reservoir_capacity, seed=29)
answer_start_reservoir = StateReservoir(answer_reservoir_capacity, seed=37)
answer_sequence_reservoir = SequenceReservoir(answer_sequence_capacity, seed=41)
moment_reservoir = StateReservoir(
config.max_training_examples if config.max_training_examples is not None else None,
seed=31,
)
transitions = TransitionAccumulator(
max_contexts_per_order=config.max_transition_contexts_per_order,
max_next_tokens=config.max_transition_next_tokens,
)
if np is not None:
target_label_mass = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
else:
target_label_mass = zeros_vector(len(embedding_model.id_to_token))
for document, tokens in zip(documents, tokenized_documents):
answer_index = _answer_boundary(tokens)
for index in range(len(tokens) - 1):
next_token = tokens[index + 1]
if tokenizer is not None and next_token in tokenizer.special_tokens:
continue
next_token_id = embedding_model.token_to_id.get(next_token, -1)
if next_token_id < 0:
continue
label_weight = _readout_weight_for_target(answer_index, index + 1, document.weight)
if label_weight > 0.0:
target_label_mass[next_token_id] += label_weight
if np is not None:
positive_label_mass = target_label_mass[target_label_mass > 0.0]
reference_label_mass = (
float(np.median(positive_label_mass))
if positive_label_mass.size
else 1.0
)
target_balance = np.ones(len(embedding_model.id_to_token), dtype=np.float64)
np.divide(
reference_label_mass,
np.maximum(target_label_mass, 1e-12),
out=target_balance,
where=target_label_mass > 0.0,
)
target_balance = np.clip(np.sqrt(target_balance), 0.25, 4.0)
else:
positive_label_mass = [value for value in target_label_mass if value > 0.0]
if positive_label_mass:
sorted_mass = sorted(positive_label_mass)
reference_label_mass = sorted_mass[len(sorted_mass) // 2]
else:
reference_label_mass = 1.0
target_balance = [
max(0.25, min(4.0, (reference_label_mass / max(value, 1e-12)) ** 0.5))
if value > 0.0
else 1.0
for value in target_label_mass
]
processed = 0
embedding_array = (
np.asarray(embedding_model.embeddings, dtype=RUNTIME_ARRAY_DTYPE)
if np is not None
else None
)
trace_embedding_array = (
model._build_trace_embedding_table_array(embedding_array)
if np is not None and embedding_array is not None
else None
)
if np is not None:
trace_decay = np.asarray(
[1.0 / (1.0 + unit.timescale) for unit in model.memory_units],
dtype=RUNTIME_ARRAY_DTYPE,
)
trace_gain = 1.0 - trace_decay
transition_stack = np.asarray(
[unit.transition for unit in model.memory_units],
dtype=RUNTIME_ARRAY_DTYPE,
)
input_projection_stack = np.asarray(
[unit.input_projection for unit in model.memory_units],
dtype=RUNTIME_ARRAY_DTYPE,
)
drive_indices = np.arange(config.state_dim, dtype=np.int64)
drive_primary = drive_indices % config.embedding_dim
drive_secondary = (3 * drive_indices + 1) % config.embedding_dim
drive_tertiary = (5 * drive_indices + 2) % config.embedding_dim
else:
trace_decay = None
trace_gain = None
transition_stack = None
input_projection_stack = None
drive_primary = None
drive_secondary = None
drive_tertiary = None
for document, tokens in zip(documents, tokenized_documents):
processed += 1
if len(tokens) < 2:
_log_progress("state", processed, log_every)
continue
answer_index = _answer_boundary(tokens)
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
tokens,
tokenizer,
document.weight,
):
if len(token_sequence) > 1:
transitions.update_tokens(token_sequence, weight=sequence_weight)
if np is not None:
hidden_state_matrix = np.zeros((len(config.timescales), config.state_dim), dtype=RUNTIME_ARRAY_DTYPE)
context_trace_matrix = np.zeros((len(config.timescales), config.embedding_dim), dtype=RUNTIME_ARRAY_DTYPE)
else:
hidden_states = [zeros_vector(config.state_dim) for _ in config.timescales]
context_traces = [zeros_vector(config.embedding_dim) for _ in config.timescales]
answer_anchor_state = None
for index in range(len(tokens) - 1):
token = tokens[index]
token_id = embedding_model.token_to_id.get(token, -1)
if (
np is not None
and embedding_array is not None
and trace_decay is not None
and trace_gain is not None
and transition_stack is not None
and input_projection_stack is not None
and drive_primary is not None
and drive_secondary is not None
and drive_tertiary is not None
and trace_embedding_array is not None
and token_id >= 0
):
embedding = embedding_array[token_id]
trace_embedding = trace_embedding_array[token_id]
drive = (
embedding[drive_primary]
+ (0.5 * embedding[drive_secondary])
- (0.25 * embedding[drive_tertiary])
)
hidden_state_matrix = (
(transition_stack @ hidden_state_matrix[:, :, None])[:, :, 0]
+ (input_projection_stack * drive[None, :])
)
context_trace_matrix = (
context_trace_matrix + (trace_gain[:, None] * trace_embedding[None, :])
)
else:
hidden_states, context_traces, combined_state = model._step_hidden_states(
hidden_states,
context_traces,
token,
)
if token == "<answer>":
if np is not None:
answer_anchor_state = np.concatenate(
(hidden_state_matrix, context_trace_matrix),
axis=1,
).reshape(-1).copy()
else:
answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:]
next_token = tokens[index + 1]
if next_token in tokenizer.special_tokens:
continue
next_token_id = embedding_model.token_to_id.get(next_token, -1)
if next_token_id < 0:
continue
raw_readout_weight = _readout_weight_for_target(answer_index, index + 1, document.weight)
readout_weight = raw_readout_weight * float(target_balance[next_token_id])
if readout_weight <= 0.0:
continue
moment_slot = moment_reservoir.reserve_slot(weight=readout_weight)
is_answer_target = answer_index is not None and index + 1 > answer_index
target_reservoir = answer_reservoir if is_answer_target else general_reservoir
memory_weight = readout_weight * float(target_balance[next_token_id])
answer_token_offset = (
index - answer_index
if is_answer_target and answer_index is not None
else None
)
intent_slot = (
answer_intent_reservoir.reserve_slot(weight=memory_weight)
if is_answer_target and answer_anchor_state is not None
else None
)
answer_start_weight = (
raw_readout_weight * (ANSWER_START_DECAY ** answer_token_offset)
if (
answer_token_offset is not None
and answer_token_offset < ANSWER_START_TOKEN_WINDOW
)
else 0.0
)
answer_start_slot = (
answer_start_reservoir.reserve_slot(weight=answer_start_weight)
if answer_start_weight > 0.0 and answer_anchor_state is not None
else None
)
if np is not None:
reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight)
if moment_slot is not None or reservoir_slot is not None:
combined_state = np.concatenate(
(hidden_state_matrix, context_trace_matrix),
axis=1,
).reshape(-1).copy()
if moment_slot is not None:
moment_reservoir.store_reserved(
moment_slot,
combined_state,
next_token_id,
example_weight=readout_weight,
)
if reservoir_slot is not None:
target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id)
if intent_slot is not None:
answer_intent_reservoir.store_reserved(
intent_slot,
answer_anchor_state,
next_token_id,
example_weight=memory_weight,
)
if answer_start_slot is not None:
answer_start_reservoir.store_reserved(
answer_start_slot,
answer_anchor_state,
next_token_id,
example_weight=answer_start_weight * float(target_balance[next_token_id]),
)
else:
reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight)
if moment_slot is None and reservoir_slot is None:
continue
if moment_slot is not None:
moment_reservoir.store_reserved(
moment_slot,
combined_state,
next_token_id,
example_weight=readout_weight,
)
if reservoir_slot is not None:
target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id)
if intent_slot is not None:
answer_intent_reservoir.store_reserved(
intent_slot,
answer_anchor_state,
next_token_id,
example_weight=memory_weight,
)
if answer_start_slot is not None:
answer_start_reservoir.store_reserved(
answer_start_slot,
answer_anchor_state,
next_token_id,
example_weight=answer_start_weight * target_balance[next_token_id],
)
if answer_anchor_state is not None and answer_index is not None:
prompt_token_ids = [
embedding_model.token_to_id[token]
for token in tokens[:answer_index]
if token not in tokenizer.special_tokens
and token in embedding_model.token_to_id
]
answer_token_ids = [
embedding_model.token_to_id[token]
for token in tokens[answer_index + 1 :]
if token not in tokenizer.special_tokens
and token in embedding_model.token_to_id
]
answer_sequence_reservoir.consider(
answer_anchor_state,
prompt_token_ids,
answer_token_ids,
weight=document.weight * ANSWER_READOUT_WEIGHT,
)
_log_progress("state", processed, log_every)
moment_states = moment_reservoir.states
moment_labels = moment_reservoir.labels
moment_weights = moment_reservoir.weights
example_weight_total = sum(moment_weights)
if np is not None and moment_states:
state_matrix = np.asarray(moment_states, dtype=np.float64)
weight_vector = np.asarray(moment_weights, dtype=np.float64)
weighted_states = weight_vector[:, None] * state_matrix
feature_second_moment += (weighted_states * state_matrix).sum(axis=0)
np.add.at(raw_cross, moment_labels, weighted_states)
elif moment_states:
for state, label_id, readout_weight in zip(moment_states, moment_labels, moment_weights):
for feature, value in enumerate(state):
weighted_value = readout_weight * value
feature_second_moment[feature] += weighted_value * value
raw_cross[label_id][feature] += weighted_value
if example_weight_total <= 0.0:
raise ValueError("Streaming recompute did not collect any next-token training examples.")
if np is not None:
feature_energy = (feature_second_moment / example_weight_total).tolist()
else:
feature_energy = [
feature_second_moment[index] / example_weight_total
for index in range(feature_count)
]
ternary_scale, ternary_mask = derive_ternary_mask_from_feature_energy(feature_energy)
if np is not None:
diagonal = np.asarray([ternary_scale * value for value in ternary_mask], dtype=np.float64)
masked_feature_second_moment = feature_second_moment * diagonal * diagonal
masked_cross = raw_cross * diagonal[None, :]
else:
diagonal = [ternary_scale * value for value in ternary_mask]
masked_feature_second_moment = [
feature_second_moment[index] * diagonal[index] * diagonal[index]
for index in range(feature_count)
]
masked_cross = [
[
raw_cross[row][col] * diagonal[col]
for col in range(feature_count)
]
for row in range(len(raw_cross))
]
readout_solver = "diagonal"
state_offset_values: object
readout_bias_values: object
if (
np is not None
and moment_states
and feature_count <= FULL_READOUT_FEATURE_LIMIT
and len(moment_states) <= FULL_READOUT_EXAMPLE_LIMIT
):
state_matrix = np.asarray(moment_states, dtype=np.float64)
weight_vector = np.asarray(moment_weights, dtype=np.float64)
label_array = np.asarray(moment_labels, dtype=np.int64)
masked_states = state_matrix * diagonal[None, :]
total_weight = float(weight_vector.sum())
if total_weight <= 0.0:
total_weight = 1.0
state_offset_values = (weight_vector[:, None] * masked_states).sum(axis=0) / total_weight
centered_states = masked_states - state_offset_values[None, :]
weighted_centered_states = weight_vector[:, None] * centered_states
gram = centered_states.T @ weighted_centered_states
full_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
np.add.at(full_cross, label_array, weighted_centered_states)
readout_bias_values = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
np.add.at(readout_bias_values, label_array, weight_vector)
readout_bias_values /= total_weight
readout_weights = ridge_regression_readout_from_moments(
gram,
full_cross,
regularization=config.regularization,
)
readout_solver = "full"
else:
state_offset_values = (
np.zeros(feature_count, dtype=np.float64)
if np is not None
else [0.0 for _ in range(feature_count)]
)
if np is not None:
label_total = max(float(target_label_mass.sum()), 1.0)
readout_bias_values = target_label_mass / label_total
else:
label_total = max(sum(target_label_mass), 1.0)
readout_bias_values = [value / label_total for value in target_label_mass]
readout_weights = ridge_regression_readout_from_diagonal_moments(
masked_feature_second_moment,
masked_cross,
regularization=config.regularization,
)
finish_stage("state_and_readout")
model.ternary_scale = ternary_scale
model.ternary_mask = ternary_mask
model.readout_weights = readout_weights
model.state_offset = (
state_offset_values.tolist()
if hasattr(state_offset_values, "tolist")
else list(state_offset_values)
)
model.readout_bias = (
readout_bias_values.tolist()
if hasattr(readout_bias_values, "tolist")
else list(readout_bias_values)
)
model.preference_bias, preference_state_pairs = _derive_preference_bias_from_pairs(
model,
preference_token_pairs,
tokenizer,
)
finish_stage("preference")
reservoir_states = answer_reservoir.states + general_reservoir.states
reservoir_labels = answer_reservoir.labels + general_reservoir.labels
answer_intent_states = answer_intent_reservoir.states
answer_intent_labels = answer_intent_reservoir.labels
answer_start_states = answer_start_reservoir.states
answer_start_labels = answer_start_reservoir.labels
answer_sequence_states = answer_sequence_reservoir.keys
answer_sequence_prompt_rows = answer_sequence_reservoir.prompt_rows
answer_sequence_rows = answer_sequence_reservoir.token_rows
prompt_answer_weights, prompt_answer_bias, prompt_answer_readout_examples = (
_solve_weighted_prompt_readout(
answer_intent_states,
answer_intent_labels,
answer_intent_reservoir.weights,
vocab_size=len(embedding_model.id_to_token),
diagonal=diagonal,
state_offset=state_offset_values,
regularization=config.regularization,
)
)
(
prompt_answer_start_weights,
prompt_answer_start_bias,
prompt_answer_start_readout_examples,
) = _solve_weighted_prompt_readout(
answer_start_states,
answer_start_labels,
answer_start_reservoir.weights,
vocab_size=len(embedding_model.id_to_token),
diagonal=diagonal,
state_offset=state_offset_values,
regularization=config.regularization,
)
model.prompt_answer_weights = prompt_answer_weights
model.prompt_answer_bias = (
prompt_answer_bias.tolist()
if hasattr(prompt_answer_bias, "tolist")
else list(prompt_answer_bias)
)
model.prompt_answer_start_weights = prompt_answer_start_weights
model.prompt_answer_start_bias = (
prompt_answer_start_bias.tolist()
if hasattr(prompt_answer_start_bias, "tolist")
else list(prompt_answer_start_bias)
)
if np is not None and reservoir_states:
reservoir_array = np.asarray(reservoir_states, dtype=RUNTIME_ARRAY_DTYPE)
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
associative_array = ((reservoir_array * mask_array[None, :]) - offset_array[None, :]).astype(
RUNTIME_ARRAY_DTYPE,
copy=False,
)
model.associative_keys = associative_array
model.associative_key_norms = np.linalg.norm(associative_array, axis=1).tolist()
else:
offset_vector = model.state_offset
model.associative_keys = [
[
value - offset_vector[index]
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
]
for state in reservoir_states
]
model.associative_key_norms = [norm(state) for state in model.associative_keys]
model.associative_values = reservoir_labels[:]
if np is not None and answer_intent_states:
answer_intent_array = np.asarray(answer_intent_states, dtype=RUNTIME_ARRAY_DTYPE)
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
answer_array = ((answer_intent_array * mask_array[None, :]) - offset_array[None, :]).astype(
RUNTIME_ARRAY_DTYPE,
copy=False,
)
model.answer_keys = answer_array
model.answer_key_norms = np.linalg.norm(answer_array, axis=1).tolist()
else:
offset_vector = model.state_offset
model.answer_keys = [
[
value - offset_vector[index]
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
]
for state in answer_intent_states
]
model.answer_key_norms = [norm(state) for state in model.answer_keys]
model.answer_values = answer_intent_labels[:]
if np is not None and answer_start_states:
answer_start_array = np.asarray(answer_start_states, dtype=RUNTIME_ARRAY_DTYPE)
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
start_array = ((answer_start_array * mask_array[None, :]) - offset_array[None, :]).astype(
RUNTIME_ARRAY_DTYPE,
copy=False,
)
model.answer_start_keys = start_array
model.answer_start_key_norms = np.linalg.norm(start_array, axis=1).tolist()
else:
offset_vector = model.state_offset
model.answer_start_keys = [
[
value - offset_vector[index]
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
]
for state in answer_start_states
]
model.answer_start_key_norms = [norm(state) for state in model.answer_start_keys]
model.answer_start_values = answer_start_labels[:]
if np is not None and answer_sequence_states:
answer_sequence_array = np.asarray(answer_sequence_states, dtype=RUNTIME_ARRAY_DTYPE)
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
sequence_array = ((answer_sequence_array * mask_array[None, :]) - offset_array[None, :]).astype(
RUNTIME_ARRAY_DTYPE,
copy=False,
)
model.answer_sequence_keys = sequence_array
model.answer_sequence_key_norms = np.linalg.norm(sequence_array, axis=1).tolist()
else:
offset_vector = model.state_offset
model.answer_sequence_keys = [
[
value - offset_vector[index]
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
]
for state in answer_sequence_states
]
model.answer_sequence_key_norms = [norm(state) for state in model.answer_sequence_keys]
if np is not None:
padded_answer_sequences = np.full(
(len(answer_sequence_rows), MAX_ANSWER_SEQUENCE_TOKENS),
-1,
dtype=np.int32,
)
for row_index, row in enumerate(answer_sequence_rows):
row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS)
if row_width > 0:
padded_answer_sequences[row_index, :row_width] = row[:row_width]
padded_answer_sequence_prompts = np.full(
(len(answer_sequence_prompt_rows), MAX_ANSWER_SEQUENCE_TOKENS),
-1,
dtype=np.int32,
)
for row_index, row in enumerate(answer_sequence_prompt_rows):
row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS)
if row_width > 0:
padded_answer_sequence_prompts[row_index, :row_width] = row[:row_width]
else:
padded_answer_sequences = [
row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))]
for row in answer_sequence_rows
]
padded_answer_sequence_prompts = [
row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))]
for row in answer_sequence_prompt_rows
]
model.answer_sequence_prompt_tokens = padded_answer_sequence_prompts
model.answer_sequence_tokens = padded_answer_sequences
model.transition_tables = transitions.finalize(
max_contexts_per_order=config.max_transition_contexts_per_order,
max_next_tokens=config.max_transition_next_tokens,
)
finish_stage("model_finalize")
payload = {
"streaming": True,
"documents_processed": processed,
"source_counts": source_counts,
"embedding_vocab_size": len(embedding_model.id_to_token),
"tokenizer_vocab_size": tokenizer.vocab_size,
"examples_processed": int(round(example_weight_total)),
"associative_examples": len(model.associative_keys),
"answer_associative_examples": len(answer_reservoir.states),
"general_associative_examples": len(general_reservoir.states),
"answer_intent_examples": len(model.answer_keys),
"answer_start_examples": len(model.answer_start_keys),
"answer_sequence_examples": len(model.answer_sequence_keys),
"prompt_answer_readout_examples": prompt_answer_readout_examples,
"prompt_answer_start_readout_examples": prompt_answer_start_readout_examples,
"stage_seconds": stage_seconds,
"target_balance_reference": round(float(reference_label_mass), 6),
"readout_solver": readout_solver,
"preference_pairs": len(preference_token_pairs),
"preference_state_pairs": preference_state_pairs,
}
return model, payload