OkeyMeta's picture
Add-openai-compatible-runtime-docs
5348cd5 verified
import argparse
import json
import sys
from dataclasses import replace
from pathlib import Path
from .checkpoint import inspect_checkpoint
from .config import ReframrConfig
from .corpus_recipes import (
build_foundation_corpus,
build_generalization_corpus,
write_corpus_package,
)
from .curriculum import CurriculumConfig, write_curriculum_package
from .datasets import load_prompt_suite, load_text_corpus
from .evaluation import (
benchmark_open_prompts,
evaluate_manifest,
load_manifest,
load_replay_sources,
)
from .hf_import import import_hf_dataset
from .materialize import DEFAULT_CACHE_BYTE_LIMIT, DEFAULT_SHARD_BYTE_LIMIT, materialize_corpus_plan
from .model import ReframrModel
from .reasoning import REASONING_PROFILES, TOKENIZER_NAME, reasoning_prefix
from .sparse_context import (
AnalyticalSparseAttention,
FaissSparseAttention,
HashedSparseAttention,
compare_selectors,
)
from .streaming import estimate_corpus_plan, fit_model_from_corpus_plan, load_corpus_plan
from .tokenizer import MAX_TOKENIZER_VOCAB_SIZE, clamp_vocab_size, recommend_vocab_size
from .v2_data import write_blind_prompt_suite, write_v2_streaming_plan
def configure_stdio() -> None:
for stream in (sys.stdout, sys.stderr):
reconfigure = getattr(stream, "reconfigure", None)
if reconfigure is not None:
reconfigure(encoding="utf-8")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="reframr",
description="Compute and query REFRAMR analytical language model checkpoints.",
)
subparsers = parser.add_subparsers(dest="command", required=True)
compute = subparsers.add_parser(
"compute",
aliases=["train"],
help="Compute a REFRAMR checkpoint from a text corpus with no epoch loop.",
)
compute.add_argument(
"--input",
required=True,
help="Path to a text, JSON, or JSONL corpus file, or a directory of such files.",
)
compute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.")
compute.add_argument("--embedding-dim", type=int, default=16)
compute.add_argument("--state-dim", type=int, default=32)
compute.add_argument("--timescales", default="1.0,0.5,0.25,0.125")
compute.add_argument("--window-size", type=int, default=2)
compute.add_argument("--regularization", type=float, default=1e-3)
compute.add_argument("--min-frequency", type=int, default=1)
compute.add_argument(
"--max-vocab",
type=int,
default=256,
help="Cap analytical embedding vocabulary to keep weight computation fast on CPU.",
)
compute.add_argument("--tokenizer-vocab-size", type=int, default=0)
compute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2)
compute.add_argument(
"--max-training-examples",
type=int,
default=60000,
help="Cap sampled recurrent training states while still reading the full corpus for tokenizer, embeddings, and transitions.",
)
compute.add_argument(
"--max-memory-examples",
type=int,
default=-1,
help="Cap saved associative memory examples separately from readout training. Use -1 to match --max-training-examples.",
)
compute.add_argument(
"--max-state-tokens-per-document",
type=int,
default=768,
help="Cap recurrent state steps per document with a deterministic corpus sketch. Use 0 to step full documents.",
)
compute.add_argument(
"--max-transition-contexts",
type=int,
default=4096,
help="Keep only the strongest learned transition contexts per order. Use 0 to disable the cap.",
)
compute.add_argument(
"--max-transition-next-tokens",
type=int,
default=4,
help="Keep this many learned next-token choices per transition context.",
)
case_group = compute.add_mutually_exclusive_group()
case_group.add_argument(
"--lowercase",
action="store_true",
help="Normalize corpus text to lowercase before tokenization.",
)
case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS)
compute.add_argument(
"--reasoning-profile",
choices=sorted(REASONING_PROFILES),
default="none",
help="Default reasoning-control profile baked into the checkpoint.",
)
compute.add_argument(
"--layout-profile",
default="rfm-base",
help="Structured analytical layout label to store in checkpoint metadata, such as rfm-70b-structured.",
)
compute.add_argument(
"--effective-parameter-target",
type=int,
default=0,
help="Dense-equivalent structured target to store in checkpoint metadata; this does not allocate dense tensors.",
)
recompute = subparsers.add_parser(
"recompute",
help="Compute a REFRAMR checkpoint from a streaming corpus plan with no raw-text cache.",
)
recompute.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.")
recompute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.")
recompute.add_argument("--embedding-dim", type=int, default=16)
recompute.add_argument("--state-dim", type=int, default=32)
recompute.add_argument("--timescales", default="1.0,0.5,0.25,0.125")
recompute.add_argument("--window-size", type=int, default=2)
recompute.add_argument("--regularization", type=float, default=1e-3)
recompute.add_argument("--min-frequency", type=int, default=1)
recompute.add_argument("--max-vocab", type=int, default=256)
recompute.add_argument("--tokenizer-vocab-size", type=int, default=0)
recompute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2)
recompute.add_argument("--max-training-examples", type=int, default=60000)
recompute.add_argument("--max-memory-examples", type=int, default=-1)
recompute.add_argument("--max-state-tokens-per-document", type=int, default=768)
recompute.add_argument("--max-transition-contexts", type=int, default=4096)
recompute.add_argument("--max-transition-next-tokens", type=int, default=4)
recompute.add_argument("--log-every", type=int, default=0)
recompute.add_argument(
"--dry-run",
action="store_true",
help="Estimate accepted rows and compute shape without fitting or saving a checkpoint.",
)
recompute.add_argument(
"--estimate-max-rows-per-source",
type=int,
default=0,
help="Optional cap for preflight row scanning per local source.",
)
recompute.add_argument(
"--calibrate-rows",
type=int,
default=0,
help="Run a bounded representative fit first and estimate full-run wall-clock time.",
)
recompute.add_argument(
"--calibrate-only",
action="store_true",
help="Stop after calibration instead of computing and saving the full checkpoint.",
)
recompute_case_group = recompute.add_mutually_exclusive_group()
recompute_case_group.add_argument("--lowercase", action="store_true")
recompute_case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS)
recompute.add_argument(
"--reasoning-profile",
choices=sorted(REASONING_PROFILES),
default="none",
help="Default reasoning-control profile baked into the checkpoint.",
)
recompute.add_argument(
"--layout-profile",
default="rfm-base",
help="Structured analytical layout label to store in checkpoint metadata, such as rfm-70b-structured.",
)
recompute.add_argument(
"--effective-parameter-target",
type=int,
default=0,
help="Dense-equivalent structured target to store in checkpoint metadata; this does not allocate dense tensors.",
)
predict = subparsers.add_parser("predict", help="Predict the next-token distribution from a saved model.")
predict.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
predict.add_argument("--context", required=True, help="Input context text.")
predict.add_argument("--top-k", type=int, default=5)
predict.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
generate = subparsers.add_parser("generate", help="Generate long-form text from a saved model.")
generate.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
generate.add_argument("--context", required=True, help="Prompt or starting context text.")
generate.add_argument("--system", default="", help="Optional system instruction to prepend as learned context.")
generate.add_argument("--max-tokens", type=int, default=64)
generate.add_argument("--temperature", type=float, default=0.82)
generate.add_argument("--decode-top-k", type=int, default=24)
generate.add_argument("--decode-top-p", type=float, default=0.92)
generate.add_argument("--repetition-penalty", type=float, default=1.18)
generate.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
generate_batch = subparsers.add_parser(
"generate-batch",
help="Generate answers for a prompt file while keeping one checkpoint loaded.",
)
generate_batch.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
generate_batch.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.")
generate_batch.add_argument("--output", required=True, help="Path to write JSONL generations.")
generate_batch.add_argument("--max-tokens", type=int, default=64)
generate_batch.add_argument("--temperature", type=float, default=0.82)
generate_batch.add_argument("--decode-top-k", type=int, default=24)
generate_batch.add_argument("--decode-top-p", type=float, default=0.92)
generate_batch.add_argument("--repetition-penalty", type=float, default=1.18)
generate_batch.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
serve = subparsers.add_parser(
"serve",
help="Keep one checkpoint loaded and answer JSONL generation requests from stdin.",
)
serve.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
serve.add_argument("--max-tokens", type=int, default=64)
serve.add_argument("--temperature", type=float, default=0.82)
serve.add_argument("--decode-top-k", type=int, default=24)
serve.add_argument("--decode-top-p", type=float, default=0.92)
serve.add_argument("--repetition-penalty", type=float, default=1.18)
serve.add_argument(
"--memory-turns",
type=int,
default=16,
help="Number of prior JSONL session turns to prepend as conversation memory.",
)
serve.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
chat_completion = subparsers.add_parser(
"chat-completion",
help="Run one OpenAI-compatible chat completion request from stdin or a JSON file.",
)
chat_completion.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
chat_completion.add_argument(
"--request",
default="",
help="Optional path to a JSON request. Defaults to stdin.",
)
trace = subparsers.add_parser("trace", help="Trace REFRAMR reasoning components through generation steps.")
trace.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
trace.add_argument("--context", required=True, help="Prompt or starting context text.")
trace.add_argument("--max-tokens", type=int, default=8)
trace.add_argument("--top-k", type=int, default=5)
trace.add_argument("--temperature", type=float, default=0.82)
trace.add_argument("--decode-top-p", type=float, default=0.92)
trace.add_argument("--repetition-penalty", type=float, default=1.18)
trace.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
inspect = subparsers.add_parser("inspect", help="Inspect a REFRAMR safetensors checkpoint.")
inspect.add_argument("--model", required=True, help="Path to a .safetensors checkpoint.")
craft = subparsers.add_parser(
"craft-corpus",
help="Generate a JSON-first bootstrap corpus, manifest, and generalization prompt suite.",
)
craft.add_argument("--output-dir", required=True, help="Directory to write corpus and manifest files.")
craft.add_argument(
"--variant",
choices=("foundation", "generalization"),
default="foundation",
help="Choose between the mixed foundation corpus and the language-first generalization corpus.",
)
craft_curriculum = subparsers.add_parser(
"craft-curriculum",
help="Generate the OkeyMeta JSON curriculum shard, manifest, holdout prompts, and recompute plan.",
)
craft_curriculum.add_argument("--output-dir", required=True, help="Directory to write curriculum files.")
craft_curriculum.add_argument(
"--records-per-category",
type=int,
default=1000,
help="How many JSON records to generate for each curriculum category.",
)
craft_curriculum.add_argument("--seed", type=int, default=7)
craft_curriculum.add_argument("--train-ratio", type=float, default=0.92)
craft_curriculum.add_argument(
"--effective-token-target",
type=int,
default=0,
help="Set plan weighting so compact curriculum statistics represent this many effective tokens.",
)
craft_v2_plan = subparsers.add_parser(
"craft-v2-plan",
help="Write a strict streaming Hugging Face recompute plan for the v2 data mix.",
)
craft_v2_plan.add_argument("--output", required=True, help="Path to write the streaming plan JSON.")
craft_v2_plan.add_argument(
"--rows-per-source",
type=int,
default=10_000,
help="Base accepted row target per source before per-domain multipliers.",
)
craft_v2_plan.add_argument(
"--effective-token-target",
type=int,
default=0,
help="Optional effective token target recorded in the plan metadata.",
)
craft_v2_plan.add_argument(
"--wikipedia-mode",
choices=("skip", "hf", "viewer"),
default="skip",
help="Use skip for fast smoke runs; hf/viewer include Wikipedia through the fast HF viewer adapter.",
)
craft_v2_plan.add_argument(
"--local-curriculum",
action="append",
default=[],
help="Local JSON/JSONL curriculum shard to blend before HF sources.",
)
craft_v2_plan.add_argument(
"--local-curriculum-limit",
type=int,
default=0,
help="Maximum accepted rows per local curriculum shard. Use 0 for all rows.",
)
materialize_plan = subparsers.add_parser(
"materialize-plan",
help="Write bounded normalized JSONL shards from a corpus plan, then emit a local recompute plan.",
)
materialize_plan.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.")
materialize_plan.add_argument("--output-dir", required=True, help="Directory for normalized JSONL shards.")
materialize_plan.add_argument(
"--max-gb",
type=float,
default=DEFAULT_CACHE_BYTE_LIMIT / (1024 ** 3),
help="Maximum normalized cache size in GB. Defaults to 3GB.",
)
materialize_plan.add_argument(
"--shard-mb",
type=int,
default=DEFAULT_SHARD_BYTE_LIMIT // (1024 ** 2),
help="Maximum size per JSONL shard in MB.",
)
materialize_plan.add_argument("--log-every", type=int, default=0)
craft_blind_prompts = subparsers.add_parser(
"craft-blind-prompts",
help="Write a blind open-prompt JSONL suite for v2 generalization checks.",
)
craft_blind_prompts.add_argument("--output", required=True, help="Path to write JSONL prompts.")
craft_blind_prompts.add_argument("--seed", type=int, default=2026)
craft_blind_prompts.add_argument(
"--variants-per-intent",
type=int,
default=4,
help="How many prompt variants to generate per evaluation intent.",
)
evaluate = subparsers.add_parser(
"evaluate",
help="Evaluate memorization and held-out generalization from a benchmark manifest.",
)
evaluate.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
evaluate.add_argument("--manifest", required=True, help="Path to a corpus benchmark manifest JSON file.")
evaluate.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile during evaluation.",
)
evaluate.add_argument("--top-k", type=int, default=5)
benchmark_open = subparsers.add_parser(
"benchmark-open",
help="Run arbitrary prompt files through a checkpoint with open-ended output metrics.",
)
benchmark_open.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
benchmark_open.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.")
benchmark_open.add_argument("--max-tokens", type=int, default=64)
benchmark_open.add_argument("--temperature", type=float, default=0.82)
benchmark_open.add_argument("--decode-top-k", type=int, default=24)
benchmark_open.add_argument("--decode-top-p", type=float, default=0.92)
benchmark_open.add_argument("--repetition-penalty", type=float, default=1.18)
benchmark_open.add_argument(
"--replay-source",
action="append",
default=[],
help="JSON/JSONL/TXT corpus path used only to flag generated source-row replay.",
)
benchmark_open.add_argument(
"--replay-source-limit",
type=int,
default=10_000,
help="Maximum source rows to load for replay checks.",
)
benchmark_open.add_argument("--replay-ngram-size", type=int, default=8)
benchmark_open.add_argument("--replay-overlap-threshold", type=float, default=0.70)
benchmark_open.add_argument(
"--output",
default="",
help="Optional UTF-8 JSON path for benchmark results.",
)
benchmark_open.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile during benchmarking.",
)
sparse_benchmark = subparsers.add_parser(
"sparse-context-benchmark",
help="Measure analytical sparse-context selection speed on a checkpoint embedding table.",
)
sparse_benchmark.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
sparse_benchmark.add_argument("--context-tokens", type=int, default=100_000)
sparse_benchmark.add_argument("--query-count", type=int, default=64)
sparse_benchmark.add_argument("--top-k", type=int, default=64)
sparse_benchmark.add_argument("--seed", type=int, default=2026)
sparse_benchmark.add_argument(
"--selector",
choices=("exact", "hashed", "faiss"),
default="hashed",
help="Use exact cosine scan or hashed approximate sparse selection.",
)
sparse_benchmark.add_argument("--hash-bits", type=int, default=12)
sparse_benchmark.add_argument("--probe-radius", type=int, default=1)
sparse_benchmark.add_argument("--candidate-multiplier", type=int, default=12)
sparse_benchmark.add_argument("--faiss-hnsw", action="store_true")
sparse_benchmark.add_argument("--hnsw-neighbors", type=int, default=32)
sparse_benchmark.add_argument("--ef-search", type=int, default=64)
sparse_benchmark.add_argument(
"--compare-exact",
action="store_true",
help="Also compute exact top-k recall for the selected query set.",
)
sparse_benchmark.add_argument("--output", default="", help="Optional UTF-8 JSON path for benchmark results.")
import_hf = subparsers.add_parser(
"import-hf",
help="Import Hugging Face dataset text into the REFRAMR JSON record standard.",
)
import_hf.add_argument("--dataset", required=True, help="Hugging Face dataset id.")
import_hf.add_argument("--output", required=True, help="Path to write the JSONL corpus.")
import_hf.add_argument("--config", default=None, help="Optional dataset config/subset.")
import_hf.add_argument("--split", default="train", help="Dataset split to import.")
import_hf.add_argument("--text-field", default=None, help="Explicit text column name.")
import_hf.add_argument("--limit", type=int, default=1000, help="Maximum records to import.")
import_hf.add_argument(
"--min-words",
type=int,
default=0,
help="Drop imported records shorter than this many words.",
)
import_hf.add_argument(
"--max-words",
type=int,
default=0,
help="Drop imported records longer than this many words. Use 0 to disable.",
)
import_hf.add_argument(
"--min-alpha-ratio",
type=float,
default=0.0,
help="Drop imported records whose alphabetic-character ratio falls below this threshold.",
)
import_hf.add_argument(
"--allowed-languages",
default="",
help="Optional comma-separated language codes to keep, such as en,yo,ig,ha.",
)
import_hf.add_argument(
"--preference-target",
choices=("both", "chosen", "rejected"),
default="chosen",
help="When importing preference datasets, keep both sides or only the chosen/rejected side.",
)
import_hf.add_argument(
"--no-streaming",
action="store_true",
help="Disable streaming dataset reads.",
)
return parser
def parse_timescales(raw_timescales: str) -> tuple[float, ...]:
values = [segment.strip() for segment in raw_timescales.split(",") if segment.strip()]
if not values:
raise ValueError("At least one timescale is required.")
return tuple(float(value) for value in values)
def command_compute(args: argparse.Namespace) -> int:
text = load_text_corpus(args.input)
requested_vocab_size = args.tokenizer_vocab_size or recommend_vocab_size(
text,
lowercase=args.lowercase,
)
tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size)
config = ReframrConfig(
embedding_dim=args.embedding_dim,
state_dim=args.state_dim,
timescales=parse_timescales(args.timescales),
window_size=args.window_size,
regularization=args.regularization,
min_frequency=args.min_frequency,
max_vocab=args.max_vocab,
tokenizer_vocab_size=tokenizer_vocab_size,
tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency,
max_training_examples=args.max_training_examples,
max_memory_examples=(
None
if args.max_memory_examples < 0
else args.max_memory_examples
),
max_state_tokens_per_document=(
None
if args.max_state_tokens_per_document <= 0
else args.max_state_tokens_per_document
),
max_transition_contexts_per_order=(
args.max_transition_contexts if args.max_transition_contexts > 0 else None
),
max_transition_next_tokens=args.max_transition_next_tokens,
lowercase=args.lowercase,
default_reasoning_profile=args.reasoning_profile,
layout_profile=args.layout_profile,
effective_parameter_target=args.effective_parameter_target,
)
model = ReframrModel(config).fit(text)
model.save(args.output)
assert model.tokenizer is not None
assert model.embedding_model is not None
summary = {
"status": "computed",
"format": "safetensors",
"model_path": str(Path(args.output).resolve()),
"tokenizer_name": TOKENIZER_NAME,
"vocab_size": len(model.embedding_model.id_to_token),
"tokenizer_vocab_budget": config.tokenizer_vocab_size,
"tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE,
"tokenizer_vocab_size": model.tokenizer.vocab_size,
"reasoning_profile": config.default_reasoning_profile,
"reasoning_tokens": reasoning_prefix(config.default_reasoning_profile),
"lowercase": config.lowercase,
"max_training_examples": config.max_training_examples,
"max_memory_examples": config.max_memory_examples,
"max_state_tokens_per_document": config.max_state_tokens_per_document,
"max_transition_contexts_per_order": config.max_transition_contexts_per_order,
"max_transition_next_tokens": config.max_transition_next_tokens,
"embedding_dim": config.embedding_dim,
"state_dim": config.state_dim,
"timescales": list(config.timescales),
"layout_profile": config.layout_profile,
"effective_parameter_target": config.effective_parameter_target,
}
print(json.dumps(summary))
return 0
def command_recompute(args: argparse.Namespace) -> int:
plan = load_corpus_plan(args.plan)
requested_vocab_size = args.tokenizer_vocab_size or 1024
tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size)
config = ReframrConfig(
embedding_dim=args.embedding_dim,
state_dim=args.state_dim,
timescales=parse_timescales(args.timescales),
window_size=args.window_size,
regularization=args.regularization,
min_frequency=args.min_frequency,
max_vocab=args.max_vocab,
tokenizer_vocab_size=tokenizer_vocab_size,
tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency,
max_training_examples=args.max_training_examples,
max_memory_examples=(
None
if args.max_memory_examples < 0
else args.max_memory_examples
),
max_state_tokens_per_document=(
None
if args.max_state_tokens_per_document <= 0
else args.max_state_tokens_per_document
),
max_transition_contexts_per_order=(
args.max_transition_contexts if args.max_transition_contexts > 0 else None
),
max_transition_next_tokens=args.max_transition_next_tokens,
lowercase=args.lowercase,
default_reasoning_profile=args.reasoning_profile,
layout_profile=args.layout_profile,
effective_parameter_target=args.effective_parameter_target,
)
if args.dry_run:
estimate = estimate_corpus_plan(
plan,
max_rows_per_source=args.estimate_max_rows_per_source,
)
accepted = int(estimate.get("accepted_documents", 0) or 0)
state_cap = config.max_state_tokens_per_document or 768
estimated_state_tokens = accepted * state_cap
summary = {
"status": "dry_run",
"plan_path": str(Path(args.plan).resolve()),
"output_path": str(Path(args.output).resolve()),
"accepted_documents": accepted,
"seen_texts": estimate.get("seen_texts", 0),
"rejected_texts": estimate.get("rejected_texts", 0),
"estimated_words": estimate.get("estimated_words", 0),
"estimated_state_token_budget": estimated_state_tokens,
"embedding_dim": config.embedding_dim,
"state_dim": config.state_dim,
"tokenizer_vocab_budget": config.tokenizer_vocab_size,
"max_vocab": config.max_vocab,
"max_training_examples": config.max_training_examples,
"max_memory_examples": config.max_memory_examples,
"max_state_tokens_per_document": config.max_state_tokens_per_document,
"max_transition_contexts_per_order": config.max_transition_contexts_per_order,
"max_transition_next_tokens": config.max_transition_next_tokens,
"layout_profile": config.layout_profile,
"effective_parameter_target": config.effective_parameter_target,
"estimate_seconds": estimate.get("seconds", 0),
"sources": estimate.get("sources", []),
}
print(json.dumps(summary))
return 0
if args.calibrate_rows > 0:
calibration = _calibrate_recompute_plan(
plan,
config,
target_rows=args.calibrate_rows,
estimate_max_rows_per_source=args.estimate_max_rows_per_source,
log_every=args.log_every,
)
print(json.dumps(calibration), flush=True)
if args.calibrate_only:
return 0
model, payload = fit_model_from_corpus_plan(
plan,
config,
log_every=args.log_every,
)
model.save(args.output)
summary = {
"status": "recomputed",
"format": "safetensors",
"streaming": True,
"plan_path": str(Path(args.plan).resolve()),
"model_path": str(Path(args.output).resolve()),
"tokenizer_name": TOKENIZER_NAME,
"tokenizer_vocab_budget": config.tokenizer_vocab_size,
"tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE,
"tokenizer_vocab_size": payload["tokenizer_vocab_size"],
"vocab_size": payload["embedding_vocab_size"],
"documents_processed": payload["documents_processed"],
"source_counts": payload["source_counts"],
"examples_processed": payload["examples_processed"],
"associative_examples": payload["associative_examples"],
"answer_associative_examples": payload.get("answer_associative_examples", 0),
"general_associative_examples": payload.get("general_associative_examples", 0),
"answer_intent_examples": payload.get("answer_intent_examples", 0),
"answer_start_examples": payload.get("answer_start_examples", 0),
"answer_sequence_examples": payload.get("answer_sequence_examples", 0),
"prompt_answer_readout_examples": payload.get("prompt_answer_readout_examples", 0),
"prompt_answer_start_readout_examples": payload.get("prompt_answer_start_readout_examples", 0),
"preference_pairs": payload.get("preference_pairs", 0),
"preference_state_pairs": payload.get("preference_state_pairs", 0),
"stage_seconds": payload.get("stage_seconds", {}),
"readout_solver": payload.get("readout_solver"),
"reasoning_profile": config.default_reasoning_profile,
"reasoning_tokens": reasoning_prefix(config.default_reasoning_profile),
"lowercase": config.lowercase,
"max_training_examples": config.max_training_examples,
"max_memory_examples": config.max_memory_examples,
"max_state_tokens_per_document": config.max_state_tokens_per_document,
"state_tokens_before_sketch": payload.get("state_tokens_before_sketch", 0),
"state_tokens_after_sketch": payload.get("state_tokens_after_sketch", 0),
"max_transition_contexts_per_order": config.max_transition_contexts_per_order,
"max_transition_next_tokens": config.max_transition_next_tokens,
"embedding_dim": config.embedding_dim,
"state_dim": config.state_dim,
"timescales": list(config.timescales),
"layout_profile": config.layout_profile,
"effective_parameter_target": config.effective_parameter_target,
}
print(json.dumps(summary))
return 0
def _limited_calibration_plan(
plan: list[object],
*,
target_rows: int,
full_accepted: int,
) -> list[object]:
if target_rows <= 0:
return plan
ratio = min(1.0, target_rows / max(1, full_accepted))
limited: list[object] = []
fallback_limit = max(1, target_rows // max(1, len(plan)))
for entry in plan:
raw_limit = int(getattr(entry, "limit", 0) or 0)
if raw_limit > 0:
next_limit = max(1, min(raw_limit, int((raw_limit * ratio) + 0.999999)))
else:
record_count = len(getattr(entry, "records", ()) or ())
source_cap = record_count if record_count > 0 else fallback_limit
next_limit = max(1, min(source_cap, fallback_limit))
limited.append(replace(entry, limit=next_limit))
return limited
def _estimate_full_seconds_from_calibration(
*,
full_documents: int,
full_state_tokens: int,
calibration_payload: dict[str, object],
) -> dict[str, object]:
calibration_documents = max(1, int(calibration_payload.get("documents_processed", 0) or 0))
calibration_state_tokens = max(
1,
int(calibration_payload.get("state_tokens_after_sketch", 0) or 0),
)
document_scale = full_documents / calibration_documents
state_scale = full_state_tokens / calibration_state_tokens
stage_seconds = calibration_payload.get("stage_seconds", {})
if not isinstance(stage_seconds, dict):
stage_seconds = {}
fixed_weighted = {"tokenizer_fit", "embedding", "kernel_warmup", "preference"}
state_weighted = {"state_and_readout", "finalize_prompt_readouts", "finalize_memory_arrays"}
document_weighted = {
"stream_and_segment",
"vocabulary",
"cooccurrence",
"model_finalize",
"finalize_answer_sequences",
"finalize_transition_tables",
}
stage_estimates: dict[str, float] = {}
for stage, raw_seconds in stage_seconds.items():
seconds = float(raw_seconds)
if stage in fixed_weighted:
scale = 1.0
elif stage in state_weighted:
scale = state_scale
elif stage in document_weighted:
scale = document_scale
else:
scale = max(document_scale, state_scale)
stage_estimates[str(stage)] = round(seconds * scale, 3)
total_seconds = round(sum(stage_estimates.values()), 3)
return {
"estimated_full_seconds": total_seconds,
"estimated_full_minutes": round(total_seconds / 60.0, 3),
"scale_documents": round(document_scale, 4),
"scale_state_tokens": round(state_scale, 4),
"stage_estimates": stage_estimates,
}
def _calibrate_recompute_plan(
plan: list[object],
config: ReframrConfig,
*,
target_rows: int,
estimate_max_rows_per_source: int,
log_every: int,
) -> dict[str, object]:
full_estimate = estimate_corpus_plan(
plan,
max_rows_per_source=estimate_max_rows_per_source,
)
full_documents = int(full_estimate.get("accepted_documents", 0) or 0)
state_cap = config.max_state_tokens_per_document or 768
full_state_tokens = full_documents * state_cap
calibration_plan = _limited_calibration_plan(
plan,
target_rows=target_rows,
full_accepted=full_documents,
)
_, calibration_payload = fit_model_from_corpus_plan(
calibration_plan,
config,
log_every=log_every,
)
runtime_estimate = _estimate_full_seconds_from_calibration(
full_documents=full_documents,
full_state_tokens=full_state_tokens,
calibration_payload=calibration_payload,
)
return {
"status": "calibration",
"target_rows": target_rows,
"full_accepted_documents": full_documents,
"full_estimated_words": full_estimate.get("estimated_words", 0),
"full_estimated_state_token_budget": full_state_tokens,
"calibration_documents": calibration_payload.get("documents_processed", 0),
"calibration_state_tokens": calibration_payload.get("state_tokens_after_sketch", 0),
"calibration_stage_seconds": calibration_payload.get("stage_seconds", {}),
**runtime_estimate,
}
def command_predict(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
distribution = model.predict_next_distribution(
args.context,
reasoning_mode=args.reasoning_mode,
)
predictions = sorted(
distribution.items(),
key=lambda item: item[1],
reverse=True,
)[: args.top_k]
payload = {
"context": args.context,
"reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile,
"reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile),
"predictions": [
{"token": token, "probability": probability}
for token, probability in predictions
],
}
print(json.dumps(payload))
return 0
def command_generate(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
context = compose_generation_context(args.context, system=args.system)
generated_text = model.generate_text(
context,
max_tokens=args.max_tokens,
reasoning_mode=args.reasoning_mode,
temperature=args.temperature,
top_k=args.decode_top_k,
top_p=args.decode_top_p,
repetition_penalty=args.repetition_penalty,
)
payload = {
"context": context,
"reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile,
"reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile),
"generated_token_count": len(generated_text.split()),
"generated_text": generated_text,
}
print(json.dumps(payload))
return 0
def _content_to_text(content: object) -> str:
if content is None:
return ""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict):
text = item.get("text", item.get("content", item.get("input_text", "")))
if text:
parts.append(str(text).strip())
elif item is not None:
parts.append(str(item).strip())
return "\n".join(part for part in parts if part)
if isinstance(content, (dict, tuple)):
return json.dumps(content, ensure_ascii=False, separators=(",", ":"))
return str(content).strip()
def _coerce_json_payload(payload: object) -> object:
if not isinstance(payload, str):
return payload
stripped = payload.strip()
if not stripped:
return ""
try:
return json.loads(stripped)
except json.JSONDecodeError:
return stripped
def _render_source_lines(payload: object) -> list[str]:
if not isinstance(payload, dict):
return []
nested_content = payload.get("content")
if isinstance(nested_content, dict):
nested_lines = _render_source_lines(nested_content)
if nested_lines:
return nested_lines
raw_sources = payload.get("sources", payload.get("source", []))
if isinstance(raw_sources, dict):
sources = [raw_sources]
elif isinstance(raw_sources, list):
sources = raw_sources
elif raw_sources:
sources = [raw_sources]
else:
sources = []
lines: list[str] = []
for source in sources:
if isinstance(source, dict):
title = str(source.get("title", source.get("name", "source"))).strip()
url = str(source.get("url", source.get("uri", ""))).strip()
snippet = str(source.get("snippet", source.get("text", source.get("content", "")))).strip()
parts = [part for part in (title, url, snippet) if part]
if parts:
lines.append(f"<source> {' | '.join(parts)}")
elif source:
lines.append(f"<source> {str(source).strip()}")
return lines
def _render_tool_result(name: str, payload: object) -> list[str]:
tool_name = name.strip() or "tool"
parsed = _coerce_json_payload(payload)
if isinstance(parsed, dict):
explicit_name = str(parsed.get("name", parsed.get("tool", ""))).strip()
if explicit_name:
tool_name = explicit_name
status = str(parsed.get("status", "")).casefold()
ok_value = parsed.get("ok", None)
error = str(parsed.get("error", parsed.get("message", ""))).strip()
failed = ok_value is False or status in {"error", "failed", "failure", "timeout"} or bool(error)
if failed:
first = f"<tool_result> {tool_name} failed: {error or status or 'unknown error'}"
else:
summary = str(parsed.get("summary", parsed.get("content", parsed.get("text", "")))).strip()
first = f"<tool_result> {tool_name} ok"
if summary and not _render_source_lines(parsed):
first = f"{first}: {summary}"
return [first, *_render_source_lines(parsed)]
if parsed:
return [f"<tool_result> {tool_name} {str(parsed).strip()}"]
return [f"<tool_result> {tool_name} empty"]
def _render_tool_call(call: object) -> str:
if not isinstance(call, dict):
return f"<tool_call> {str(call).strip()}"
function_payload = call.get("function", {})
function = function_payload if isinstance(function_payload, dict) else {}
name = str(call.get("name", function.get("name", "tool"))).strip() or "tool"
arguments = call.get("arguments", function.get("arguments", {}))
if not isinstance(arguments, str):
arguments = json.dumps(arguments, ensure_ascii=False, separators=(",", ":"))
return f"<tool_call> {name} {arguments}".strip()
def compose_generation_context(
prompt: str,
*,
system: str = "",
messages: object | None = None,
tool_results: object | None = None,
) -> str:
clean_prompt = prompt.strip()
clean_system = system.strip()
lines: list[str] = []
tool_protocol_seen = False
if clean_system:
lines.append(clean_system)
if isinstance(messages, list):
for message in messages:
if not isinstance(message, dict):
continue
role = str(message.get("role", "")).casefold()
content = _content_to_text(message.get("content", ""))
if role == "system":
if content:
lines.append(f"System instruction: {content}")
elif role == "user":
if content:
lines.append(f"User: {content}")
elif role == "assistant":
if content:
lines.append(f"Assistant: {content}")
if "<tool_call>" in content:
tool_protocol_seen = True
tool_calls = message.get("tool_calls", [])
if isinstance(tool_calls, list):
for call in tool_calls:
lines.append(_render_tool_call(call))
tool_protocol_seen = True
elif role == "tool":
tool_name = str(message.get("name", message.get("tool_call_id", "tool")))
lines.extend(_render_tool_result(tool_name, message.get("content", "")))
tool_protocol_seen = True
elif content:
lines.append(f"{role.capitalize()}: {content}")
if clean_prompt:
lines.append(f"User: {clean_prompt}" if isinstance(messages, list) else clean_prompt)
if isinstance(tool_results, list):
for result in tool_results:
tool_name = "tool"
if isinstance(result, dict):
tool_name = str(result.get("name", result.get("tool", "tool")))
lines.extend(_render_tool_result(tool_name, result))
tool_protocol_seen = True
elif tool_results:
lines.extend(_render_tool_result("tool", tool_results))
tool_protocol_seen = True
if tool_protocol_seen:
lines.append("<final>")
return "\n".join(line for line in lines if line).strip()
def command_generate_batch(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
prompts = load_prompt_suite(args.prompts)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
rows: list[dict[str, object]] = []
with output_path.open("w", encoding="utf-8") as handle:
for index, record in enumerate(prompts):
prompt = str(record["prompt"])
record_mode = str(
record.get(
"reasoning_mode",
args.reasoning_mode or model.config.default_reasoning_profile,
)
)
context = compose_generation_context(
prompt,
system=str(record.get("system", "")),
messages=record.get("messages"),
tool_results=record.get("tool_results"),
)
max_tokens = int(record.get("max_tokens", args.max_tokens))
generated_text = model.generate_text(
context,
max_tokens=max_tokens,
reasoning_mode=record_mode,
temperature=args.temperature,
top_k=args.decode_top_k,
top_p=args.decode_top_p,
repetition_penalty=args.repetition_penalty,
)
row = {
"index": index,
"prompt": prompt,
"context": context,
"system": record.get("system", ""),
"tags": record.get("tags", []),
"reasoning_mode": record_mode,
"reasoning_tokens": reasoning_prefix(record_mode),
"generated_token_count": len(generated_text.split()),
"generated_text": generated_text,
}
rows.append(row)
handle.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n")
payload = {
"status": "generated",
"sample_count": len(rows),
"model_path": str(Path(args.model).resolve()),
"prompts_path": str(Path(args.prompts).resolve()),
"output_path": str(output_path.resolve()),
"model_loads": 1,
}
print(json.dumps(payload))
return 0
def command_serve(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
default_mode = args.reasoning_mode or model.config.default_reasoning_profile
generated_history_by_context: dict[str, list[str]] = {}
session_turns_by_id: dict[str, list[tuple[str, str]]] = {}
for index, raw_line in enumerate(sys.stdin):
line = raw_line.strip()
if not line:
continue
try:
request = json.loads(line)
except json.JSONDecodeError as exc:
response = {
"index": index,
"error": "invalid_json",
"message": str(exc),
"model_loads": 1,
}
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
sys.stdout.flush()
continue
if isinstance(request, str):
raw_context = request
base_context = request
request_payload: dict[str, object] = {}
elif isinstance(request, dict):
request_payload = request
raw_context = str(request_payload.get("prompt", request_payload.get("context", "")))
base_context = compose_generation_context(
raw_context,
system=str(request_payload.get("system", "")),
messages=request_payload.get("messages"),
tool_results=request_payload.get("tool_results", request_payload.get("toolResults")),
)
else:
response = {
"index": index,
"error": "invalid_request",
"message": "request must be a JSON object or string",
"model_loads": 1,
}
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
sys.stdout.flush()
continue
session_id = str(
request_payload.get(
"session_id",
request_payload.get("conversation_id", request_payload.get("thread_id", "")),
)
).strip()
memory_turn_limit = max(
0,
int(request_payload.get("memory_turns", getattr(args, "memory_turns", 16))),
)
session_turns = session_turns_by_id.get(session_id, []) if session_id else []
memory_context = ""
if session_turns and memory_turn_limit > 0:
memory_lines = ["Conversation memory:"]
for prior_user, prior_assistant in session_turns[-memory_turn_limit:]:
if prior_user.strip():
memory_lines.append(f"Previous user: {prior_user.strip()}")
if prior_assistant.strip():
memory_lines.append(f"Previous assistant: {prior_assistant.strip()}")
memory_context = "\n".join(memory_lines)
context = (
f"{memory_context}\nCurrent user: {base_context}"
if memory_context
else base_context
)
active_mode = str(request_payload.get("reasoning_mode", default_mode))
max_tokens = int(request_payload.get("max_tokens", args.max_tokens))
temperature = float(request_payload.get("temperature", args.temperature))
top_k = int(request_payload.get("decode_top_k", args.decode_top_k))
top_p = float(request_payload.get("decode_top_p", args.decode_top_p))
repetition_penalty = float(
request_payload.get("repetition_penalty", args.repetition_penalty)
)
history_key = " ".join(base_context.split())
avoid_texts = generated_history_by_context.get(history_key, [])
generated_text = model.generate_text(
context,
max_tokens=max_tokens,
reasoning_mode=active_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
avoid_texts=avoid_texts,
)
if generated_text.strip():
next_history = [*avoid_texts, generated_text]
generated_history_by_context[history_key] = next_history[-8:]
if session_id:
user_memory_text = raw_context if raw_context.strip() else base_context
next_session_turns = [*session_turns, (user_memory_text, generated_text)]
session_turns_by_id[session_id] = next_session_turns[-max(1, memory_turn_limit):]
response = {
"index": index,
"context": context,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generated_token_count": len(generated_text.split()),
"generated_text": generated_text,
"memory_turn_count": len(session_turns[-memory_turn_limit:]) if memory_turn_limit > 0 else 0,
"model_loads": 1,
}
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
sys.stdout.flush()
return 0
def command_chat_completion(args: argparse.Namespace) -> int:
from .openai_compat import build_chat_completion_response, iter_sse_chat_completion
request_path = str(getattr(args, "request", "")).strip()
if request_path:
request_text = Path(request_path).read_text(encoding="utf-8")
else:
request_text = sys.stdin.read()
request = json.loads(request_text)
if not isinstance(request, dict):
raise ValueError("chat-completion request must be a JSON object")
model = ReframrModel.load(args.model)
if bool(request.get("stream", False)):
for event in iter_sse_chat_completion(model, request):
sys.stdout.write(event)
sys.stdout.flush()
return 0
response = build_chat_completion_response(model, request)
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
sys.stdout.flush()
return 0
def command_trace(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
payload = model.trace_generation(
args.context,
max_tokens=args.max_tokens,
reasoning_mode=args.reasoning_mode,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.decode_top_p,
repetition_penalty=args.repetition_penalty,
)
print(json.dumps(payload))
return 0
def command_inspect(args: argparse.Namespace) -> int:
print(json.dumps(inspect_checkpoint(args.model)))
return 0
def command_craft_corpus(args: argparse.Namespace) -> int:
package = (
build_generalization_corpus()
if args.variant == "generalization"
else build_foundation_corpus()
)
paths = write_corpus_package(package, args.output_dir)
payload = {
"name": package.name,
"corpus_path": paths["corpus_path"],
"manifest_path": paths["manifest_path"],
"prompt_suite_path": paths["prompt_suite_path"],
"token_count_estimate": len(package.text.split()),
"memorization_samples": len(package.memorization_samples),
"generalization_samples": len(package.generalization_samples),
"generalization_prompt_count": len(package.open_ended_samples),
"variant": args.variant,
"section_counts": package.section_counts,
}
print(json.dumps(payload))
return 0
def command_craft_curriculum(args: argparse.Namespace) -> int:
payload = write_curriculum_package(
args.output_dir,
CurriculumConfig(
records_per_category=args.records_per_category,
seed=args.seed,
train_ratio=args.train_ratio,
),
effective_token_target=args.effective_token_target or None,
)
print(json.dumps(payload))
return 0
def command_craft_v2_plan(args: argparse.Namespace) -> int:
payload = write_v2_streaming_plan(
args.output,
rows_per_source=args.rows_per_source,
effective_token_target=args.effective_token_target,
wikipedia_mode=args.wikipedia_mode,
local_curriculum_paths=args.local_curriculum,
local_curriculum_limit=args.local_curriculum_limit,
)
print(json.dumps(payload))
return 0
def command_materialize_plan(args: argparse.Namespace) -> int:
max_bytes = int(max(0.0, float(args.max_gb)) * (1024 ** 3))
shard_bytes = int(max(1, int(args.shard_mb)) * (1024 ** 2))
payload = materialize_corpus_plan(
load_corpus_plan(args.plan),
args.output_dir,
max_bytes=max_bytes,
shard_bytes=shard_bytes,
log_every=args.log_every,
)
print(json.dumps(payload))
return 0
def command_craft_blind_prompts(args: argparse.Namespace) -> int:
payload = write_blind_prompt_suite(
args.output,
seed=args.seed,
variants_per_intent=args.variants_per_intent,
)
print(json.dumps(payload))
return 0
def command_evaluate(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
manifest = load_manifest(args.manifest)
payload = evaluate_manifest(
model,
manifest,
reasoning_mode=args.reasoning_mode,
top_k=args.top_k,
)
print(json.dumps(payload))
return 0
def command_benchmark_open(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
prompts = load_prompt_suite(args.prompts)
replay_sources = load_replay_sources(
args.replay_source,
limit=args.replay_source_limit,
)
payload = benchmark_open_prompts(
model,
prompts,
reasoning_mode=args.reasoning_mode,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.decode_top_k,
top_p=args.decode_top_p,
repetition_penalty=args.repetition_penalty,
replay_sources=replay_sources,
replay_ngram_size=args.replay_ngram_size,
replay_overlap_threshold=args.replay_overlap_threshold,
)
serialized = json.dumps(payload, ensure_ascii=False)
output_path = str(getattr(args, "output", "")).strip()
if output_path:
target = Path(output_path)
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(serialized + "\n", encoding="utf-8")
print(serialized)
return 0
def command_sparse_context_benchmark(args: argparse.Namespace) -> int:
import random
model = ReframrModel.load(args.model)
if model.embedding_model is None:
raise RuntimeError("checkpoint does not contain embeddings")
if args.selector == "hashed":
kernel = HashedSparseAttention(
model.embedding_model.embeddings,
k_neighbors=args.top_k,
hash_bits=args.hash_bits,
probe_radius=args.probe_radius,
seed=args.seed,
candidate_multiplier=args.candidate_multiplier,
)
elif args.selector == "faiss":
kernel = FaissSparseAttention(
model.embedding_model.embeddings,
k_neighbors=args.top_k,
approximate=args.faiss_hnsw,
hnsw_neighbors=args.hnsw_neighbors,
ef_search=args.ef_search,
)
else:
kernel = AnalyticalSparseAttention(
model.embedding_model.embeddings,
k_neighbors=args.top_k,
)
vocab_size = len(model.embedding_model.id_to_token)
rng = random.Random(int(args.seed))
context_tokens = [rng.randrange(vocab_size) for _ in range(max(0, int(args.context_tokens)))]
query_tokens = [rng.randrange(vocab_size) for _ in range(max(0, int(args.query_count)))]
payload = kernel.benchmark_selection(
context_tokens,
query_tokens,
top_k=args.top_k,
)
if args.compare_exact and args.selector == "hashed":
payload["exact_recall"] = compare_selectors(
model.embedding_model.embeddings,
context_tokens,
query_tokens,
top_k=args.top_k,
hash_bits=args.hash_bits,
probe_radius=args.probe_radius,
seed=args.seed,
)
payload.update(
{
"schema_version": "reframr.sparse_context_benchmark.v1",
"model": str(Path(args.model).resolve()),
"selector": args.selector,
"hash_bits": int(args.hash_bits) if args.selector == "hashed" else 0,
"probe_radius": int(args.probe_radius) if args.selector == "hashed" else 0,
"candidate_multiplier": int(args.candidate_multiplier) if args.selector == "hashed" else 0,
"faiss_approximate": bool(args.selector == "faiss" and args.faiss_hnsw),
"hnsw_neighbors": int(args.hnsw_neighbors) if args.selector == "faiss" and args.faiss_hnsw else 0,
"ef_search": int(args.ef_search) if args.selector == "faiss" and args.faiss_hnsw else 0,
"tokenizer_vocab_size": vocab_size,
"embedding_dim": kernel.embedding_dim,
}
)
serialized = json.dumps(payload, ensure_ascii=False)
output_path = str(getattr(args, "output", "")).strip()
if output_path:
target = Path(output_path)
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(serialized + "\n", encoding="utf-8")
print(serialized)
return 0
def command_import_hf(args: argparse.Namespace) -> int:
payload = import_hf_dataset(
dataset=args.dataset,
output_path=args.output,
config=args.config,
split=args.split,
text_field=args.text_field,
limit=args.limit,
streaming=not args.no_streaming,
preference_target=args.preference_target,
min_words=args.min_words,
max_words=args.max_words,
min_alpha_ratio=args.min_alpha_ratio,
allowed_languages=tuple(
segment.strip()
for segment in args.allowed_languages.split(",")
if segment.strip()
),
)
print(json.dumps(payload))
return 0
def main(argv: list[str] | None = None) -> int:
configure_stdio()
parser = build_parser()
args = parser.parse_args(argv)
if args.command in {"compute", "train"}:
return command_compute(args)
if args.command == "recompute":
return command_recompute(args)
if args.command == "predict":
return command_predict(args)
if args.command == "generate":
return command_generate(args)
if args.command == "generate-batch":
return command_generate_batch(args)
if args.command == "serve":
return command_serve(args)
if args.command == "chat-completion":
return command_chat_completion(args)
if args.command == "trace":
return command_trace(args)
if args.command == "inspect":
return command_inspect(args)
if args.command == "craft-corpus":
return command_craft_corpus(args)
if args.command == "craft-curriculum":
return command_craft_curriculum(args)
if args.command == "craft-v2-plan":
return command_craft_v2_plan(args)
if args.command == "materialize-plan":
return command_materialize_plan(args)
if args.command == "craft-blind-prompts":
return command_craft_blind_prompts(args)
if args.command == "evaluate":
return command_evaluate(args)
if args.command == "benchmark-open":
return command_benchmark_open(args)
if args.command == "sparse-context-benchmark":
return command_sparse_context_benchmark(args)
if args.command == "import-hf":
return command_import_hf(args)
parser.error(f"Unknown command: {args.command}")
return 2