Reframr-RFM-v2-Base / reframr /streaming.py
OkeyMeta's picture
Add Reframr-RFM-v2-Base release files
52da7b7 verified
from __future__ import annotations
import hashlib
import heapq
import json
import os
import random
import re
import site
import sys
import time
from collections import Counter
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from pathlib import Path
from .config import ReframrConfig
from .corpus import build_vocabulary_from_counts
from .embeddings import (
complete_id_to_token,
fit_ppmi_embedding_from_cooccurrence,
fit_randomized_ppmi_embedding_from_counts,
)
from .hippo import (
HAS_COMPILED_HIPPO_KERNEL,
AnalyticalMemoryUnit,
hippo_document_combined_states_fast,
hippo_legs_propagate_stack_fast,
)
from .linalg import Matrix, Vector, norm, zeros, zeros_vector
from .model import ReframrModel, RUNTIME_ARRAY_DTYPE, TRANSITION_ORDERS, np
from .reasoning import TOOL_PROTOCOL_TOKENS
from .reservoir import (
ridge_regression_readout_from_diagonal_moments,
ridge_regression_readout_from_moments,
)
from .ternary import apply_ternary_mask, derive_ternary_mask_from_feature_energy
from .text_quality import (
clean_answer_text,
clean_context_text,
clean_training_text,
has_machine_artifacts,
)
from .tokenizer import NativeTokenizer
try:
from scipy import sparse as scipy_sparse
except (ImportError, ModuleNotFoundError, OSError):
scipy_sparse = None
TEXT_FIELD_PREFERENCES = (
"text",
"content",
"body",
"article",
"document",
"passage",
"markdown",
"answer",
"response",
)
DIALOGUE_FIELD_PREFERENCES = (
"messages",
"conversation",
"conversations",
"dialogue",
"dialog",
"turns",
"chosen",
"chat",
)
INSTRUCTION_FIELD_PAIRS = (
("instruction", "output"),
("prompt", "completion"),
("prompt", "response"),
("question", "answer"),
("question", "response"),
("query", "answer"),
("query", "response"),
)
TRANSCRIPT_ROLE_PATTERN = re.compile(
r"(?:^|\n\s*\n)(Human|Assistant|System|User|Function Response|Function|Tool)\s*:\s*",
re.IGNORECASE,
)
ROLE_ALIASES = {
"assistant": "assistant",
"assistant_response": "assistant",
"bot": "assistant",
"gpt": "assistant",
"model": "assistant",
"human": "user",
"prompter": "user",
"user": "user",
"customer": "user",
"system": "system",
"function": "tool",
"function response": "tool",
"tool": "tool",
"tool_result": "tool",
}
TOOL_DEFINITION_FIELDS = ("tools_json", "tools", "functions", "available_tools")
ANSWER_READOUT_WEIGHT = 1.0
CONTEXT_READOUT_WEIGHT = 0.0
CONTEXT_STAT_WEIGHT = 0.02
PLAIN_TEXT_READOUT_WEIGHT = 0.03
PREFERENCE_REJECTED_TOKENIZER_WEIGHT = 0.0
PREFERENCE_BIAS_SCALE = 0.95
MAX_PREFERENCE_STATE_PAIRS = 512
ANSWER_START_TOKEN_WINDOW = 12
ANSWER_START_DECAY = 0.86
MAX_ANSWER_SEQUENCE_EXAMPLES = 196608
MAX_ANSWER_SEQUENCE_TOKENS = 192
TARGET_BALANCE_MIN = 0.55
TARGET_BALANCE_MAX = 2.0
PREALLOCATED_RESERVOIR_MIN_CAPACITY = 4096
HF_STREAM_MAX_RETRIES = 5
HF_STREAM_RETRY_BASE_DELAY_SECONDS = 0.25
HF_VIEWER_REQUEST_TIMEOUT_SECONDS = 10.0
HF_VIEWER_MIN_REQUEST_TIMEOUT_SECONDS = 1.0
FULL_READOUT_FEATURE_LIMIT = 2304
FULL_READOUT_EXAMPLE_LIMIT = 25000
def _is_learned_output_token(token: str, tokenizer: NativeTokenizer) -> bool:
if token not in tokenizer.special_tokens:
return True
return token in TOOL_PROTOCOL_TOKENS
@dataclass(slots=True)
class CorpusPlanEntry:
source: str
name: str
dataset: str = ""
path: str = ""
config: str | None = None
split: str = "train"
limit: int = 0
weight: float = 1.0
text_field: str | None = None
min_words: int = 0
max_words: int = 0
min_alpha_ratio: float = 0.0
allowed_languages: tuple[str, ...] = ()
records: tuple[object, ...] = ()
streaming: bool = True
trust_remote_code: bool = False
max_seconds: float = 0.0
readout_weight: float = 1.0
transition_weight: float = 1.0
@dataclass(slots=True)
class StreamDocument:
text: str
weight: float
source: str
language: str = ""
preference_rejected_text: str = ""
readout_weight: float = 1.0
transition_weight: float = 1.0
class StreamingCooccurrenceAccumulator:
def __init__(self, token_to_id: dict[str, int], window_size: int) -> None:
self.token_to_id = token_to_id
self.window_size = window_size
self.rows: dict[int, dict[int, float]] = {}
def update_tokens(self, tokens: list[str], *, weight: float) -> None:
token_ids = [self.token_to_id[token] for token in tokens if token in self.token_to_id]
for index, token_id in enumerate(token_ids):
for offset in range(1, self.window_size + 1):
other_index = index + offset
if other_index >= len(token_ids):
break
other_id = token_ids[other_index]
delta = weight * (1.0 / offset)
self.rows.setdefault(token_id, {})[other_id] = (
self.rows.setdefault(token_id, {}).get(other_id, 0.0) + delta
)
self.rows.setdefault(other_id, {})[token_id] = (
self.rows.setdefault(other_id, {}).get(token_id, 0.0) + delta
)
def to_dense(self) -> Matrix:
size = len(self.token_to_id)
matrix = zeros(size, size)
for row, columns in self.rows.items():
for col, value in columns.items():
matrix[row][col] = value
return matrix
def to_sparse(self) -> object:
if scipy_sparse is None or np is None:
return self.to_dense()
rows: list[int] = []
cols: list[int] = []
data: list[float] = []
for row, columns in self.rows.items():
for col, value in columns.items():
rows.append(row)
cols.append(col)
data.append(value)
size = len(self.token_to_id)
return scipy_sparse.coo_matrix(
(
np.asarray(data, dtype=np.float64),
(np.asarray(rows, dtype=np.int64), np.asarray(cols, dtype=np.int64)),
),
shape=(size, size),
dtype=np.float64,
).tocsr()
class TransitionAccumulator:
def __init__(
self,
*,
max_contexts_per_order: int | None = None,
max_next_tokens: int = 0,
) -> None:
self.max_contexts_per_order = max_contexts_per_order
self.max_next_tokens = max_next_tokens
self.context_soft_limit = (
max_contexts_per_order * 4
if max_contexts_per_order is not None and max_contexts_per_order > 0
else None
)
self.next_token_soft_limit = max_next_tokens * 4 if max_next_tokens > 0 else None
self.counts: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
self.context_totals: dict[int, dict[tuple[str, ...], float]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
def update_tokens(self, tokens: list[str], *, weight: float) -> None:
for order in sorted(TRANSITION_ORDERS):
order_counts = self.counts[order]
for index in range(order - 1, len(tokens) - 1):
key = tuple(tokens[index - order + 1 : index + 1])
nxt = tokens[index + 1]
if (
self.context_soft_limit is not None
and key not in order_counts
and len(order_counts) >= self.context_soft_limit
):
continue
bucket = order_counts.setdefault(key, {})
self.context_totals[order][key] = self.context_totals[order].get(key, 0.0) + weight
if (
self.next_token_soft_limit is not None
and nxt not in bucket
and len(bucket) >= self.next_token_soft_limit
):
continue
bucket[nxt] = bucket.get(nxt, 0.0) + weight
def finalize(
self,
*,
max_contexts_per_order: int | None,
max_next_tokens: int,
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order, mapping in self.counts.items():
totals = self.context_totals.get(order, {})
if max_contexts_per_order is not None and max_contexts_per_order >= 0:
keys = heapq.nlargest(
max_contexts_per_order,
mapping.keys(),
key=lambda key: totals.get(key, 0.0),
)
items = ((key, mapping[key]) for key in keys)
else:
items = mapping.items()
for key, bucket in items:
if max_next_tokens > 0 and len(bucket) > max_next_tokens:
next_items = heapq.nlargest(
max_next_tokens,
bucket.items(),
key=lambda item: item[1],
)
else:
next_items = bucket.items()
total = sum(value for _, value in next_items)
if total <= 0.0:
continue
probabilities[order][key] = {
token: value / total
for token, value in next_items
}
return probabilities
def finalize_tensor_cache(
self,
*,
token_to_id: dict[str, int],
max_contexts_per_order: int | None,
max_next_tokens: int,
) -> dict[str, object]:
orders: list[int] = []
key_offsets: list[int] = [0]
key_token_ids: list[int] = []
next_offsets: list[int] = [0]
next_token_ids: list[int] = []
next_probabilities: list[float] = []
for order in sorted(TRANSITION_ORDERS):
mapping = self.counts.get(order, {})
totals = self.context_totals.get(order, {})
if max_contexts_per_order is not None and max_contexts_per_order >= 0:
keys = heapq.nlargest(
max_contexts_per_order,
mapping.keys(),
key=lambda key: totals.get(key, 0.0),
)
items = ((key, mapping[key]) for key in keys)
else:
items = mapping.items()
for key, bucket in items:
key_ids = [token_to_id.get(token, -1) for token in key]
if len(key_ids) != order or any(token_id < 0 for token_id in key_ids):
continue
if max_next_tokens > 0 and len(bucket) > max_next_tokens:
next_items = heapq.nlargest(
max_next_tokens,
bucket.items(),
key=lambda item: item[1],
)
else:
next_items = bucket.items()
filtered_next_items = [
(token_to_id[token], float(value))
for token, value in next_items
if token in token_to_id and value > 0.0
]
total = sum(value for _, value in filtered_next_items)
if total <= 0.0:
continue
orders.append(order)
key_token_ids.extend(key_ids)
key_offsets.append(len(key_token_ids))
for token_id, value in filtered_next_items:
next_token_ids.append(token_id)
next_probabilities.append(value / total)
next_offsets.append(len(next_token_ids))
if np is None:
return {
"orders": orders,
"key_offsets": key_offsets,
"key_token_ids": key_token_ids,
"next_offsets": next_offsets,
"next_token_ids": next_token_ids,
"next_probabilities": next_probabilities,
"order_spans": {},
}
return {
"orders": np.asarray(orders, dtype=np.int32),
"key_offsets": np.asarray(key_offsets, dtype=np.int32),
"key_token_ids": np.asarray(key_token_ids, dtype=np.int32),
"next_offsets": np.asarray(next_offsets, dtype=np.int32),
"next_token_ids": np.asarray(next_token_ids, dtype=np.int32),
"next_probabilities": np.asarray(next_probabilities, dtype=np.float64),
"order_spans": {},
}
class StateReservoir:
def __init__(self, capacity: int | None, *, seed: int = 13) -> None:
self.capacity = capacity
self.random = random.Random(seed)
self.states: list[Vector] = []
self.labels: list[int] = []
self.weights: list[float] = []
self.seen = 0
self.total_weight = 0.0
self._state_array: object | None = None
def _should_preallocate(self) -> bool:
return (
np is not None
and self.capacity is not None
and self.capacity >= PREALLOCATED_RESERVOIR_MIN_CAPACITY
)
def _ensure_state_array(self, width: int) -> object | None:
if not self._should_preallocate():
return None
if self._state_array is None:
self._state_array = np.empty((int(self.capacity), width), dtype=RUNTIME_ARRAY_DTYPE)
return self._state_array
def reserve_slot(self, weight: float = 1.0) -> int | None:
if weight <= 0.0:
return None
self.seen += 1
self.total_weight += weight
if self.capacity is None:
return len(self.states)
if self.capacity <= 0:
return None
if len(self.states) < self.capacity:
return len(self.states)
keep_probability = min(1.0, (self.capacity * weight) / max(self.total_weight, 1e-12))
if self.random.random() >= keep_probability:
return None
return self.random.randrange(self.capacity)
def store_reserved(
self,
slot: int,
state: Vector,
label_id: int,
*,
example_weight: float = 1.0,
) -> None:
state_array = self._ensure_state_array(len(state))
if state_array is not None:
state_array[slot, :] = state
stored_state = None
else:
stored_state = state.copy() if hasattr(state, "copy") else state[:]
if slot == len(self.states):
self.states.append(stored_state)
self.labels.append(label_id)
self.weights.append(example_weight)
elif 0 <= slot < len(self.states):
self.states[slot] = stored_state
self.labels[slot] = label_id
self.weights[slot] = example_weight
def consider(self, state: Vector, label_id: int, weight: float = 1.0) -> None:
slot = self.reserve_slot(weight=weight)
if slot is not None:
self.store_reserved(slot, state, label_id, example_weight=weight)
def state_matrix(self, *, dtype: object | None = None) -> object:
if np is None or not self.states:
return self.states
if self._state_array is not None:
matrix = self._state_array[: len(self.states)]
if dtype is not None and matrix.dtype != dtype:
return matrix.astype(dtype, copy=False)
return matrix
return np.asarray(self.states, dtype=dtype or np.float64)
class SequenceReservoir:
def __init__(self, capacity: int | None, *, seed: int = 41) -> None:
self.capacity = capacity
self.random = random.Random(seed)
self.keys: list[Vector] = []
self.prompt_rows: list[list[int]] = []
self.token_rows: list[list[int]] = []
self.weights: list[float] = []
self.seen_weight = 0.0
self._key_array: object | None = None
def _should_preallocate(self) -> bool:
return (
np is not None
and self.capacity is not None
and self.capacity >= PREALLOCATED_RESERVOIR_MIN_CAPACITY
)
def _ensure_key_array(self, width: int) -> object | None:
if not self._should_preallocate():
return None
if self._key_array is None:
self._key_array = np.empty((int(self.capacity), width), dtype=RUNTIME_ARRAY_DTYPE)
return self._key_array
def reserve_slot(self, *, weight: float = 1.0) -> int | None:
if self.capacity == 0 or weight <= 0.0:
return None
self.seen_weight += weight
if self.capacity is None or len(self.keys) < self.capacity:
return len(self.keys)
probability = min(1.0, (self.capacity * weight) / max(self.seen_weight, 1e-12))
if self.random.random() >= probability:
return None
return self.random.randrange(self.capacity)
def store_reserved(
self,
slot: int,
key: Vector,
prompt_token_ids: list[int],
token_ids: list[int],
*,
example_weight: float = 1.0,
) -> None:
key_array = self._ensure_key_array(len(key))
if key_array is not None:
key_array[slot, :] = key
key_copy = None
else:
key_copy = key.copy() if hasattr(key, "copy") else list(key)
prompt_row = prompt_token_ids[:MAX_ANSWER_SEQUENCE_TOKENS]
row = token_ids[:MAX_ANSWER_SEQUENCE_TOKENS]
if self.capacity is None or slot >= len(self.keys):
self.keys.append(key_copy)
self.prompt_rows.append(prompt_row)
self.token_rows.append(row)
self.weights.append(example_weight)
return
self.keys[slot] = key_copy
self.prompt_rows[slot] = prompt_row
self.token_rows[slot] = row
self.weights[slot] = example_weight
def consider(
self,
key: Vector,
prompt_token_ids: list[int],
token_ids: list[int],
weight: float = 1.0,
) -> None:
if not token_ids:
return
slot = self.reserve_slot(weight=weight)
if slot is not None:
self.store_reserved(slot, key, prompt_token_ids, token_ids, example_weight=weight)
def state_matrix(self, *, dtype: object | None = None) -> object:
if np is None or not self.keys:
return self.keys
if self._key_array is not None:
matrix = self._key_array[: len(self.keys)]
if dtype is not None and matrix.dtype != dtype:
return matrix.astype(dtype, copy=False)
return matrix
return np.asarray(self.keys, dtype=dtype or np.float64)
def _word_count(text: str) -> int:
return len(text.split())
def _alpha_ratio(text: str) -> float:
if not text:
return 0.0
alpha_count = sum(character.isalpha() for character in text)
return alpha_count / len(text)
def _row_language(row: dict[str, object]) -> str:
for candidate in ("lang", "language", "locale"):
value = row.get(candidate)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def _normalize_role(raw_role: object) -> str:
role = str(raw_role or "").strip().casefold()
return ROLE_ALIASES.get(role, role)
def _coerce_json_payload(payload: object) -> object:
if not isinstance(payload, str):
return payload
stripped = payload.strip()
if not stripped:
return ""
try:
return json.loads(stripped)
except json.JSONDecodeError:
return stripped
def _compact_json(payload: object) -> str:
if isinstance(payload, str):
return payload.strip()
return json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
def _render_tool_call(call: object) -> str:
if not isinstance(call, dict):
return f"<tool_call> {str(call).strip()}".strip()
function_payload = call.get("function", {})
function = function_payload if isinstance(function_payload, dict) else {}
name = str(call.get("name", function.get("name", "tool"))).strip() or "tool"
arguments = call.get("arguments", function.get("arguments", {}))
return f"<tool_call> {name} {_compact_json(arguments)}".strip()
def _render_source_lines(payload: object) -> list[str]:
if not isinstance(payload, dict):
return []
raw_sources = payload.get("sources", payload.get("source", []))
if isinstance(raw_sources, dict):
sources = [raw_sources]
elif isinstance(raw_sources, list):
sources = raw_sources
elif raw_sources:
sources = [raw_sources]
else:
sources = []
lines: list[str] = []
for source in sources:
if isinstance(source, dict):
title = str(source.get("title", source.get("name", "source"))).strip()
url = str(source.get("url", source.get("uri", ""))).strip()
snippet = str(source.get("snippet", source.get("text", source.get("content", "")))).strip()
parts = [part for part in (title, url, snippet) if part]
if parts:
lines.append(f"<source> {' | '.join(parts)}")
elif source:
lines.append(f"<source> {str(source).strip()}")
return lines
def _render_tool_result(name: str, payload: object) -> list[str]:
tool_name = name.strip() or "tool"
parsed = _coerce_json_payload(payload)
if isinstance(parsed, dict):
explicit_name = str(parsed.get("name", parsed.get("tool", ""))).strip()
if explicit_name:
tool_name = explicit_name
status = str(parsed.get("status", "")).casefold()
ok_value = parsed.get("ok", None)
error = str(parsed.get("error", parsed.get("message", ""))).strip()
failed = ok_value is False or status in {"error", "failed", "failure", "timeout"} or bool(error)
if failed:
first = f"<tool_result> {tool_name} failed: {error or status or 'unknown error'}"
else:
summary = str(parsed.get("summary", parsed.get("content", parsed.get("text", "")))).strip()
first = f"<tool_result> {tool_name} ok"
if summary and not _render_source_lines(parsed):
first = f"{first}: {summary}"
return [first, *_render_source_lines(parsed)]
if parsed:
return [f"<tool_result> {tool_name} {str(parsed).strip()}"]
return [f"<tool_result> {tool_name} empty"]
def _message_content(message: dict[str, object], role: str = "") -> str:
if role == "tool":
name = str(message.get("name", message.get("tool_call_id", "tool"))).strip() or "tool"
payload = message.get("content", message.get("value", message.get("text", message)))
return clean_training_text("\n".join(_render_tool_result(name, payload)))
parts: list[str] = []
for field in ("content", "value", "text", "message"):
value = message.get(field)
if isinstance(value, str) and value.strip():
parts.append(clean_training_text(value))
break
tool_calls = message.get("tool_calls", message.get("function_calls", message.get("tools")))
if isinstance(tool_calls, str):
tool_calls = _coerce_json_payload(tool_calls)
if isinstance(tool_calls, dict):
tool_calls = [tool_calls]
if isinstance(tool_calls, list):
for call in tool_calls:
parts.append(_render_tool_call(call))
return "\n".join(part for part in parts if part).strip()
def _message_role(message: dict[str, object]) -> str:
for field in ("role", "from", "speaker", "author"):
value = message.get(field)
if value is not None:
normalized = _normalize_role(value)
if normalized:
return normalized
return ""
def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]:
if isinstance(raw_messages, str):
parsed_json = _coerce_json_payload(raw_messages)
if parsed_json is not raw_messages:
raw_messages = parsed_json
if not isinstance(raw_messages, list):
return []
parsed: list[dict[str, str]] = []
for message in raw_messages:
if not isinstance(message, dict):
continue
role = _message_role(message)
content = _message_content(message, role)
if role not in {"system", "user", "assistant", "tool"} or not content:
continue
parsed.append({"role": role, "content": content})
return parsed
def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]:
if not isinstance(raw_text, str):
return []
text = raw_text.strip()
if not text:
return []
matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text))
if not matches:
return []
parsed: list[dict[str, str]] = []
for index, match in enumerate(matches):
role = _normalize_role(match.group(1))
start = match.end()
end = matches[index + 1].start() if index + 1 < len(matches) else len(text)
raw_content = text[start:end].strip()
if role == "tool":
content = clean_training_text("\n".join(_render_tool_result("tool", raw_content)))
else:
content = clean_training_text(raw_content)
if role in {"system", "user", "assistant", "tool"} and content:
parsed.append({"role": role, "content": content})
return parsed
def _render_prompt(messages: list[dict[str, str]]) -> str:
parts = []
for message in messages:
raw_content = message["content"]
if message["role"] in {"system", "tool"} or any(
token in raw_content for token in TOOL_PROTOCOL_TOKENS
):
content = clean_training_text(raw_content)
else:
content = clean_context_text(raw_content)
if content:
parts.append(content)
return "\n".join(parts).strip()
def _last_user_prompt_before(messages: list[dict[str, str]], end_index: int) -> str:
for message in reversed(messages[:end_index]):
if message["role"] == "user":
return clean_context_text(message["content"])
return _render_prompt(messages[:end_index])
def _compose_training_text(context: object, answer: object) -> str:
prompt_text = clean_context_text(_flatten_value(context))
answer_text = clean_answer_text(_flatten_value(answer))
if prompt_text and answer_text:
return f"<reason> {prompt_text} <answer> {answer_text}".strip()
return clean_training_text(answer_text or prompt_text)
def _compose_from_messages(messages: list[dict[str, str]]) -> str:
assistant_index = None
for index in range(len(messages) - 1, -1, -1):
if messages[index]["role"] == "assistant":
assistant_index = index
break
if assistant_index is not None:
history = messages[:assistant_index]
if any(
message["role"] in {"system", "tool"}
or any(token in message["content"] for token in TOOL_PROTOCOL_TOKENS)
for message in history
):
prompt = _render_prompt(history)
else:
prompt = _last_user_prompt_before(messages, assistant_index)
answer = clean_answer_text(messages[assistant_index]["content"])
if prompt and answer:
return f"<reason> {prompt} <answer> {answer}".strip()
return "\n".join(
message["content"]
for message in messages
if message.get("content")
).strip()
def _history_needs_full_prompt(history: list[dict[str, str]]) -> bool:
return any(
message["role"] in {"system", "tool"}
or any(token in message["content"] for token in TOOL_PROTOCOL_TOKENS)
for message in history
)
def _compose_training_texts_from_messages(messages: list[dict[str, str]]) -> list[str]:
rows: list[str] = []
history: list[dict[str, str]] = []
for message in messages:
if message["role"] != "assistant":
history.append(message)
continue
prompt = (
_render_prompt(history)
if _history_needs_full_prompt(history)
else _last_user_prompt_before(history, len(history))
)
answer = clean_answer_text(message["content"])
if prompt and answer:
rows.append(f"<reason> {prompt} <answer> {answer}".strip())
history.append(message)
return rows
def _flatten_message_list(messages: object) -> str:
parsed = _parse_dialogue_messages(messages)
if parsed:
return _compose_from_messages(parsed)
if not isinstance(messages, list):
return ""
parts: list[str] = []
for message in messages:
if not isinstance(message, dict):
continue
content = str(
message.get("content", message.get("value", message.get("text", "")))
).strip()
if not content:
continue
parts.append(clean_training_text(content))
return "\n".join(parts).strip()
def _flatten_value(value: object) -> str:
if isinstance(value, str):
parsed = _parse_transcript_messages(value)
if parsed:
return _compose_from_messages(parsed)
return clean_training_text(value.strip())
if isinstance(value, list):
return _flatten_message_list(value)
if isinstance(value, dict):
for field in ("messages", "conversation", "conversations", "dialogue", "turns"):
nested_messages = value.get(field)
text = _flatten_message_list(nested_messages)
if text:
return text
for field in ("text", "content", "value", "message"):
nested = value.get(field)
if isinstance(nested, str) and nested.strip():
return _flatten_value(nested)
return ""
def _tool_definition_text(row: dict[str, object]) -> str:
parts: list[str] = []
for field in TOOL_DEFINITION_FIELDS:
value = row.get(field)
if value in (None, ""):
continue
parts.append(_compact_json(_coerce_json_payload(value)))
if not parts:
return ""
return clean_training_text("Available tools: " + "\n".join(parts))
def _tool_call_answer_text(value: object) -> str:
payload = _coerce_json_payload(value)
calls = payload if isinstance(payload, list) else [payload]
rendered = [
_render_tool_call(call)
for call in calls
if call not in (None, "")
]
return clean_answer_text("\n".join(part for part in rendered if part))
def _tool_call_training_text_from_row(row: dict[str, object]) -> str:
if "query" not in row or "answers" not in row:
return ""
answer_text = _tool_call_answer_text(row.get("answers"))
if not answer_text:
return ""
query_text = clean_context_text(_flatten_value(row.get("query")))
if not query_text:
return ""
tool_definition = _tool_definition_text(row)
context = "\n".join(part for part in (tool_definition, query_text) if part)
return _compose_training_text(context, answer_text)
def _prepend_dialogue_context(text: str, row: dict[str, object]) -> str:
prefixes: list[str] = []
system_text = clean_training_text(str(row.get("system", "")).strip())
if system_text:
prefixes.append(system_text)
tool_definition = _tool_definition_text(row)
if tool_definition and tool_definition != system_text:
prefixes.append(tool_definition)
if not prefixes:
return text
context = "\n".join(prefixes)
if text.startswith("<reason> "):
return text.replace("<reason> ", f"<reason> {context}\n", 1)
return f"{context}\n{text}".strip()
def _dialogue_texts_from_row(row: dict[str, object]) -> list[str]:
prefix_messages: list[dict[str, str]] = []
system_text = clean_training_text(str(row.get("system", "")).strip())
if system_text:
prefix_messages.append({"role": "system", "content": system_text})
tool_definition = _tool_definition_text(row)
if tool_definition and tool_definition != system_text:
prefix_messages.append({"role": "system", "content": tool_definition})
rows: list[str] = []
for field in DIALOGUE_FIELD_PREFERENCES:
messages = _parse_dialogue_messages(row.get(field))
if messages:
tool_related = bool(tool_definition) or any(
message["role"] == "tool"
or any(token in message["content"] for token in TOOL_PROTOCOL_TOKENS)
for message in messages
)
if tool_related:
rows.extend(_compose_training_texts_from_messages([*prefix_messages, *messages]))
else:
text = _compose_from_messages(messages)
if text:
rows.append(_prepend_dialogue_context(text, row))
return rows
def _safe_flag(value: object) -> bool | None:
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().casefold()
if normalized in {"true", "1", "yes", "safe"}:
return True
if normalized in {"false", "0", "no", "unsafe"}:
return False
return None
def _selected_response_fields(row: dict[str, object]) -> tuple[str, str]:
if "response_0" not in row or "response_1" not in row:
return "", ""
safe_0 = _safe_flag(row.get("is_response_0_safe"))
safe_1 = _safe_flag(row.get("is_response_1_safe"))
if safe_0 is not None and safe_1 is not None:
if safe_0 and not safe_1:
return "response_0", "response_1"
if safe_1 and not safe_0:
return "response_1", "response_0"
if safe_0 and safe_1:
return "response_0", ""
return "", ""
for selector in ("safer_response_id", "better_response_id"):
raw_value = row.get(selector)
try:
preferred = int(raw_value)
except (TypeError, ValueError):
continue
chosen = "response_1" if preferred == 1 else "response_0"
rejected = "response_0" if chosen == "response_1" else "response_1"
return chosen, rejected
return "response_0", "response_1"
def _extract_preference_pair(row: dict[str, object]) -> tuple[str, str]:
if "chosen" in row and "rejected" in row:
chosen_text = clean_training_text(_flatten_value(row.get("chosen")))
rejected_text = clean_training_text(_flatten_value(row.get("rejected")))
if chosen_text and rejected_text:
return chosen_text, rejected_text
if "response_0" in row and "response_1" in row:
preferred_field, rejected_field = _selected_response_fields(row)
if not preferred_field or not rejected_field:
return "", ""
prompt = row.get("prompt", row.get("question", row.get("query", "")))
if prompt:
chosen_text = _compose_training_text(prompt, row.get(preferred_field))
rejected_text = _compose_training_text(prompt, row.get(rejected_field))
if chosen_text and rejected_text:
return clean_training_text(chosen_text), clean_training_text(rejected_text)
chosen_text = clean_training_text(_flatten_value(row.get(preferred_field)))
rejected_text = clean_training_text(_flatten_value(row.get(rejected_field)))
if chosen_text and rejected_text:
return chosen_text, rejected_text
return "", ""
def _extract_preference_value(row: dict[str, object]) -> str:
chosen_text, _ = _extract_preference_pair(row)
return chosen_text
def _extract_row_text(row: dict[str, object], text_field: str | None) -> str:
tool_call_text = _tool_call_training_text_from_row(row)
if tool_call_text:
return tool_call_text
if "context" in row and "answer" in row:
context = clean_context_text(_flatten_value(row.get("context")))
answer = clean_answer_text(_flatten_value(row.get("answer")))
if context and answer:
return f"<reason> {context} <answer> {answer}".strip()
if "response_0" in row and "response_1" in row:
preferred_field, _ = _selected_response_fields(row)
prompt = row.get("prompt", row.get("question", row.get("query", "")))
if preferred_field and prompt:
text = _compose_training_text(prompt, row.get(preferred_field))
if text:
return text
for prompt_field, answer_field in INSTRUCTION_FIELD_PAIRS:
if prompt_field in row and answer_field in row:
prompt_value = row.get(prompt_field)
system_value = row.get("system_prompt", row.get("system", ""))
system_text = clean_context_text(_flatten_value(system_value))
if system_text:
prompt_value = "\n".join(
part
for part in (system_text, clean_context_text(_flatten_value(prompt_value)))
if part
)
text = _compose_training_text(prompt_value, row.get(answer_field))
if text:
return text
if text_field is not None:
return clean_training_text(_flatten_value(row.get(text_field)))
preferred = _extract_preference_value(row)
if preferred:
return clean_training_text(preferred)
for field in TEXT_FIELD_PREFERENCES:
text = _flatten_value(row.get(field))
if text:
return clean_training_text(text)
for field in DIALOGUE_FIELD_PREFERENCES:
text = _flatten_value(row.get(field))
if text:
return clean_training_text(_prepend_dialogue_context(text, row))
return ""
def _extract_row_texts(row: dict[str, object], text_field: str | None) -> list[str]:
if text_field is None:
dialogue_texts = _dialogue_texts_from_row(row)
if dialogue_texts:
return dialogue_texts
text = _extract_row_text(row, text_field)
return [text] if text else []
def _passes_text_quality(text: str, language: str, entry: CorpusPlanEntry) -> bool:
if not text:
return False
if has_machine_artifacts(text):
return False
word_count = _word_count(text)
if entry.min_words > 0 and word_count < entry.min_words:
return False
if entry.max_words > 0 and word_count > entry.max_words:
return False
if entry.min_alpha_ratio > 0.0 and _alpha_ratio(text) < entry.min_alpha_ratio:
return False
if entry.allowed_languages:
if not language or language.casefold() not in entry.allowed_languages:
return False
return True
def load_corpus_plan(source: str | Path) -> list[CorpusPlanEntry]:
payload = json.loads(Path(source).read_text(encoding="utf-8-sig"))
raw_entries = payload.get("sources", payload.get("datasets", []))
if not isinstance(raw_entries, list) or not raw_entries:
raise ValueError("Corpus plan must define a non-empty 'sources' list.")
entries: list[CorpusPlanEntry] = []
for index, raw_entry in enumerate(raw_entries, start=1):
if not isinstance(raw_entry, dict):
raise ValueError("Each corpus plan entry must be an object.")
source = str(raw_entry.get("source", "hf")).strip() or "hf"
name = str(
raw_entry.get("name", raw_entry.get("dataset", f"source-{index}"))
).strip() or f"source-{index}"
raw_languages = raw_entry.get("allowed_languages", [])
allowed_languages = tuple(
str(value).strip().casefold()
for value in raw_languages
if str(value).strip()
) if isinstance(raw_languages, list) else ()
raw_records = raw_entry.get("records", raw_entry.get("texts", []))
if source == "inline" and not isinstance(raw_records, list):
raise ValueError("Inline corpus plan entries must provide a records/texts list.")
entries.append(
CorpusPlanEntry(
source=source,
name=name,
dataset=str(raw_entry.get("dataset", "")),
path=str(raw_entry.get("path", raw_entry.get("file", ""))),
config=(
str(raw_entry["config"])
if raw_entry.get("config") is not None
else None
),
split=str(raw_entry.get("split", "train")),
limit=int(raw_entry.get("limit", 0)),
weight=float(raw_entry.get("weight", 1.0)),
text_field=(
str(raw_entry["text_field"])
if raw_entry.get("text_field") is not None
else None
),
min_words=int(raw_entry.get("min_words", 0)),
max_words=int(raw_entry.get("max_words", 0)),
min_alpha_ratio=float(raw_entry.get("min_alpha_ratio", 0.0)),
allowed_languages=allowed_languages,
records=tuple(raw_records) if isinstance(raw_records, list) else (),
streaming=bool(raw_entry.get("streaming", True)),
trust_remote_code=bool(raw_entry.get("trust_remote_code", False)),
max_seconds=float(raw_entry.get("max_seconds", 0.0)),
readout_weight=float(raw_entry.get("readout_weight", 1.0)),
transition_weight=float(raw_entry.get("transition_weight", 1.0)),
)
)
return entries
def _iter_hf_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
try:
from datasets import load_dataset
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
from datasets import load_dataset
dataset_kwargs: dict[str, object] = {
"split": entry.split,
"streaming": entry.streaming,
}
if entry.config:
dataset_kwargs["name"] = entry.config
if entry.trust_remote_code:
dataset_kwargs["trust_remote_code"] = True
for row in load_dataset(entry.dataset, **dataset_kwargs):
yield dict(row)
def _hf_auth_headers() -> dict[str, str]:
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
if not token:
try:
from huggingface_hub import get_token
token = get_token()
except (ImportError, ModuleNotFoundError, OSError):
token = None
return {"Authorization": f"Bearer {token}"} if token else {}
def _remaining_source_budget(entry: CorpusPlanEntry, started_at: float) -> float | None:
if entry.max_seconds <= 0.0:
return None
return max(0.0, float(entry.max_seconds) - (time.monotonic() - started_at))
def _hf_viewer_request_timeout(entry: CorpusPlanEntry, started_at: float) -> float:
remaining = _remaining_source_budget(entry, started_at)
if remaining is None:
return HF_VIEWER_REQUEST_TIMEOUT_SECONDS
if remaining <= 0.0:
return 0.0
return max(
HF_VIEWER_MIN_REQUEST_TIMEOUT_SECONDS,
min(HF_VIEWER_REQUEST_TIMEOUT_SECONDS, remaining),
)
def _hf_viewer_retry_delay(entry: CorpusPlanEntry, started_at: float, attempt: int) -> float:
delay = min(20.0, HF_STREAM_RETRY_BASE_DELAY_SECONDS * (2 ** attempt))
remaining = _remaining_source_budget(entry, started_at)
if remaining is None:
return delay
return max(0.0, min(delay, remaining))
def _iter_hf_viewer_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
import requests
if not entry.dataset:
return
started_at = time.monotonic()
offset = 0
page_size = min(100, max(1, entry.limit if entry.limit > 0 else 100))
while True:
if entry.max_seconds > 0.0 and _remaining_source_budget(entry, started_at) <= 0.0:
print(
f"[hf_viewer] {entry.name or entry.dataset}: source budget reached "
f"at offset {offset}",
flush=True,
)
return
params: dict[str, object] = {
"dataset": entry.dataset,
"split": entry.split,
"offset": offset,
"length": page_size,
}
if entry.config is not None:
params["config"] = entry.config
response = None
for attempt in range(HF_STREAM_MAX_RETRIES + 1):
request_timeout = _hf_viewer_request_timeout(entry, started_at)
if request_timeout <= 0.0:
print(
f"[hf_viewer] {entry.name or entry.dataset}: source budget reached "
f"before offset {offset}",
flush=True,
)
return
try:
response = requests.get(
"https://datasets-server.huggingface.co/rows",
params=params,
headers=_hf_auth_headers(),
timeout=request_timeout,
)
except Exception as exc:
if attempt >= HF_STREAM_MAX_RETRIES:
print(
f"[hf_viewer] {entry.name or entry.dataset}: skipped at offset "
f"{offset} after request error: {exc}",
flush=True,
)
return
delay = _hf_viewer_retry_delay(entry, started_at, attempt)
if delay <= 0.0:
print(
f"[hf_viewer] {entry.name or entry.dataset}: source budget reached "
f"after request error at offset {offset}: {exc}",
flush=True,
)
return
print(
f"[hf_viewer] {entry.name or entry.dataset}: request error at "
f"offset {offset}; retry {attempt + 1}/{HF_STREAM_MAX_RETRIES} "
f"in {delay:.2f}s: {exc}",
flush=True,
)
time.sleep(delay)
continue
status_code = int(getattr(response, "status_code", 200))
if status_code not in {429, 500, 502, 503, 504}:
break
if attempt >= HF_STREAM_MAX_RETRIES:
print(
f"[hf_viewer] {entry.name or entry.dataset}: skipped at offset "
f"{offset} after status {status_code}",
flush=True,
)
return
delay = _hf_viewer_retry_delay(entry, started_at, attempt)
if delay <= 0.0:
print(
f"[hf_viewer] {entry.name or entry.dataset}: source budget reached "
f"after status {status_code} at offset {offset}",
flush=True,
)
return
print(
f"[hf_viewer] {entry.name or entry.dataset}: status {status_code} "
f"at offset {offset}; retry {attempt + 1}/{HF_STREAM_MAX_RETRIES} in {delay:.2f}s",
flush=True,
)
time.sleep(delay)
assert response is not None
response.raise_for_status()
payload = response.json()
raw_rows = payload.get("rows", [])
if not isinstance(raw_rows, list) or not raw_rows:
return
for item in raw_rows:
if isinstance(item, dict):
row = item.get("row", item)
if isinstance(row, dict):
yield row
offset += len(raw_rows)
if len(raw_rows) < page_size:
return
total = payload.get("num_rows_total")
if isinstance(total, int) and offset >= total:
return
def _hf_parquet_infos(entry: CorpusPlanEntry) -> list[dict[str, object]]:
import requests
response = requests.get(
"https://datasets-server.huggingface.co/parquet",
params={"dataset": entry.dataset},
headers=_hf_auth_headers(),
timeout=60,
)
response.raise_for_status()
payload = response.json()
raw_files = payload.get("parquet_files", [])
if not isinstance(raw_files, list):
return []
infos: list[dict[str, object]] = []
for info in raw_files:
if not isinstance(info, dict):
continue
split = str(info.get("split", "")).strip()
config = str(info.get("config", "")).strip()
url = str(info.get("url", "")).strip()
if not url:
continue
if split != entry.split:
continue
if entry.config is not None and config != entry.config:
continue
infos.append(info)
infos.sort(key=lambda item: str(item.get("filename", "")))
return infos
def _download_parquet_shard(url: str, target: Path) -> None:
import requests
with requests.get(url, headers=_hf_auth_headers(), stream=True, timeout=120) as response:
response.raise_for_status()
with target.open("wb") as handle:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
handle.write(chunk)
def _iter_parquet_file_rows(path: Path) -> Iterator[dict[str, object]]:
import pyarrow.parquet as pq
parquet_file = pq.ParquetFile(path)
for batch in parquet_file.iter_batches(batch_size=512):
for row in batch.to_pylist():
if isinstance(row, dict):
yield row
def _iter_hf_parquet_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
infos = _hf_parquet_infos(entry)
if not infos:
return
print(
f"[hf_parquet] {entry.name or entry.dataset}: {len(infos)} shard(s) for "
f"{entry.dataset} {entry.config or 'default'}:{entry.split}",
flush=True,
)
temp_root = Path(os.environ.get("REFRAMR_HF_PARQUET_CACHE", "C:/tmp/reframr-hf-parquet"))
temp_root.mkdir(parents=True, exist_ok=True)
safe_name = re.sub(r"[^A-Za-z0-9_.-]+", "-", entry.name or entry.dataset).strip("-") or "hf-parquet"
for shard_index, info in enumerate(infos):
url = str(info.get("url", "")).strip()
if not url:
continue
local_source = Path(url)
downloaded = False
if url.startswith(("http://", "https://")):
suffix = Path(str(info.get("filename", f"{shard_index}.parquet"))).suffix or ".parquet"
local_source = temp_root / f"{safe_name}-{shard_index:05d}{suffix}"
print(
f"[hf_parquet] {entry.name or entry.dataset}: downloading shard "
f"{shard_index + 1}/{len(infos)} ({int(info.get('size', 0) or 0)} bytes)",
flush=True,
)
_download_parquet_shard(url, local_source)
downloaded = True
try:
yield from _iter_parquet_file_rows(local_source)
finally:
if downloaded:
local_source.unlink(missing_ok=True)
def _iter_file_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
raw_path = entry.path or entry.dataset
if not raw_path:
raise ValueError("File corpus plan entries must provide a path.")
path = Path(raw_path)
suffix = path.suffix.lower()
if suffix == ".jsonl":
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.strip():
row = json.loads(line)
yield row if isinstance(row, dict) else {"text": str(row)}
return
if suffix == ".json":
payload = json.loads(path.read_text(encoding="utf-8"))
if isinstance(payload, list):
for row in payload:
yield row if isinstance(row, dict) else {"text": str(row)}
return
if isinstance(payload, dict):
rows = payload.get("records", payload.get("texts"))
if isinstance(rows, list):
for row in rows:
yield row if isinstance(row, dict) else {"text": str(row)}
return
yield payload
return
if suffix in {".txt", ".md", ".text"}:
yield {"text": path.read_text(encoding="utf-8")}
return
raise ValueError(f"Unsupported file corpus source: {path}")
def iter_corpus_plan_documents(plan: Iterable[CorpusPlanEntry]) -> Iterator[StreamDocument]:
for entry in plan:
accepted = 0
attempts = 0
source_started_at = time.monotonic()
while True:
accepted_seen_this_attempt = 0
try:
if entry.source == "inline":
row_iterator = (
item if isinstance(item, dict) else {"text": str(item)}
for item in entry.records
)
elif entry.source == "hf":
row_iterator = _iter_hf_rows(entry)
elif entry.source == "hf_viewer":
row_iterator = _iter_hf_viewer_rows(entry)
elif entry.source == "hf_parquet":
row_iterator = _iter_hf_parquet_rows(entry)
elif entry.source == "file":
row_iterator = _iter_file_rows(entry)
else:
raise ValueError(f"Unsupported corpus plan source: {entry.source}")
for row in row_iterator:
if (
entry.max_seconds > 0.0
and time.monotonic() - source_started_at >= entry.max_seconds
):
print(
f"[source] {entry.name} time budget reached after "
f"{accepted} accepted documents; moving on",
flush=True,
)
break
language = _row_language(row)
_, rejected_text = _extract_preference_pair(row)
for raw_text in _extract_row_texts(row, entry.text_field):
text = clean_training_text(raw_text)
if not _passes_text_quality(text, language, entry):
continue
accepted_seen_this_attempt += 1
if accepted_seen_this_attempt <= accepted:
continue
yield StreamDocument(
text=text,
weight=entry.weight,
source=entry.name,
language=language,
preference_rejected_text=rejected_text,
readout_weight=max(0.0, entry.readout_weight),
transition_weight=max(0.0, entry.transition_weight),
)
accepted += 1
if entry.limit > 0 and accepted >= entry.limit:
break
if (
entry.limit > 0
and accepted >= entry.limit
or (
entry.max_seconds > 0.0
and time.monotonic() - source_started_at >= entry.max_seconds
)
):
break
break
except Exception as exc:
if entry.source != "hf":
raise
if attempts >= HF_STREAM_MAX_RETRIES:
print(
f"[source] {entry.name} skipped after {attempts} retries; "
f"accepted {accepted} documents before final error: {exc}"
)
break
attempts += 1
delay = min(
15.0,
HF_STREAM_RETRY_BASE_DELAY_SECONDS * (2 ** (attempts - 1)),
)
print(
f"[source] {entry.name} stream interrupted after {accepted} accepted "
f"documents; retry {attempts}/{HF_STREAM_MAX_RETRIES} in {delay:.2f}s: {exc}"
)
time.sleep(delay)
def estimate_corpus_plan(
plan: Iterable[CorpusPlanEntry],
*,
max_rows_per_source: int = 0,
) -> dict[str, object]:
"""Fast preflight pass that applies text-quality gates without tokenizing/training."""
source_summaries: list[dict[str, object]] = []
total_accepted = 0
total_seen = 0
total_words = 0
started = time.perf_counter()
for entry in plan:
accepted = 0
seen = 0
rejected = 0
words = 0
source_started = time.perf_counter()
if entry.source == "inline":
row_iterator = (
item if isinstance(item, dict) else {"text": str(item)}
for item in entry.records
)
elif entry.source == "file":
row_iterator = _iter_file_rows(entry)
elif entry.source in {"hf", "hf_viewer", "hf_parquet"}:
source_summaries.append(
{
"name": entry.name,
"source": entry.source,
"limit": entry.limit,
"accepted": None,
"seen": None,
"rejected": None,
"estimated_words": None,
"seconds": 0.0,
"note": "remote source; materialize first for exact preflight",
}
)
continue
else:
raise ValueError(f"Unsupported corpus plan source: {entry.source}")
for row in row_iterator:
if max_rows_per_source > 0 and seen >= max_rows_per_source:
break
language = _row_language(row)
for raw_text in _extract_row_texts(row, entry.text_field):
seen += 1
text = clean_training_text(raw_text)
if _passes_text_quality(text, language, entry):
accepted += 1
word_count = _word_count(text)
words += word_count
if entry.limit > 0 and accepted >= entry.limit:
break
else:
rejected += 1
if entry.limit > 0 and accepted >= entry.limit:
break
total_accepted += accepted
total_seen += seen
total_words += words
source_summaries.append(
{
"name": entry.name,
"source": entry.source,
"limit": entry.limit,
"accepted": accepted,
"seen": seen,
"rejected": rejected,
"estimated_words": words,
"seconds": round(time.perf_counter() - source_started, 6),
}
)
return {
"sources": source_summaries,
"accepted_documents": total_accepted,
"seen_texts": total_seen,
"rejected_texts": sum(
int(source.get("rejected") or 0) for source in source_summaries
),
"estimated_words": total_words,
"seconds": round(time.perf_counter() - started, 6),
}
def _log_progress(label: str, processed: int, log_every: int) -> None:
if log_every > 0 and processed % log_every == 0:
print(f"[{label}] processed {processed} documents", flush=True)
def _answer_boundary(tokens: list[str]) -> int | None:
try:
return tokens.index("<answer>")
except ValueError:
return None
def _weighted_text_parts_for_statistics(text: str, document_weight: float) -> list[tuple[str, float]]:
if "<answer>" not in text:
return [(text, document_weight)]
context, answer = text.split("<answer>", 1)
context = clean_context_text(context.replace("<reason>", " "))
answer = clean_answer_text(answer)
parts: list[tuple[str, float]] = []
if context:
parts.append((context, document_weight * CONTEXT_STAT_WEIGHT))
if answer:
parts.append((answer, document_weight * ANSWER_READOUT_WEIGHT))
return parts or [(text, document_weight)]
def _weighted_token_sequences_for_statistics(
tokens: list[str],
tokenizer: NativeTokenizer,
document_weight: float,
) -> list[tuple[list[str], float]]:
answer_index = _answer_boundary(tokens)
if answer_index is None:
sequence = [token for token in tokens if _is_learned_output_token(token, tokenizer)]
return [(sequence, document_weight)] if sequence else []
context_tokens = [
token
for token in tokens[:answer_index]
if _is_learned_output_token(token, tokenizer)
]
answer_tokens = [
token
for token in tokens[answer_index + 1 :]
if _is_learned_output_token(token, tokenizer)
]
sequences: list[tuple[list[str], float]] = []
if context_tokens:
sequences.append((context_tokens, document_weight * CONTEXT_STAT_WEIGHT))
if answer_tokens:
sequences.append((answer_tokens, document_weight * ANSWER_READOUT_WEIGHT))
return sequences
def _readout_weight_for_target(
answer_index: int | None,
target_index: int,
document_weight: float,
) -> float:
if answer_index is None:
return document_weight * PLAIN_TEXT_READOUT_WEIGHT
if target_index <= answer_index:
return document_weight * CONTEXT_READOUT_WEIGHT
return document_weight * ANSWER_READOUT_WEIGHT
def _target_balance_from_label_mass(target_label_mass: object) -> tuple[object, float]:
if np is not None:
masses = np.asarray(target_label_mass, dtype=np.float64)
positive_label_mass = masses[masses > 0.0]
reference_label_mass = (
float(np.median(positive_label_mass))
if positive_label_mass.size
else 1.0
)
target_balance = np.ones(masses.shape, dtype=np.float64)
np.divide(
reference_label_mass,
np.maximum(masses, 1e-12),
out=target_balance,
where=masses > 0.0,
)
return (
np.clip(
np.sqrt(target_balance),
TARGET_BALANCE_MIN,
TARGET_BALANCE_MAX,
),
reference_label_mass,
)
raw_masses = [float(value) for value in target_label_mass]
positive_label_mass = [value for value in raw_masses if value > 0.0]
if positive_label_mass:
sorted_mass = sorted(positive_label_mass)
reference_label_mass = sorted_mass[len(sorted_mass) // 2]
else:
reference_label_mass = 1.0
target_balance = [
max(
TARGET_BALANCE_MIN,
min(TARGET_BALANCE_MAX, (reference_label_mass / max(value, 1e-12)) ** 0.5),
)
if value > 0.0
else 1.0
for value in raw_masses
]
return target_balance, reference_label_mass
def _answer_payload_tokens(tokens: list[str], tokenizer: NativeTokenizer) -> list[str]:
answer_index = _answer_boundary(tokens)
payload = tokens[answer_index + 1 :] if answer_index is not None else tokens
return [token for token in payload if _is_learned_output_token(token, tokenizer)]
def _stable_window_start(tokens: list[str], width: int) -> int:
if width >= len(tokens):
return 0
sketch = "\n".join(
[
str(len(tokens)),
" ".join(tokens[:32]),
" ".join(tokens[len(tokens) // 2 : len(tokens) // 2 + 32]),
" ".join(tokens[-32:]),
]
)
digest = hashlib.blake2b(sketch.encode("utf-8"), digest_size=8).digest()
return int.from_bytes(digest, "little") % (len(tokens) - width + 1)
def _state_training_tokens(tokens: list[str], max_tokens: int | None) -> list[str]:
if max_tokens is None or max_tokens <= 0 or len(tokens) <= max_tokens:
return tokens
answer_index = _answer_boundary(tokens)
if answer_index is None:
start = _stable_window_start(tokens, max_tokens)
return tokens[start : start + max_tokens]
context_budget = max(1, max_tokens // 3)
context_tokens = tokens[max(0, answer_index - context_budget) : answer_index]
answer_budget = max(1, max_tokens - len(context_tokens) - 1)
answer_tokens = tokens[answer_index + 1 : answer_index + 1 + answer_budget]
return [*context_tokens, "<answer>", *answer_tokens]
def _standardized_preference_bias(values: object, active_mask: object | None = None) -> list[float]:
if np is not None:
bias = np.asarray(values, dtype=np.float64)
if bias.size == 0:
return []
active = (
np.asarray(active_mask, dtype=bool)
if active_mask is not None
else np.ones(bias.shape, dtype=bool)
)
if not np.any(active):
return [0.0 for _ in range(int(bias.size))]
active_values = bias[active]
spread = float(active_values.std())
if spread <= 1e-12:
return [0.0 for _ in range(int(bias.size))]
standardized = np.zeros_like(bias, dtype=np.float64)
standardized[active] = (
(active_values - float(active_values.mean())) / spread
) * PREFERENCE_BIAS_SCALE
return np.clip(standardized, -2.5, 2.5).astype(float).tolist()
raw_values = [float(value) for value in values]
if not raw_values:
return []
average = sum(raw_values) / len(raw_values)
variance = sum((value - average) * (value - average) for value in raw_values) / len(raw_values)
spread = variance**0.5
if spread <= 1e-12:
return [0.0 for _ in raw_values]
active_indices = (
[
index
for index, active in enumerate(active_mask)
if active
]
if active_mask is not None
else list(range(len(raw_values)))
)
if not active_indices:
return [0.0 for _ in raw_values]
active_values = [raw_values[index] for index in active_indices]
average = mean(active_values)
spread = (mean([(value - average) * (value - average) for value in active_values])) ** 0.5
if spread <= 1e-12:
return [0.0 for _ in raw_values]
standardized = [0.0 for _ in raw_values]
for index in active_indices:
standardized[index] = max(
-2.5,
min(2.5, ((raw_values[index] - average) / spread) * PREFERENCE_BIAS_SCALE),
)
return standardized
def _candidate_preference_bias_from_state_vector(
model: ReframrModel,
preference_state: object,
) -> object:
if np is None:
return None
assert model.embedding_model is not None
assert model.memory_units is not None
assert model.ternary_mask is not None
embeddings = np.asarray(model.embedding_model.embeddings, dtype=np.float64)
if embeddings.size == 0:
return np.zeros(0, dtype=np.float64)
state_vector = np.asarray(preference_state, dtype=np.float64)
mask = np.asarray(model.ternary_mask, dtype=np.float64) * float(model.ternary_scale)
if state_vector.shape[0] != mask.shape[0]:
return np.zeros(embeddings.shape[0], dtype=np.float64)
state_indices = np.arange(model.config.state_dim, dtype=np.int64)
drive = (
embeddings[:, state_indices % model.config.embedding_dim]
+ (0.5 * embeddings[:, (3 * state_indices + 1) % model.config.embedding_dim])
- (0.25 * embeddings[:, (5 * state_indices + 2) % model.config.embedding_dim])
)
scores = np.zeros(embeddings.shape[0], dtype=np.float64)
offset = 0
for unit in model.memory_units:
hidden_end = offset + model.config.state_dim
trace_end = hidden_end + model.config.embedding_dim
hidden_pref = state_vector[offset:hidden_end] * mask[offset:hidden_end]
trace_pref = state_vector[hidden_end:trace_end] * mask[hidden_end:trace_end]
hidden_delta_axis = np.asarray(unit.input_projection, dtype=np.float64) * hidden_pref
trace_gain = 1.0 - (1.0 / (1.0 + unit.timescale))
scores += drive @ hidden_delta_axis
scores += embeddings @ (trace_gain * trace_pref)
offset = trace_end
return scores
def _derive_preference_bias_from_pairs(
model: ReframrModel,
preference_token_pairs: list[tuple[list[str], list[str], float]],
tokenizer: NativeTokenizer,
) -> tuple[list[float], int]:
assert model.embedding_model is not None
vocab_size = len(model.embedding_model.id_to_token)
if not preference_token_pairs:
return [0.0 for _ in range(vocab_size)], 0
if np is not None:
token_bias = np.zeros(vocab_size, dtype=np.float64)
active_token_mask = np.zeros(vocab_size, dtype=bool)
state_delta = np.zeros(model._combined_state_width(), dtype=np.float64)
else:
token_bias = [0.0 for _ in range(vocab_size)]
active_token_ids: set[int] = set()
state_delta = [0.0 for _ in range(model._combined_state_width())]
pair_weight_total = 0.0
state_pair_count = 0
state_stride = max(
1,
(len(preference_token_pairs) + MAX_PREFERENCE_STATE_PAIRS - 1)
// MAX_PREFERENCE_STATE_PAIRS,
)
for pair_index, (chosen_tokens, rejected_tokens, pair_weight) in enumerate(preference_token_pairs):
chosen_answer = _answer_payload_tokens(chosen_tokens, tokenizer)
rejected_answer = _answer_payload_tokens(rejected_tokens, tokenizer)
if chosen_answer:
delta = pair_weight / max(1, len(chosen_answer))
for token in chosen_answer:
token_id = model.embedding_model.token_to_id.get(token)
if token_id is not None:
token_bias[token_id] += delta
if np is not None:
active_token_mask[token_id] = True
else:
active_token_ids.add(token_id)
if rejected_answer:
delta = pair_weight / max(1, len(rejected_answer))
for token in rejected_answer:
token_id = model.embedding_model.token_to_id.get(token)
if token_id is not None:
token_bias[token_id] -= delta
if np is not None:
active_token_mask[token_id] = True
else:
active_token_ids.add(token_id)
if pair_index % state_stride != 0 or state_pair_count >= MAX_PREFERENCE_STATE_PAIRS:
continue
chosen_state = model._masked_decode_state(model._build_decode_state(chosen_tokens))
rejected_state = model._masked_decode_state(model._build_decode_state(rejected_tokens))
if len(chosen_state) != len(rejected_state):
continue
pair_weight_total += pair_weight
state_pair_count += 1
if np is not None:
state_delta += pair_weight * (
np.asarray(chosen_state, dtype=np.float64)
- np.asarray(rejected_state, dtype=np.float64)
)
else:
for index, (chosen_value, rejected_value) in enumerate(zip(chosen_state, rejected_state)):
state_delta[index] += pair_weight * (chosen_value - rejected_value)
if pair_weight_total > 0.0:
if np is not None:
state_delta = state_delta / pair_weight_total
candidate_bias = _candidate_preference_bias_from_state_vector(model, state_delta)
if candidate_bias is not None:
token_bias[active_token_mask] = (
token_bias[active_token_mask] + candidate_bias[active_token_mask]
)
else:
state_delta = [value / pair_weight_total for value in state_delta]
if np is not None:
return _standardized_preference_bias(token_bias, active_token_mask), state_pair_count
active_mask = [index in active_token_ids for index in range(vocab_size)]
return _standardized_preference_bias(token_bias, active_mask), state_pair_count
def _solve_weighted_prompt_readout(
states: object,
labels: list[int],
weights: list[float],
*,
vocab_size: int,
diagonal: object,
state_offset: object,
regularization: float,
full_solver: bool = True,
) -> tuple[object, object, int]:
if np is None or not labels or not weights:
return [], [0.0 for _ in range(vocab_size)], 0
state_matrix = np.asarray(states, dtype=np.float64)
if state_matrix.ndim != 2 or int(state_matrix.shape[0]) == 0:
return [], [0.0 for _ in range(vocab_size)], 0
label_array = np.asarray(labels, dtype=np.int64)
weight_vector = np.asarray(weights, dtype=np.float64)
row_count = min(int(state_matrix.shape[0]), int(label_array.shape[0]), int(weight_vector.shape[0]))
if row_count <= 0:
return [], [0.0 for _ in range(vocab_size)], 0
state_matrix = state_matrix[:row_count]
label_array = label_array[:row_count]
weight_vector = weight_vector[:row_count]
valid_mask = (
(label_array >= 0)
& (label_array < vocab_size)
& (weight_vector > 0.0)
)
if not np.any(valid_mask):
return [], [0.0 for _ in range(vocab_size)], 0
state_matrix = state_matrix[valid_mask]
label_array = label_array[valid_mask]
weight_vector = weight_vector[valid_mask]
diagonal_array = np.asarray(diagonal, dtype=np.float64)
offset_array = np.asarray(state_offset, dtype=np.float64)
if (
len(state_matrix.shape) != 2
or diagonal_array.shape[0] != state_matrix.shape[1]
or offset_array.shape[0] != state_matrix.shape[1]
):
return [], [0.0 for _ in range(vocab_size)], 0
masked_states = state_matrix * diagonal_array[None, :]
centered_states = masked_states - offset_array[None, :]
weighted_centered_states = weight_vector[:, None] * centered_states
cross = np.zeros((vocab_size, centered_states.shape[1]), dtype=np.float64)
np.add.at(cross, label_array, weighted_centered_states)
total_weight = float(weight_vector.sum())
if total_weight <= 0.0:
return [], [0.0 for _ in range(vocab_size)], 0
bias = np.zeros(vocab_size, dtype=np.float64)
np.add.at(bias, label_array, weight_vector)
bias /= total_weight
if full_solver:
gram = centered_states.T @ weighted_centered_states
readout = ridge_regression_readout_from_moments(
gram,
cross,
regularization=regularization,
)
else:
feature_second_moment = (weighted_centered_states * centered_states).sum(axis=0)
readout = ridge_regression_readout_from_diagonal_moments(
feature_second_moment,
cross,
regularization=regularization,
)
return readout, bias, int(label_array.shape[0])
def _solve_prompt_readout_from_diagonal_moments(
feature_second_moment: object,
cross: object,
bias_counts: object,
*,
diagonal: object,
regularization: float,
total_weight: float,
) -> tuple[object, object, int]:
if np is None or total_weight <= 0.0:
vocab_size = len(bias_counts) if hasattr(bias_counts, "__len__") else 0
return [], [0.0 for _ in range(vocab_size)], 0
diagonal_array = np.asarray(diagonal, dtype=np.float64)
masked_feature_second = np.asarray(feature_second_moment, dtype=np.float64) * diagonal_array * diagonal_array
cross_array = np.asarray(cross, dtype=np.float64)
bias_counts_array = np.asarray(bias_counts, dtype=np.float64)
bias = bias_counts_array / max(float(total_weight), 1e-12)
denominator = masked_feature_second + regularization
denominator = np.where(np.abs(denominator) > 1e-12, denominator, regularization)
readout = np.zeros_like(cross_array, dtype=np.float64)
active_rows = bias_counts_array > 0.0
if np.any(active_rows):
readout[active_rows] = (cross_array[active_rows] * diagonal_array[None, :]) / denominator[None, :]
return readout, bias, int(round(float(total_weight)))
def fit_model_from_corpus_plan(
plan: Iterable[CorpusPlanEntry],
config: ReframrConfig,
*,
log_every: int = 0,
) -> tuple[ReframrModel, dict[str, object]]:
entries = list(plan)
if not entries:
raise ValueError("Cannot fit REFRAMR without any corpus plan entries.")
stage_seconds: dict[str, float] = {}
stage_started = time.perf_counter()
def finish_stage(name: str) -> None:
nonlocal stage_started
now = time.perf_counter()
elapsed = round(now - stage_started, 6)
stage_seconds[name] = elapsed
if log_every > 0:
print(f"[stage] {name} finished in {elapsed:.3f}s", flush=True)
stage_started = now
seed_tokenizer = NativeTokenizer(
merges=[],
vocab=[],
base_symbols=[],
lowercase=config.lowercase,
)
segment_counts: Counter[str] = Counter()
source_counts: dict[str, int] = {}
documents: list[StreamDocument] = []
processed = 0
for entry in entries:
if log_every > 0:
print(f"[source] {entry.name} started", flush=True)
source_start = processed
for document in iter_corpus_plan_documents([entry]):
documents.append(document)
processed += 1
source_counts[document.source] = source_counts.get(document.source, 0) + 1
for text_part, part_weight in _weighted_text_parts_for_statistics(
document.text,
document.weight,
):
for segment in seed_tokenizer.pretokenize(text_part):
segment_counts[segment] += part_weight
if document.preference_rejected_text:
rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT
for text_part, part_weight in _weighted_text_parts_for_statistics(
document.preference_rejected_text,
rejected_weight,
):
for segment in seed_tokenizer.pretokenize(text_part):
segment_counts[segment] += part_weight
_log_progress("tokenizer", processed, log_every)
if log_every > 0:
print(f"[source] {entry.name} accepted {processed - source_start} documents", flush=True)
if processed == 0:
raise ValueError("Corpus plan did not yield any usable documents after filtering.")
finish_stage("stream_and_segment")
tokenizer = NativeTokenizer.train_from_segment_counts(
segment_counts,
vocab_size=config.tokenizer_vocab_size,
min_pair_frequency=config.tokenizer_min_pair_frequency,
lowercase=config.lowercase,
)
finish_stage("tokenizer_fit")
token_counts: Counter[str] = Counter()
raw_tokenized_documents: list[list[str]] = []
raw_rejected_tokenized_documents: list[list[str]] = []
processed = 0
for document in documents:
processed += 1
tokens = tokenizer.encode(document.text)
raw_tokenized_documents.append(tokens)
for token in tokens:
if token in tokenizer.special_tokens:
token_counts[token] += document.weight
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
tokens,
tokenizer,
document.weight,
):
for token in token_sequence:
token_counts[token] += sequence_weight
rejected_tokens = (
tokenizer.encode(document.preference_rejected_text)
if document.preference_rejected_text
else []
)
raw_rejected_tokenized_documents.append(rejected_tokens)
rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT
for token in rejected_tokens:
if token in tokenizer.special_tokens:
token_counts[token] += rejected_weight
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
rejected_tokens,
tokenizer,
rejected_weight,
):
for token in token_sequence:
token_counts[token] += sequence_weight
_log_progress("vocab", processed, log_every)
token_to_id, id_to_token = build_vocabulary_from_counts(
token_counts,
min_frequency=config.min_frequency,
max_vocab=config.max_vocab,
)
id_to_token = complete_id_to_token(id_to_token, tokenizer.vocab)
token_to_id = {token: index for index, token in enumerate(id_to_token)}
if not id_to_token:
raise ValueError("Streaming recompute could not derive an embedding vocabulary.")
finish_stage("vocabulary")
state_tokenized_documents = [
_state_training_tokens(tokens, config.max_state_tokens_per_document)
for tokens in raw_tokenized_documents
]
state_tokens_before = sum(len(tokens) for tokens in raw_tokenized_documents)
state_tokens_after = sum(len(tokens) for tokens in state_tokenized_documents)
cooccurrence = StreamingCooccurrenceAccumulator(token_to_id, config.window_size)
tokenized_documents: list[list[str]] = []
state_documents: list[list[str]] = []
preference_token_pairs: list[tuple[list[str], list[str], float]] = []
processed = 0
for document, raw_tokens, state_raw_tokens, raw_rejected_tokens in zip(
documents,
raw_tokenized_documents,
state_tokenized_documents,
raw_rejected_tokenized_documents,
):
processed += 1
tokens = [token for token in raw_tokens if token in token_to_id]
tokenized_documents.append(tokens)
state_tokens = [token for token in state_raw_tokens if token in token_to_id]
state_documents.append(state_tokens)
rejected_tokens = [token for token in raw_rejected_tokens if token in token_to_id]
if len(tokens) > 1 and len(rejected_tokens) > 1:
preference_weight = document.weight * document.readout_weight
if preference_weight > 0.0:
preference_token_pairs.append((tokens, rejected_tokens, preference_weight))
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
tokens,
tokenizer,
document.weight,
):
if len(token_sequence) > 1:
cooccurrence.update_tokens(token_sequence, weight=sequence_weight)
_log_progress("cooccurrence", processed, log_every)
finish_stage("cooccurrence")
if np is not None:
embedding_model = fit_randomized_ppmi_embedding_from_counts(
id_to_token,
cooccurrence.rows,
embedding_dim=config.embedding_dim,
)
else:
embedding_model = fit_ppmi_embedding_from_cooccurrence(
id_to_token,
cooccurrence.to_sparse(),
embedding_dim=config.embedding_dim,
)
finish_stage("embedding")
model = ReframrModel(config)
model.tokenizer = tokenizer
model.embedding_model = embedding_model
model.memory_units = [
AnalyticalMemoryUnit(config.state_dim, timescale)
for timescale in config.timescales
]
model.trace_token_weights = model._derive_trace_token_weights_from_counts(token_counts)
feature_count = len(model._zero_combined_state())
if np is not None:
feature_second_moment = np.zeros(feature_count, dtype=np.float64)
raw_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
prompt_answer_feature_second_moment = np.zeros(feature_count, dtype=np.float64)
prompt_answer_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
prompt_answer_bias_counts = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
prompt_answer_weight_total = 0.0
prompt_answer_start_feature_second_moment = np.zeros(feature_count, dtype=np.float64)
prompt_answer_start_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
prompt_answer_start_bias_counts = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
prompt_answer_start_weight_total = 0.0
else:
feature_second_moment = zeros_vector(feature_count)
raw_cross = zeros(len(embedding_model.id_to_token), feature_count)
prompt_answer_feature_second_moment = zeros_vector(feature_count)
prompt_answer_cross = zeros(len(embedding_model.id_to_token), feature_count)
prompt_answer_bias_counts = zeros_vector(len(embedding_model.id_to_token))
prompt_answer_weight_total = 0.0
prompt_answer_start_feature_second_moment = zeros_vector(feature_count)
prompt_answer_start_cross = zeros(len(embedding_model.id_to_token), feature_count)
prompt_answer_start_bias_counts = zeros_vector(len(embedding_model.id_to_token))
prompt_answer_start_weight_total = 0.0
example_weight_total = 0.0
has_answer_targets = any(_answer_boundary(tokens) is not None for tokens in state_documents)
memory_example_cap = (
config.max_memory_examples
if config.max_memory_examples is not None
else config.max_training_examples
)
if memory_example_cap is None:
answer_reservoir_capacity = None
general_reservoir_capacity = None
elif memory_example_cap <= 0:
answer_reservoir_capacity = 0
general_reservoir_capacity = 0
elif has_answer_targets:
answer_reservoir_capacity = max(1, int(memory_example_cap * 0.75))
general_reservoir_capacity = max(0, memory_example_cap - answer_reservoir_capacity)
else:
answer_reservoir_capacity = 0
general_reservoir_capacity = memory_example_cap
answer_sequence_capacity = MAX_ANSWER_SEQUENCE_EXAMPLES if has_answer_targets else 0
answer_reservoir = StateReservoir(answer_reservoir_capacity, seed=17)
general_reservoir = StateReservoir(general_reservoir_capacity, seed=13)
answer_intent_reservoir = StateReservoir(answer_reservoir_capacity, seed=29)
answer_start_reservoir = StateReservoir(answer_reservoir_capacity, seed=37)
answer_sequence_reservoir = SequenceReservoir(answer_sequence_capacity, seed=41)
direct_moment_readout = (
np is not None
and (
feature_count > FULL_READOUT_FEATURE_LIMIT
or (
config.max_training_examples is not None
and config.max_training_examples > FULL_READOUT_EXAMPLE_LIMIT
)
)
)
moment_reservoir = StateReservoir(
0
if direct_moment_readout
else config.max_training_examples if config.max_training_examples is not None else None,
seed=31,
)
transitions = TransitionAccumulator(
max_contexts_per_order=config.max_transition_contexts_per_order,
max_next_tokens=config.max_transition_next_tokens,
)
if np is not None:
target_label_mass = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
else:
target_label_mass = zeros_vector(len(embedding_model.id_to_token))
for document, tokens in zip(documents, state_documents):
answer_index = _answer_boundary(tokens)
for index in range(len(tokens) - 1):
next_token = tokens[index + 1]
if tokenizer is not None and not _is_learned_output_token(next_token, tokenizer):
continue
next_token_id = embedding_model.token_to_id.get(next_token, -1)
if next_token_id < 0:
continue
label_weight = _readout_weight_for_target(
answer_index,
index + 1,
document.weight * document.readout_weight,
)
if label_weight > 0.0:
target_label_mass[next_token_id] += label_weight
target_balance, reference_label_mass = _target_balance_from_label_mass(target_label_mass)
processed = 0
embedding_array = (
np.asarray(embedding_model.embeddings, dtype=RUNTIME_ARRAY_DTYPE)
if np is not None
else None
)
trace_embedding_array = (
model._build_trace_embedding_table_array(embedding_array)
if np is not None and embedding_array is not None
else None
)
if np is not None:
trace_decay = np.asarray(
[1.0 / (1.0 + unit.timescale) for unit in model.memory_units],
dtype=RUNTIME_ARRAY_DTYPE,
)
timescale_array = np.asarray(
[unit.timescale for unit in model.memory_units],
dtype=RUNTIME_ARRAY_DTYPE,
)
trace_gain = 1.0 - trace_decay
transition_stack = np.asarray(
[unit.transition for unit in model.memory_units],
dtype=RUNTIME_ARRAY_DTYPE,
)
input_projection_stack = np.asarray(
[unit.input_projection for unit in model.memory_units],
dtype=RUNTIME_ARRAY_DTYPE,
)
drive_indices = np.arange(config.state_dim, dtype=np.int64)
drive_primary = drive_indices % config.embedding_dim
drive_secondary = (3 * drive_indices + 1) % config.embedding_dim
drive_tertiary = (5 * drive_indices + 2) % config.embedding_dim
hidden_feature_width = len(config.timescales) * config.state_dim
if HAS_COMPILED_HIPPO_KERNEL:
hippo_legs_propagate_stack_fast(
np.zeros((len(config.timescales), config.state_dim), dtype=RUNTIME_ARRAY_DTYPE),
timescale_array,
)
finish_stage("kernel_warmup")
else:
trace_decay = None
timescale_array = None
trace_gain = None
transition_stack = None
input_projection_stack = None
drive_primary = None
drive_secondary = None
drive_tertiary = None
hidden_feature_width = 0
for document, tokens in zip(documents, state_documents):
processed += 1
if len(tokens) < 2:
_log_progress("state", processed, log_every)
continue
answer_index = _answer_boundary(tokens)
readout_document_weight = document.weight * document.readout_weight
transition_document_weight = document.weight * document.transition_weight
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
tokens,
tokenizer,
transition_document_weight,
):
if len(token_sequence) > 1:
transitions.update_tokens(token_sequence, weight=sequence_weight)
compiled_combined_states = None
token_id_sequence = None
if np is not None:
hidden_state_matrix = np.zeros((len(config.timescales), config.state_dim), dtype=RUNTIME_ARRAY_DTYPE)
context_trace_matrix = np.zeros((len(config.timescales), config.embedding_dim), dtype=RUNTIME_ARRAY_DTYPE)
combined_state_buffer = np.empty(feature_count, dtype=RUNTIME_ARRAY_DTYPE)
if (
embedding_array is not None
and trace_embedding_array is not None
and timescale_array is not None
and trace_gain is not None
and input_projection_stack is not None
and drive_primary is not None
and drive_secondary is not None
and drive_tertiary is not None
):
token_id_sequence = np.asarray(
[embedding_model.token_to_id.get(token, -1) for token in tokens],
dtype=np.int64,
)
if token_id_sequence.size and int(token_id_sequence.min()) >= 0:
compiled_combined_states = hippo_document_combined_states_fast(
token_id_sequence,
embedding_array,
trace_embedding_array,
timescale_array,
trace_gain,
input_projection_stack,
drive_primary,
drive_secondary,
drive_tertiary,
state_dim=config.state_dim,
embedding_dim=config.embedding_dim,
)
else:
hidden_states = [zeros_vector(config.state_dim) for _ in config.timescales]
context_traces = [zeros_vector(config.embedding_dim) for _ in config.timescales]
combined_state_buffer = None
token_id_sequence = None
answer_anchor_state = None
for index in range(len(tokens) - 1):
token = tokens[index]
token_id = (
int(token_id_sequence[index])
if token_id_sequence is not None
else embedding_model.token_to_id.get(token, -1)
)
if compiled_combined_states is not None:
combined_state = None
elif (
np is not None
and embedding_array is not None
and trace_decay is not None
and timescale_array is not None
and trace_gain is not None
and transition_stack is not None
and input_projection_stack is not None
and drive_primary is not None
and drive_secondary is not None
and drive_tertiary is not None
and trace_embedding_array is not None
and token_id >= 0
):
embedding = embedding_array[token_id]
trace_embedding = trace_embedding_array[token_id]
drive = (
embedding[drive_primary]
+ (0.5 * embedding[drive_secondary])
- (0.25 * embedding[drive_tertiary])
)
if HAS_COMPILED_HIPPO_KERNEL:
hidden_state_matrix = hippo_legs_propagate_stack_fast(
hidden_state_matrix,
timescale_array,
) + (input_projection_stack * drive[None, :])
else:
hidden_state_matrix = (
(transition_stack @ hidden_state_matrix[:, :, None])[:, :, 0]
+ (input_projection_stack * drive[None, :])
)
context_trace_matrix = (
context_trace_matrix + (trace_gain[:, None] * trace_embedding[None, :])
)
else:
hidden_states, context_traces, combined_state = model._step_hidden_states(
hidden_states,
context_traces,
token,
)
if token == "<answer>":
if compiled_combined_states is not None:
answer_anchor_state = compiled_combined_states[index].copy()
elif np is not None:
combined_state_buffer[:hidden_feature_width] = hidden_state_matrix.reshape(-1)
combined_state_buffer[hidden_feature_width:] = context_trace_matrix.reshape(-1)
answer_anchor_state = combined_state_buffer.copy()
else:
answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:]
next_token = tokens[index + 1]
if not _is_learned_output_token(next_token, tokenizer):
continue
next_token_id = (
int(token_id_sequence[index + 1])
if token_id_sequence is not None
else embedding_model.token_to_id.get(next_token, -1)
)
if next_token_id < 0:
continue
raw_readout_weight = _readout_weight_for_target(
answer_index,
index + 1,
readout_document_weight,
)
readout_weight = raw_readout_weight * float(target_balance[next_token_id])
if readout_weight <= 0.0:
continue
moment_slot = (
None
if direct_moment_readout
else moment_reservoir.reserve_slot(weight=readout_weight)
)
is_answer_target = answer_index is not None and index + 1 > answer_index
target_reservoir = answer_reservoir if is_answer_target else general_reservoir
memory_weight = readout_weight
answer_token_offset = (
index - answer_index
if is_answer_target and answer_index is not None
else None
)
intent_slot = (
answer_intent_reservoir.reserve_slot(weight=memory_weight)
if is_answer_target and answer_anchor_state is not None
else None
)
answer_start_weight = (
raw_readout_weight * (ANSWER_START_DECAY ** answer_token_offset)
if (
answer_token_offset is not None
and answer_token_offset < ANSWER_START_TOKEN_WINDOW
)
else 0.0
)
answer_start_slot = (
answer_start_reservoir.reserve_slot(weight=answer_start_weight)
if answer_start_weight > 0.0 and answer_anchor_state is not None
else None
)
if np is not None:
reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight)
if is_answer_target and answer_anchor_state is not None:
prompt_answer_feature_second_moment += memory_weight * answer_anchor_state * answer_anchor_state
prompt_answer_cross[next_token_id] += memory_weight * answer_anchor_state
prompt_answer_bias_counts[next_token_id] += memory_weight
prompt_answer_weight_total += memory_weight
if answer_start_weight > 0.0:
answer_start_example_weight = answer_start_weight * float(target_balance[next_token_id])
prompt_answer_start_feature_second_moment += (
answer_start_example_weight * answer_anchor_state * answer_anchor_state
)
prompt_answer_start_cross[next_token_id] += (
answer_start_example_weight * answer_anchor_state
)
prompt_answer_start_bias_counts[next_token_id] += answer_start_example_weight
prompt_answer_start_weight_total += answer_start_example_weight
if direct_moment_readout or moment_slot is not None or reservoir_slot is not None:
if compiled_combined_states is not None:
combined_state = compiled_combined_states[index].copy()
else:
combined_state_buffer[:hidden_feature_width] = hidden_state_matrix.reshape(-1)
combined_state_buffer[hidden_feature_width:] = context_trace_matrix.reshape(-1)
combined_state = combined_state_buffer
if direct_moment_readout:
feature_second_moment += readout_weight * combined_state * combined_state
raw_cross[next_token_id] += readout_weight * combined_state
example_weight_total += readout_weight
if moment_slot is not None:
moment_reservoir.store_reserved(
moment_slot,
combined_state,
next_token_id,
example_weight=readout_weight,
)
if reservoir_slot is not None:
target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id)
if intent_slot is not None:
answer_intent_reservoir.store_reserved(
intent_slot,
answer_anchor_state,
next_token_id,
example_weight=memory_weight,
)
if answer_start_slot is not None:
answer_start_example_weight = answer_start_weight * float(target_balance[next_token_id])
answer_start_reservoir.store_reserved(
answer_start_slot,
answer_anchor_state,
next_token_id,
example_weight=answer_start_example_weight,
)
else:
reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight)
if moment_slot is None and reservoir_slot is None:
continue
if moment_slot is not None:
moment_reservoir.store_reserved(
moment_slot,
combined_state,
next_token_id,
example_weight=readout_weight,
)
if reservoir_slot is not None:
target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id)
if intent_slot is not None:
answer_intent_reservoir.store_reserved(
intent_slot,
answer_anchor_state,
next_token_id,
example_weight=memory_weight,
)
if answer_start_slot is not None:
answer_start_reservoir.store_reserved(
answer_start_slot,
answer_anchor_state,
next_token_id,
example_weight=answer_start_weight * target_balance[next_token_id],
)
if answer_anchor_state is not None and answer_index is not None:
prompt_token_ids = [
embedding_model.token_to_id[token]
for token in tokens[:answer_index]
if _is_learned_output_token(token, tokenizer)
and token in embedding_model.token_to_id
]
answer_token_ids = [
embedding_model.token_to_id[token]
for token in tokens[answer_index + 1 :]
if _is_learned_output_token(token, tokenizer)
and token in embedding_model.token_to_id
]
answer_sequence_reservoir.consider(
answer_anchor_state,
prompt_token_ids,
answer_token_ids,
weight=readout_document_weight * ANSWER_READOUT_WEIGHT,
)
_log_progress("state", processed, log_every)
moment_states = moment_reservoir.states
moment_labels = moment_reservoir.labels
moment_weights = moment_reservoir.weights
if not direct_moment_readout:
example_weight_total = sum(moment_weights)
if np is not None and moment_states:
state_matrix = moment_reservoir.state_matrix(dtype=np.float64)
weight_vector = np.asarray(moment_weights, dtype=np.float64)
weighted_states = weight_vector[:, None] * state_matrix
feature_second_moment += (weighted_states * state_matrix).sum(axis=0)
np.add.at(raw_cross, moment_labels, weighted_states)
elif moment_states:
for state, label_id, readout_weight in zip(moment_states, moment_labels, moment_weights):
for feature, value in enumerate(state):
weighted_value = readout_weight * value
feature_second_moment[feature] += weighted_value * value
raw_cross[label_id][feature] += weighted_value
if example_weight_total <= 0.0:
raise ValueError("Streaming recompute did not collect any next-token training examples.")
if np is not None:
feature_energy = (feature_second_moment / example_weight_total).tolist()
else:
feature_energy = [
feature_second_moment[index] / example_weight_total
for index in range(feature_count)
]
ternary_scale, ternary_mask = derive_ternary_mask_from_feature_energy(feature_energy)
if np is not None:
diagonal = np.asarray([ternary_scale * value for value in ternary_mask], dtype=np.float64)
masked_feature_second_moment = feature_second_moment * diagonal * diagonal
masked_cross = raw_cross * diagonal[None, :]
else:
diagonal = [ternary_scale * value for value in ternary_mask]
masked_feature_second_moment = [
feature_second_moment[index] * diagonal[index] * diagonal[index]
for index in range(feature_count)
]
masked_cross = [
[
raw_cross[row][col] * diagonal[col]
for col in range(feature_count)
]
for row in range(len(raw_cross))
]
readout_solver = "diagonal"
state_offset_values: object
readout_bias_values: object
if (
np is not None
and moment_states
and feature_count <= FULL_READOUT_FEATURE_LIMIT
and len(moment_states) <= FULL_READOUT_EXAMPLE_LIMIT
and (
config.max_training_examples is None
or config.max_training_examples <= FULL_READOUT_EXAMPLE_LIMIT
)
):
state_matrix = moment_reservoir.state_matrix(dtype=np.float64)
weight_vector = np.asarray(moment_weights, dtype=np.float64)
label_array = np.asarray(moment_labels, dtype=np.int64)
masked_states = state_matrix * diagonal[None, :]
total_weight = float(weight_vector.sum())
if total_weight <= 0.0:
total_weight = 1.0
state_offset_values = (weight_vector[:, None] * masked_states).sum(axis=0) / total_weight
centered_states = masked_states - state_offset_values[None, :]
weighted_centered_states = weight_vector[:, None] * centered_states
gram = centered_states.T @ weighted_centered_states
full_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
np.add.at(full_cross, label_array, weighted_centered_states)
readout_bias_values = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
np.add.at(readout_bias_values, label_array, weight_vector)
readout_bias_values /= total_weight
readout_weights = ridge_regression_readout_from_moments(
gram,
full_cross,
regularization=config.regularization,
)
readout_solver = "full"
else:
state_offset_values = (
np.zeros(feature_count, dtype=np.float64)
if np is not None
else [0.0 for _ in range(feature_count)]
)
if np is not None:
readout_bias_values = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
else:
readout_bias_values = [0.0 for _ in target_label_mass]
readout_weights = ridge_regression_readout_from_diagonal_moments(
masked_feature_second_moment,
masked_cross,
regularization=config.regularization,
)
finish_stage("state_and_readout")
model.ternary_scale = ternary_scale
model.ternary_mask = ternary_mask
model.readout_weights = readout_weights
model.state_offset = (
state_offset_values.tolist()
if hasattr(state_offset_values, "tolist")
else list(state_offset_values)
)
model.readout_bias = (
readout_bias_values.tolist()
if hasattr(readout_bias_values, "tolist")
else list(readout_bias_values)
)
model.preference_bias, preference_state_pairs = _derive_preference_bias_from_pairs(
model,
preference_token_pairs,
tokenizer,
)
finish_stage("preference")
reservoir_states = answer_reservoir.states + general_reservoir.states
reservoir_labels = answer_reservoir.labels + general_reservoir.labels
answer_intent_states = answer_intent_reservoir.states
answer_intent_labels = answer_intent_reservoir.labels
answer_start_states = answer_start_reservoir.states
answer_start_labels = answer_start_reservoir.labels
answer_sequence_states = answer_sequence_reservoir.keys
answer_sequence_prompt_rows = answer_sequence_reservoir.prompt_rows
answer_sequence_rows = answer_sequence_reservoir.token_rows
prompt_readout_full_solver = (
config.max_training_examples is not None
and config.max_training_examples <= FULL_READOUT_EXAMPLE_LIMIT
)
if np is not None and not prompt_readout_full_solver:
prompt_answer_weights, prompt_answer_bias, prompt_answer_readout_examples = (
_solve_prompt_readout_from_diagonal_moments(
prompt_answer_feature_second_moment,
prompt_answer_cross,
prompt_answer_bias_counts,
diagonal=diagonal,
regularization=config.regularization,
total_weight=prompt_answer_weight_total,
)
)
(
prompt_answer_start_weights,
prompt_answer_start_bias,
prompt_answer_start_readout_examples,
) = _solve_prompt_readout_from_diagonal_moments(
prompt_answer_start_feature_second_moment,
prompt_answer_start_cross,
prompt_answer_start_bias_counts,
diagonal=diagonal,
regularization=config.regularization,
total_weight=prompt_answer_start_weight_total,
)
else:
prompt_answer_weights, prompt_answer_bias, prompt_answer_readout_examples = (
_solve_weighted_prompt_readout(
answer_intent_reservoir.state_matrix(dtype=np.float64),
answer_intent_labels,
answer_intent_reservoir.weights,
vocab_size=len(embedding_model.id_to_token),
diagonal=diagonal,
state_offset=state_offset_values,
regularization=config.regularization,
full_solver=prompt_readout_full_solver,
)
)
(
prompt_answer_start_weights,
prompt_answer_start_bias,
prompt_answer_start_readout_examples,
) = _solve_weighted_prompt_readout(
answer_start_reservoir.state_matrix(dtype=np.float64),
answer_start_labels,
answer_start_reservoir.weights,
vocab_size=len(embedding_model.id_to_token),
diagonal=diagonal,
state_offset=state_offset_values,
regularization=config.regularization,
full_solver=prompt_readout_full_solver,
)
finish_stage("finalize_prompt_readouts")
model.prompt_answer_weights = prompt_answer_weights
model.prompt_answer_bias = (
prompt_answer_bias.tolist()
if hasattr(prompt_answer_bias, "tolist")
else list(prompt_answer_bias)
)
model.prompt_answer_start_weights = prompt_answer_start_weights
model.prompt_answer_start_bias = (
prompt_answer_start_bias.tolist()
if hasattr(prompt_answer_start_bias, "tolist")
else list(prompt_answer_start_bias)
)
if np is not None and reservoir_states:
reservoir_parts = [
array
for array in (
answer_reservoir.state_matrix(dtype=RUNTIME_ARRAY_DTYPE),
general_reservoir.state_matrix(dtype=RUNTIME_ARRAY_DTYPE),
)
if hasattr(array, "shape") and int(array.shape[0]) > 0
]
reservoir_array = np.concatenate(reservoir_parts, axis=0)
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
associative_array = ((reservoir_array * mask_array[None, :]) - offset_array[None, :]).astype(
RUNTIME_ARRAY_DTYPE,
copy=False,
)
model.associative_keys = associative_array
model.associative_key_norms = np.linalg.norm(associative_array, axis=1).tolist()
else:
offset_vector = model.state_offset
model.associative_keys = [
[
value - offset_vector[index]
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
]
for state in reservoir_states
]
model.associative_key_norms = [norm(state) for state in model.associative_keys]
model.associative_values = reservoir_labels[:]
if np is not None and answer_intent_states:
answer_intent_array = answer_intent_reservoir.state_matrix(dtype=RUNTIME_ARRAY_DTYPE)
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
answer_array = ((answer_intent_array * mask_array[None, :]) - offset_array[None, :]).astype(
RUNTIME_ARRAY_DTYPE,
copy=False,
)
model.answer_keys = answer_array
model.answer_key_norms = np.linalg.norm(answer_array, axis=1).tolist()
else:
offset_vector = model.state_offset
model.answer_keys = [
[
value - offset_vector[index]
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
]
for state in answer_intent_states
]
model.answer_key_norms = [norm(state) for state in model.answer_keys]
model.answer_values = answer_intent_labels[:]
if np is not None and answer_start_states:
answer_start_array = answer_start_reservoir.state_matrix(dtype=RUNTIME_ARRAY_DTYPE)
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
start_array = ((answer_start_array * mask_array[None, :]) - offset_array[None, :]).astype(
RUNTIME_ARRAY_DTYPE,
copy=False,
)
model.answer_start_keys = start_array
model.answer_start_key_norms = np.linalg.norm(start_array, axis=1).tolist()
else:
offset_vector = model.state_offset
model.answer_start_keys = [
[
value - offset_vector[index]
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
]
for state in answer_start_states
]
model.answer_start_key_norms = [norm(state) for state in model.answer_start_keys]
model.answer_start_values = answer_start_labels[:]
finish_stage("finalize_memory_arrays")
if np is not None and answer_sequence_states:
answer_sequence_array = answer_sequence_reservoir.state_matrix(dtype=RUNTIME_ARRAY_DTYPE)
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
sequence_array = ((answer_sequence_array * mask_array[None, :]) - offset_array[None, :]).astype(
RUNTIME_ARRAY_DTYPE,
copy=False,
)
model.answer_sequence_keys = sequence_array
model.answer_sequence_key_norms = np.linalg.norm(sequence_array, axis=1).tolist()
else:
offset_vector = model.state_offset
model.answer_sequence_keys = [
[
value - offset_vector[index]
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
]
for state in answer_sequence_states
]
model.answer_sequence_key_norms = [norm(state) for state in model.answer_sequence_keys]
if np is not None:
padded_answer_sequences = np.full(
(len(answer_sequence_rows), MAX_ANSWER_SEQUENCE_TOKENS),
-1,
dtype=np.int32,
)
for row_index, row in enumerate(answer_sequence_rows):
row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS)
if row_width > 0:
padded_answer_sequences[row_index, :row_width] = row[:row_width]
padded_answer_sequence_prompts = np.full(
(len(answer_sequence_prompt_rows), MAX_ANSWER_SEQUENCE_TOKENS),
-1,
dtype=np.int32,
)
for row_index, row in enumerate(answer_sequence_prompt_rows):
row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS)
if row_width > 0:
padded_answer_sequence_prompts[row_index, :row_width] = row[:row_width]
else:
padded_answer_sequences = [
row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))]
for row in answer_sequence_rows
]
padded_answer_sequence_prompts = [
row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))]
for row in answer_sequence_prompt_rows
]
model.answer_sequence_prompt_tokens = padded_answer_sequence_prompts
model.answer_sequence_tokens = padded_answer_sequences
finish_stage("finalize_answer_sequences")
model._refresh_answer_fingerprint_hashes()
if np is not None:
model.transition_tensor_cache = transitions.finalize_tensor_cache(
token_to_id=embedding_model.token_to_id,
max_contexts_per_order=config.max_transition_contexts_per_order,
max_next_tokens=config.max_transition_next_tokens,
)
model.transition_id_tables = {order: {} for order in sorted(TRANSITION_ORDERS)}
model.transition_tables = {order: {} for order in sorted(TRANSITION_ORDERS)}
else:
model.transition_tables = transitions.finalize(
max_contexts_per_order=config.max_transition_contexts_per_order,
max_next_tokens=config.max_transition_next_tokens,
)
finish_stage("finalize_transition_tables")
payload = {
"streaming": True,
"documents_processed": processed,
"source_counts": source_counts,
"embedding_vocab_size": len(embedding_model.id_to_token),
"tokenizer_vocab_size": tokenizer.vocab_size,
"examples_processed": int(round(example_weight_total)),
"associative_examples": len(model.associative_keys),
"answer_associative_examples": len(answer_reservoir.states),
"general_associative_examples": len(general_reservoir.states),
"answer_intent_examples": len(model.answer_keys),
"answer_start_examples": len(model.answer_start_keys),
"answer_sequence_examples": len(model.answer_sequence_keys),
"prompt_answer_readout_examples": prompt_answer_readout_examples,
"prompt_answer_start_readout_examples": prompt_answer_start_readout_examples,
"stage_seconds": stage_seconds,
"target_balance_reference": round(float(reference_label_mass), 6),
"readout_solver": readout_solver,
"preference_pairs": len(preference_token_pairs),
"preference_state_pairs": preference_state_pairs,
"max_memory_examples": memory_example_cap,
"max_state_tokens_per_document": config.max_state_tokens_per_document,
"state_tokens_before_sketch": state_tokens_before,
"state_tokens_after_sketch": state_tokens_after,
}
return model, payload