| import argparse |
| import json |
| import sys |
| from dataclasses import replace |
| from pathlib import Path |
|
|
| from .checkpoint import inspect_checkpoint |
| from .config import ReframrConfig |
| from .corpus_recipes import ( |
| build_foundation_corpus, |
| build_generalization_corpus, |
| write_corpus_package, |
| ) |
| from .curriculum import CurriculumConfig, write_curriculum_package |
| from .datasets import load_prompt_suite, load_text_corpus |
| from .evaluation import ( |
| benchmark_open_prompts, |
| evaluate_manifest, |
| load_manifest, |
| load_replay_sources, |
| ) |
| from .hf_import import import_hf_dataset |
| from .materialize import DEFAULT_CACHE_BYTE_LIMIT, DEFAULT_SHARD_BYTE_LIMIT, materialize_corpus_plan |
| from .model import ReframrModel |
| from .reasoning import REASONING_PROFILES, TOKENIZER_NAME, reasoning_prefix |
| from .sparse_context import ( |
| AnalyticalSparseAttention, |
| FaissSparseAttention, |
| HashedSparseAttention, |
| compare_selectors, |
| ) |
| from .streaming import estimate_corpus_plan, fit_model_from_corpus_plan, load_corpus_plan |
| from .tokenizer import MAX_TOKENIZER_VOCAB_SIZE, clamp_vocab_size, recommend_vocab_size |
| from .v2_data import write_blind_prompt_suite, write_v2_streaming_plan |
|
|
|
|
| def configure_stdio() -> None: |
| for stream in (sys.stdout, sys.stderr): |
| reconfigure = getattr(stream, "reconfigure", None) |
| if reconfigure is not None: |
| reconfigure(encoding="utf-8") |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser( |
| prog="reframr", |
| description="Compute and query REFRAMR analytical language model checkpoints.", |
| ) |
| subparsers = parser.add_subparsers(dest="command", required=True) |
|
|
| compute = subparsers.add_parser( |
| "compute", |
| aliases=["train"], |
| help="Compute a REFRAMR checkpoint from a text corpus with no epoch loop.", |
| ) |
| compute.add_argument( |
| "--input", |
| required=True, |
| help="Path to a text, JSON, or JSONL corpus file, or a directory of such files.", |
| ) |
| compute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.") |
| compute.add_argument("--embedding-dim", type=int, default=16) |
| compute.add_argument("--state-dim", type=int, default=32) |
| compute.add_argument("--timescales", default="1.0,0.5,0.25,0.125") |
| compute.add_argument("--window-size", type=int, default=2) |
| compute.add_argument("--regularization", type=float, default=1e-3) |
| compute.add_argument("--min-frequency", type=int, default=1) |
| compute.add_argument( |
| "--max-vocab", |
| type=int, |
| default=256, |
| help="Cap analytical embedding vocabulary to keep weight computation fast on CPU.", |
| ) |
| compute.add_argument("--tokenizer-vocab-size", type=int, default=0) |
| compute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2) |
| compute.add_argument( |
| "--max-training-examples", |
| type=int, |
| default=60000, |
| help="Cap sampled recurrent training states while still reading the full corpus for tokenizer, embeddings, and transitions.", |
| ) |
| compute.add_argument( |
| "--max-memory-examples", |
| type=int, |
| default=-1, |
| help="Cap saved associative memory examples separately from readout training. Use -1 to match --max-training-examples.", |
| ) |
| compute.add_argument( |
| "--max-state-tokens-per-document", |
| type=int, |
| default=768, |
| help="Cap recurrent state steps per document with a deterministic corpus sketch. Use 0 to step full documents.", |
| ) |
| compute.add_argument( |
| "--max-transition-contexts", |
| type=int, |
| default=4096, |
| help="Keep only the strongest learned transition contexts per order. Use 0 to disable the cap.", |
| ) |
| compute.add_argument( |
| "--max-transition-next-tokens", |
| type=int, |
| default=4, |
| help="Keep this many learned next-token choices per transition context.", |
| ) |
| case_group = compute.add_mutually_exclusive_group() |
| case_group.add_argument( |
| "--lowercase", |
| action="store_true", |
| help="Normalize corpus text to lowercase before tokenization.", |
| ) |
| case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS) |
| compute.add_argument( |
| "--reasoning-profile", |
| choices=sorted(REASONING_PROFILES), |
| default="none", |
| help="Default reasoning-control profile baked into the checkpoint.", |
| ) |
| compute.add_argument( |
| "--layout-profile", |
| default="rfm-base", |
| help="Structured analytical layout label to store in checkpoint metadata, such as rfm-70b-structured.", |
| ) |
| compute.add_argument( |
| "--effective-parameter-target", |
| type=int, |
| default=0, |
| help="Dense-equivalent structured target to store in checkpoint metadata; this does not allocate dense tensors.", |
| ) |
|
|
| recompute = subparsers.add_parser( |
| "recompute", |
| help="Compute a REFRAMR checkpoint from a streaming corpus plan with no raw-text cache.", |
| ) |
| recompute.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.") |
| recompute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.") |
| recompute.add_argument("--embedding-dim", type=int, default=16) |
| recompute.add_argument("--state-dim", type=int, default=32) |
| recompute.add_argument("--timescales", default="1.0,0.5,0.25,0.125") |
| recompute.add_argument("--window-size", type=int, default=2) |
| recompute.add_argument("--regularization", type=float, default=1e-3) |
| recompute.add_argument("--min-frequency", type=int, default=1) |
| recompute.add_argument("--max-vocab", type=int, default=256) |
| recompute.add_argument("--tokenizer-vocab-size", type=int, default=0) |
| recompute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2) |
| recompute.add_argument("--max-training-examples", type=int, default=60000) |
| recompute.add_argument("--max-memory-examples", type=int, default=-1) |
| recompute.add_argument("--max-state-tokens-per-document", type=int, default=768) |
| recompute.add_argument("--max-transition-contexts", type=int, default=4096) |
| recompute.add_argument("--max-transition-next-tokens", type=int, default=4) |
| recompute.add_argument("--log-every", type=int, default=0) |
| recompute.add_argument( |
| "--dry-run", |
| action="store_true", |
| help="Estimate accepted rows and compute shape without fitting or saving a checkpoint.", |
| ) |
| recompute.add_argument( |
| "--estimate-max-rows-per-source", |
| type=int, |
| default=0, |
| help="Optional cap for preflight row scanning per local source.", |
| ) |
| recompute.add_argument( |
| "--calibrate-rows", |
| type=int, |
| default=0, |
| help="Run a bounded representative fit first and estimate full-run wall-clock time.", |
| ) |
| recompute.add_argument( |
| "--calibrate-only", |
| action="store_true", |
| help="Stop after calibration instead of computing and saving the full checkpoint.", |
| ) |
| recompute_case_group = recompute.add_mutually_exclusive_group() |
| recompute_case_group.add_argument("--lowercase", action="store_true") |
| recompute_case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS) |
| recompute.add_argument( |
| "--reasoning-profile", |
| choices=sorted(REASONING_PROFILES), |
| default="none", |
| help="Default reasoning-control profile baked into the checkpoint.", |
| ) |
| recompute.add_argument( |
| "--layout-profile", |
| default="rfm-base", |
| help="Structured analytical layout label to store in checkpoint metadata, such as rfm-70b-structured.", |
| ) |
| recompute.add_argument( |
| "--effective-parameter-target", |
| type=int, |
| default=0, |
| help="Dense-equivalent structured target to store in checkpoint metadata; this does not allocate dense tensors.", |
| ) |
|
|
| predict = subparsers.add_parser("predict", help="Predict the next-token distribution from a saved model.") |
| predict.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") |
| predict.add_argument("--context", required=True, help="Input context text.") |
| predict.add_argument("--top-k", type=int, default=5) |
| predict.add_argument( |
| "--reasoning-mode", |
| choices=sorted(REASONING_PROFILES), |
| default=None, |
| help="Override the checkpoint's default reasoning-control profile.", |
| ) |
|
|
| generate = subparsers.add_parser("generate", help="Generate long-form text from a saved model.") |
| generate.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") |
| generate.add_argument("--context", required=True, help="Prompt or starting context text.") |
| generate.add_argument("--system", default="", help="Optional system instruction to prepend as learned context.") |
| generate.add_argument("--max-tokens", type=int, default=64) |
| generate.add_argument("--temperature", type=float, default=0.82) |
| generate.add_argument("--decode-top-k", type=int, default=24) |
| generate.add_argument("--decode-top-p", type=float, default=0.92) |
| generate.add_argument("--repetition-penalty", type=float, default=1.18) |
| generate.add_argument( |
| "--reasoning-mode", |
| choices=sorted(REASONING_PROFILES), |
| default=None, |
| help="Override the checkpoint's default reasoning-control profile.", |
| ) |
|
|
| generate_batch = subparsers.add_parser( |
| "generate-batch", |
| help="Generate answers for a prompt file while keeping one checkpoint loaded.", |
| ) |
| generate_batch.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") |
| generate_batch.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.") |
| generate_batch.add_argument("--output", required=True, help="Path to write JSONL generations.") |
| generate_batch.add_argument("--max-tokens", type=int, default=64) |
| generate_batch.add_argument("--temperature", type=float, default=0.82) |
| generate_batch.add_argument("--decode-top-k", type=int, default=24) |
| generate_batch.add_argument("--decode-top-p", type=float, default=0.92) |
| generate_batch.add_argument("--repetition-penalty", type=float, default=1.18) |
| generate_batch.add_argument( |
| "--reasoning-mode", |
| choices=sorted(REASONING_PROFILES), |
| default=None, |
| help="Override the checkpoint's default reasoning-control profile.", |
| ) |
|
|
| serve = subparsers.add_parser( |
| "serve", |
| help="Keep one checkpoint loaded and answer JSONL generation requests from stdin.", |
| ) |
| serve.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") |
| serve.add_argument("--max-tokens", type=int, default=64) |
| serve.add_argument("--temperature", type=float, default=0.82) |
| serve.add_argument("--decode-top-k", type=int, default=24) |
| serve.add_argument("--decode-top-p", type=float, default=0.92) |
| serve.add_argument("--repetition-penalty", type=float, default=1.18) |
| serve.add_argument( |
| "--memory-turns", |
| type=int, |
| default=16, |
| help="Number of prior JSONL session turns to prepend as conversation memory.", |
| ) |
| serve.add_argument( |
| "--reasoning-mode", |
| choices=sorted(REASONING_PROFILES), |
| default=None, |
| help="Override the checkpoint's default reasoning-control profile.", |
| ) |
|
|
| chat_completion = subparsers.add_parser( |
| "chat-completion", |
| help="Run one OpenAI-compatible chat completion request from stdin or a JSON file.", |
| ) |
| chat_completion.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") |
| chat_completion.add_argument( |
| "--request", |
| default="", |
| help="Optional path to a JSON request. Defaults to stdin.", |
| ) |
|
|
| trace = subparsers.add_parser("trace", help="Trace REFRAMR reasoning components through generation steps.") |
| trace.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") |
| trace.add_argument("--context", required=True, help="Prompt or starting context text.") |
| trace.add_argument("--max-tokens", type=int, default=8) |
| trace.add_argument("--top-k", type=int, default=5) |
| trace.add_argument("--temperature", type=float, default=0.82) |
| trace.add_argument("--decode-top-p", type=float, default=0.92) |
| trace.add_argument("--repetition-penalty", type=float, default=1.18) |
| trace.add_argument( |
| "--reasoning-mode", |
| choices=sorted(REASONING_PROFILES), |
| default=None, |
| help="Override the checkpoint's default reasoning-control profile.", |
| ) |
|
|
| inspect = subparsers.add_parser("inspect", help="Inspect a REFRAMR safetensors checkpoint.") |
| inspect.add_argument("--model", required=True, help="Path to a .safetensors checkpoint.") |
|
|
| craft = subparsers.add_parser( |
| "craft-corpus", |
| help="Generate a JSON-first bootstrap corpus, manifest, and generalization prompt suite.", |
| ) |
| craft.add_argument("--output-dir", required=True, help="Directory to write corpus and manifest files.") |
| craft.add_argument( |
| "--variant", |
| choices=("foundation", "generalization"), |
| default="foundation", |
| help="Choose between the mixed foundation corpus and the language-first generalization corpus.", |
| ) |
|
|
| craft_curriculum = subparsers.add_parser( |
| "craft-curriculum", |
| help="Generate the OkeyMeta JSON curriculum shard, manifest, holdout prompts, and recompute plan.", |
| ) |
| craft_curriculum.add_argument("--output-dir", required=True, help="Directory to write curriculum files.") |
| craft_curriculum.add_argument( |
| "--records-per-category", |
| type=int, |
| default=1000, |
| help="How many JSON records to generate for each curriculum category.", |
| ) |
| craft_curriculum.add_argument("--seed", type=int, default=7) |
| craft_curriculum.add_argument("--train-ratio", type=float, default=0.92) |
| craft_curriculum.add_argument( |
| "--effective-token-target", |
| type=int, |
| default=0, |
| help="Set plan weighting so compact curriculum statistics represent this many effective tokens.", |
| ) |
|
|
| craft_v2_plan = subparsers.add_parser( |
| "craft-v2-plan", |
| help="Write a strict streaming Hugging Face recompute plan for the v2 data mix.", |
| ) |
| craft_v2_plan.add_argument("--output", required=True, help="Path to write the streaming plan JSON.") |
| craft_v2_plan.add_argument( |
| "--rows-per-source", |
| type=int, |
| default=10_000, |
| help="Base accepted row target per source before per-domain multipliers.", |
| ) |
| craft_v2_plan.add_argument( |
| "--effective-token-target", |
| type=int, |
| default=0, |
| help="Optional effective token target recorded in the plan metadata.", |
| ) |
| craft_v2_plan.add_argument( |
| "--wikipedia-mode", |
| choices=("skip", "hf", "viewer"), |
| default="skip", |
| help="Use skip for fast smoke runs; hf/viewer include Wikipedia through the fast HF viewer adapter.", |
| ) |
| craft_v2_plan.add_argument( |
| "--local-curriculum", |
| action="append", |
| default=[], |
| help="Local JSON/JSONL curriculum shard to blend before HF sources.", |
| ) |
| craft_v2_plan.add_argument( |
| "--local-curriculum-limit", |
| type=int, |
| default=0, |
| help="Maximum accepted rows per local curriculum shard. Use 0 for all rows.", |
| ) |
|
|
| materialize_plan = subparsers.add_parser( |
| "materialize-plan", |
| help="Write bounded normalized JSONL shards from a corpus plan, then emit a local recompute plan.", |
| ) |
| materialize_plan.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.") |
| materialize_plan.add_argument("--output-dir", required=True, help="Directory for normalized JSONL shards.") |
| materialize_plan.add_argument( |
| "--max-gb", |
| type=float, |
| default=DEFAULT_CACHE_BYTE_LIMIT / (1024 ** 3), |
| help="Maximum normalized cache size in GB. Defaults to 3GB.", |
| ) |
| materialize_plan.add_argument( |
| "--shard-mb", |
| type=int, |
| default=DEFAULT_SHARD_BYTE_LIMIT // (1024 ** 2), |
| help="Maximum size per JSONL shard in MB.", |
| ) |
| materialize_plan.add_argument("--log-every", type=int, default=0) |
|
|
| craft_blind_prompts = subparsers.add_parser( |
| "craft-blind-prompts", |
| help="Write a blind open-prompt JSONL suite for v2 generalization checks.", |
| ) |
| craft_blind_prompts.add_argument("--output", required=True, help="Path to write JSONL prompts.") |
| craft_blind_prompts.add_argument("--seed", type=int, default=2026) |
| craft_blind_prompts.add_argument( |
| "--variants-per-intent", |
| type=int, |
| default=4, |
| help="How many prompt variants to generate per evaluation intent.", |
| ) |
|
|
| evaluate = subparsers.add_parser( |
| "evaluate", |
| help="Evaluate memorization and held-out generalization from a benchmark manifest.", |
| ) |
| evaluate.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.") |
| evaluate.add_argument("--manifest", required=True, help="Path to a corpus benchmark manifest JSON file.") |
| evaluate.add_argument( |
| "--reasoning-mode", |
| choices=sorted(REASONING_PROFILES), |
| default=None, |
| help="Override the checkpoint's default reasoning-control profile during evaluation.", |
| ) |
| evaluate.add_argument("--top-k", type=int, default=5) |
|
|
| benchmark_open = subparsers.add_parser( |
| "benchmark-open", |
| help="Run arbitrary prompt files through a checkpoint with open-ended output metrics.", |
| ) |
| benchmark_open.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.") |
| benchmark_open.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.") |
| benchmark_open.add_argument("--max-tokens", type=int, default=64) |
| benchmark_open.add_argument("--temperature", type=float, default=0.82) |
| benchmark_open.add_argument("--decode-top-k", type=int, default=24) |
| benchmark_open.add_argument("--decode-top-p", type=float, default=0.92) |
| benchmark_open.add_argument("--repetition-penalty", type=float, default=1.18) |
| benchmark_open.add_argument( |
| "--replay-source", |
| action="append", |
| default=[], |
| help="JSON/JSONL/TXT corpus path used only to flag generated source-row replay.", |
| ) |
| benchmark_open.add_argument( |
| "--replay-source-limit", |
| type=int, |
| default=10_000, |
| help="Maximum source rows to load for replay checks.", |
| ) |
| benchmark_open.add_argument("--replay-ngram-size", type=int, default=8) |
| benchmark_open.add_argument("--replay-overlap-threshold", type=float, default=0.70) |
| benchmark_open.add_argument( |
| "--output", |
| default="", |
| help="Optional UTF-8 JSON path for benchmark results.", |
| ) |
| benchmark_open.add_argument( |
| "--reasoning-mode", |
| choices=sorted(REASONING_PROFILES), |
| default=None, |
| help="Override the checkpoint's default reasoning-control profile during benchmarking.", |
| ) |
|
|
| sparse_benchmark = subparsers.add_parser( |
| "sparse-context-benchmark", |
| help="Measure analytical sparse-context selection speed on a checkpoint embedding table.", |
| ) |
| sparse_benchmark.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.") |
| sparse_benchmark.add_argument("--context-tokens", type=int, default=100_000) |
| sparse_benchmark.add_argument("--query-count", type=int, default=64) |
| sparse_benchmark.add_argument("--top-k", type=int, default=64) |
| sparse_benchmark.add_argument("--seed", type=int, default=2026) |
| sparse_benchmark.add_argument( |
| "--selector", |
| choices=("exact", "hashed", "faiss"), |
| default="hashed", |
| help="Use exact cosine scan or hashed approximate sparse selection.", |
| ) |
| sparse_benchmark.add_argument("--hash-bits", type=int, default=12) |
| sparse_benchmark.add_argument("--probe-radius", type=int, default=1) |
| sparse_benchmark.add_argument("--candidate-multiplier", type=int, default=12) |
| sparse_benchmark.add_argument("--faiss-hnsw", action="store_true") |
| sparse_benchmark.add_argument("--hnsw-neighbors", type=int, default=32) |
| sparse_benchmark.add_argument("--ef-search", type=int, default=64) |
| sparse_benchmark.add_argument( |
| "--compare-exact", |
| action="store_true", |
| help="Also compute exact top-k recall for the selected query set.", |
| ) |
| sparse_benchmark.add_argument("--output", default="", help="Optional UTF-8 JSON path for benchmark results.") |
|
|
| import_hf = subparsers.add_parser( |
| "import-hf", |
| help="Import Hugging Face dataset text into the REFRAMR JSON record standard.", |
| ) |
| import_hf.add_argument("--dataset", required=True, help="Hugging Face dataset id.") |
| import_hf.add_argument("--output", required=True, help="Path to write the JSONL corpus.") |
| import_hf.add_argument("--config", default=None, help="Optional dataset config/subset.") |
| import_hf.add_argument("--split", default="train", help="Dataset split to import.") |
| import_hf.add_argument("--text-field", default=None, help="Explicit text column name.") |
| import_hf.add_argument("--limit", type=int, default=1000, help="Maximum records to import.") |
| import_hf.add_argument( |
| "--min-words", |
| type=int, |
| default=0, |
| help="Drop imported records shorter than this many words.", |
| ) |
| import_hf.add_argument( |
| "--max-words", |
| type=int, |
| default=0, |
| help="Drop imported records longer than this many words. Use 0 to disable.", |
| ) |
| import_hf.add_argument( |
| "--min-alpha-ratio", |
| type=float, |
| default=0.0, |
| help="Drop imported records whose alphabetic-character ratio falls below this threshold.", |
| ) |
| import_hf.add_argument( |
| "--allowed-languages", |
| default="", |
| help="Optional comma-separated language codes to keep, such as en,yo,ig,ha.", |
| ) |
| import_hf.add_argument( |
| "--preference-target", |
| choices=("both", "chosen", "rejected"), |
| default="chosen", |
| help="When importing preference datasets, keep both sides or only the chosen/rejected side.", |
| ) |
| import_hf.add_argument( |
| "--no-streaming", |
| action="store_true", |
| help="Disable streaming dataset reads.", |
| ) |
|
|
| return parser |
|
|
|
|
| def parse_timescales(raw_timescales: str) -> tuple[float, ...]: |
| values = [segment.strip() for segment in raw_timescales.split(",") if segment.strip()] |
| if not values: |
| raise ValueError("At least one timescale is required.") |
| return tuple(float(value) for value in values) |
|
|
|
|
| def command_compute(args: argparse.Namespace) -> int: |
| text = load_text_corpus(args.input) |
| requested_vocab_size = args.tokenizer_vocab_size or recommend_vocab_size( |
| text, |
| lowercase=args.lowercase, |
| ) |
| tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size) |
| config = ReframrConfig( |
| embedding_dim=args.embedding_dim, |
| state_dim=args.state_dim, |
| timescales=parse_timescales(args.timescales), |
| window_size=args.window_size, |
| regularization=args.regularization, |
| min_frequency=args.min_frequency, |
| max_vocab=args.max_vocab, |
| tokenizer_vocab_size=tokenizer_vocab_size, |
| tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency, |
| max_training_examples=args.max_training_examples, |
| max_memory_examples=( |
| None |
| if args.max_memory_examples < 0 |
| else args.max_memory_examples |
| ), |
| max_state_tokens_per_document=( |
| None |
| if args.max_state_tokens_per_document <= 0 |
| else args.max_state_tokens_per_document |
| ), |
| max_transition_contexts_per_order=( |
| args.max_transition_contexts if args.max_transition_contexts > 0 else None |
| ), |
| max_transition_next_tokens=args.max_transition_next_tokens, |
| lowercase=args.lowercase, |
| default_reasoning_profile=args.reasoning_profile, |
| layout_profile=args.layout_profile, |
| effective_parameter_target=args.effective_parameter_target, |
| ) |
| model = ReframrModel(config).fit(text) |
| model.save(args.output) |
|
|
| assert model.tokenizer is not None |
| assert model.embedding_model is not None |
| summary = { |
| "status": "computed", |
| "format": "safetensors", |
| "model_path": str(Path(args.output).resolve()), |
| "tokenizer_name": TOKENIZER_NAME, |
| "vocab_size": len(model.embedding_model.id_to_token), |
| "tokenizer_vocab_budget": config.tokenizer_vocab_size, |
| "tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE, |
| "tokenizer_vocab_size": model.tokenizer.vocab_size, |
| "reasoning_profile": config.default_reasoning_profile, |
| "reasoning_tokens": reasoning_prefix(config.default_reasoning_profile), |
| "lowercase": config.lowercase, |
| "max_training_examples": config.max_training_examples, |
| "max_memory_examples": config.max_memory_examples, |
| "max_state_tokens_per_document": config.max_state_tokens_per_document, |
| "max_transition_contexts_per_order": config.max_transition_contexts_per_order, |
| "max_transition_next_tokens": config.max_transition_next_tokens, |
| "embedding_dim": config.embedding_dim, |
| "state_dim": config.state_dim, |
| "timescales": list(config.timescales), |
| "layout_profile": config.layout_profile, |
| "effective_parameter_target": config.effective_parameter_target, |
| } |
| print(json.dumps(summary)) |
| return 0 |
|
|
|
|
| def command_recompute(args: argparse.Namespace) -> int: |
| plan = load_corpus_plan(args.plan) |
| requested_vocab_size = args.tokenizer_vocab_size or 1024 |
| tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size) |
| config = ReframrConfig( |
| embedding_dim=args.embedding_dim, |
| state_dim=args.state_dim, |
| timescales=parse_timescales(args.timescales), |
| window_size=args.window_size, |
| regularization=args.regularization, |
| min_frequency=args.min_frequency, |
| max_vocab=args.max_vocab, |
| tokenizer_vocab_size=tokenizer_vocab_size, |
| tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency, |
| max_training_examples=args.max_training_examples, |
| max_memory_examples=( |
| None |
| if args.max_memory_examples < 0 |
| else args.max_memory_examples |
| ), |
| max_state_tokens_per_document=( |
| None |
| if args.max_state_tokens_per_document <= 0 |
| else args.max_state_tokens_per_document |
| ), |
| max_transition_contexts_per_order=( |
| args.max_transition_contexts if args.max_transition_contexts > 0 else None |
| ), |
| max_transition_next_tokens=args.max_transition_next_tokens, |
| lowercase=args.lowercase, |
| default_reasoning_profile=args.reasoning_profile, |
| layout_profile=args.layout_profile, |
| effective_parameter_target=args.effective_parameter_target, |
| ) |
| if args.dry_run: |
| estimate = estimate_corpus_plan( |
| plan, |
| max_rows_per_source=args.estimate_max_rows_per_source, |
| ) |
| accepted = int(estimate.get("accepted_documents", 0) or 0) |
| state_cap = config.max_state_tokens_per_document or 768 |
| estimated_state_tokens = accepted * state_cap |
| summary = { |
| "status": "dry_run", |
| "plan_path": str(Path(args.plan).resolve()), |
| "output_path": str(Path(args.output).resolve()), |
| "accepted_documents": accepted, |
| "seen_texts": estimate.get("seen_texts", 0), |
| "rejected_texts": estimate.get("rejected_texts", 0), |
| "estimated_words": estimate.get("estimated_words", 0), |
| "estimated_state_token_budget": estimated_state_tokens, |
| "embedding_dim": config.embedding_dim, |
| "state_dim": config.state_dim, |
| "tokenizer_vocab_budget": config.tokenizer_vocab_size, |
| "max_vocab": config.max_vocab, |
| "max_training_examples": config.max_training_examples, |
| "max_memory_examples": config.max_memory_examples, |
| "max_state_tokens_per_document": config.max_state_tokens_per_document, |
| "max_transition_contexts_per_order": config.max_transition_contexts_per_order, |
| "max_transition_next_tokens": config.max_transition_next_tokens, |
| "layout_profile": config.layout_profile, |
| "effective_parameter_target": config.effective_parameter_target, |
| "estimate_seconds": estimate.get("seconds", 0), |
| "sources": estimate.get("sources", []), |
| } |
| print(json.dumps(summary)) |
| return 0 |
| if args.calibrate_rows > 0: |
| calibration = _calibrate_recompute_plan( |
| plan, |
| config, |
| target_rows=args.calibrate_rows, |
| estimate_max_rows_per_source=args.estimate_max_rows_per_source, |
| log_every=args.log_every, |
| ) |
| print(json.dumps(calibration), flush=True) |
| if args.calibrate_only: |
| return 0 |
| model, payload = fit_model_from_corpus_plan( |
| plan, |
| config, |
| log_every=args.log_every, |
| ) |
| model.save(args.output) |
|
|
| summary = { |
| "status": "recomputed", |
| "format": "safetensors", |
| "streaming": True, |
| "plan_path": str(Path(args.plan).resolve()), |
| "model_path": str(Path(args.output).resolve()), |
| "tokenizer_name": TOKENIZER_NAME, |
| "tokenizer_vocab_budget": config.tokenizer_vocab_size, |
| "tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE, |
| "tokenizer_vocab_size": payload["tokenizer_vocab_size"], |
| "vocab_size": payload["embedding_vocab_size"], |
| "documents_processed": payload["documents_processed"], |
| "source_counts": payload["source_counts"], |
| "examples_processed": payload["examples_processed"], |
| "associative_examples": payload["associative_examples"], |
| "answer_associative_examples": payload.get("answer_associative_examples", 0), |
| "general_associative_examples": payload.get("general_associative_examples", 0), |
| "answer_intent_examples": payload.get("answer_intent_examples", 0), |
| "answer_start_examples": payload.get("answer_start_examples", 0), |
| "answer_sequence_examples": payload.get("answer_sequence_examples", 0), |
| "prompt_answer_readout_examples": payload.get("prompt_answer_readout_examples", 0), |
| "prompt_answer_start_readout_examples": payload.get("prompt_answer_start_readout_examples", 0), |
| "preference_pairs": payload.get("preference_pairs", 0), |
| "preference_state_pairs": payload.get("preference_state_pairs", 0), |
| "stage_seconds": payload.get("stage_seconds", {}), |
| "readout_solver": payload.get("readout_solver"), |
| "reasoning_profile": config.default_reasoning_profile, |
| "reasoning_tokens": reasoning_prefix(config.default_reasoning_profile), |
| "lowercase": config.lowercase, |
| "max_training_examples": config.max_training_examples, |
| "max_memory_examples": config.max_memory_examples, |
| "max_state_tokens_per_document": config.max_state_tokens_per_document, |
| "state_tokens_before_sketch": payload.get("state_tokens_before_sketch", 0), |
| "state_tokens_after_sketch": payload.get("state_tokens_after_sketch", 0), |
| "max_transition_contexts_per_order": config.max_transition_contexts_per_order, |
| "max_transition_next_tokens": config.max_transition_next_tokens, |
| "embedding_dim": config.embedding_dim, |
| "state_dim": config.state_dim, |
| "timescales": list(config.timescales), |
| "layout_profile": config.layout_profile, |
| "effective_parameter_target": config.effective_parameter_target, |
| } |
| print(json.dumps(summary)) |
| return 0 |
|
|
|
|
| def _limited_calibration_plan( |
| plan: list[object], |
| *, |
| target_rows: int, |
| full_accepted: int, |
| ) -> list[object]: |
| if target_rows <= 0: |
| return plan |
| ratio = min(1.0, target_rows / max(1, full_accepted)) |
| limited: list[object] = [] |
| fallback_limit = max(1, target_rows // max(1, len(plan))) |
| for entry in plan: |
| raw_limit = int(getattr(entry, "limit", 0) or 0) |
| if raw_limit > 0: |
| next_limit = max(1, min(raw_limit, int((raw_limit * ratio) + 0.999999))) |
| else: |
| record_count = len(getattr(entry, "records", ()) or ()) |
| source_cap = record_count if record_count > 0 else fallback_limit |
| next_limit = max(1, min(source_cap, fallback_limit)) |
| limited.append(replace(entry, limit=next_limit)) |
| return limited |
|
|
|
|
| def _estimate_full_seconds_from_calibration( |
| *, |
| full_documents: int, |
| full_state_tokens: int, |
| calibration_payload: dict[str, object], |
| ) -> dict[str, object]: |
| calibration_documents = max(1, int(calibration_payload.get("documents_processed", 0) or 0)) |
| calibration_state_tokens = max( |
| 1, |
| int(calibration_payload.get("state_tokens_after_sketch", 0) or 0), |
| ) |
| document_scale = full_documents / calibration_documents |
| state_scale = full_state_tokens / calibration_state_tokens |
| stage_seconds = calibration_payload.get("stage_seconds", {}) |
| if not isinstance(stage_seconds, dict): |
| stage_seconds = {} |
| fixed_weighted = {"tokenizer_fit", "embedding", "kernel_warmup", "preference"} |
| state_weighted = {"state_and_readout", "finalize_prompt_readouts", "finalize_memory_arrays"} |
| document_weighted = { |
| "stream_and_segment", |
| "vocabulary", |
| "cooccurrence", |
| "model_finalize", |
| "finalize_answer_sequences", |
| "finalize_transition_tables", |
| } |
| stage_estimates: dict[str, float] = {} |
| for stage, raw_seconds in stage_seconds.items(): |
| seconds = float(raw_seconds) |
| if stage in fixed_weighted: |
| scale = 1.0 |
| elif stage in state_weighted: |
| scale = state_scale |
| elif stage in document_weighted: |
| scale = document_scale |
| else: |
| scale = max(document_scale, state_scale) |
| stage_estimates[str(stage)] = round(seconds * scale, 3) |
| total_seconds = round(sum(stage_estimates.values()), 3) |
| return { |
| "estimated_full_seconds": total_seconds, |
| "estimated_full_minutes": round(total_seconds / 60.0, 3), |
| "scale_documents": round(document_scale, 4), |
| "scale_state_tokens": round(state_scale, 4), |
| "stage_estimates": stage_estimates, |
| } |
|
|
|
|
| def _calibrate_recompute_plan( |
| plan: list[object], |
| config: ReframrConfig, |
| *, |
| target_rows: int, |
| estimate_max_rows_per_source: int, |
| log_every: int, |
| ) -> dict[str, object]: |
| full_estimate = estimate_corpus_plan( |
| plan, |
| max_rows_per_source=estimate_max_rows_per_source, |
| ) |
| full_documents = int(full_estimate.get("accepted_documents", 0) or 0) |
| state_cap = config.max_state_tokens_per_document or 768 |
| full_state_tokens = full_documents * state_cap |
| calibration_plan = _limited_calibration_plan( |
| plan, |
| target_rows=target_rows, |
| full_accepted=full_documents, |
| ) |
| _, calibration_payload = fit_model_from_corpus_plan( |
| calibration_plan, |
| config, |
| log_every=log_every, |
| ) |
| runtime_estimate = _estimate_full_seconds_from_calibration( |
| full_documents=full_documents, |
| full_state_tokens=full_state_tokens, |
| calibration_payload=calibration_payload, |
| ) |
| return { |
| "status": "calibration", |
| "target_rows": target_rows, |
| "full_accepted_documents": full_documents, |
| "full_estimated_words": full_estimate.get("estimated_words", 0), |
| "full_estimated_state_token_budget": full_state_tokens, |
| "calibration_documents": calibration_payload.get("documents_processed", 0), |
| "calibration_state_tokens": calibration_payload.get("state_tokens_after_sketch", 0), |
| "calibration_stage_seconds": calibration_payload.get("stage_seconds", {}), |
| **runtime_estimate, |
| } |
|
|
|
|
| def command_predict(args: argparse.Namespace) -> int: |
| model = ReframrModel.load(args.model) |
| distribution = model.predict_next_distribution( |
| args.context, |
| reasoning_mode=args.reasoning_mode, |
| ) |
| predictions = sorted( |
| distribution.items(), |
| key=lambda item: item[1], |
| reverse=True, |
| )[: args.top_k] |
| payload = { |
| "context": args.context, |
| "reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile, |
| "reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile), |
| "predictions": [ |
| {"token": token, "probability": probability} |
| for token, probability in predictions |
| ], |
| } |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_generate(args: argparse.Namespace) -> int: |
| model = ReframrModel.load(args.model) |
| context = compose_generation_context(args.context, system=args.system) |
| generated_text = model.generate_text( |
| context, |
| max_tokens=args.max_tokens, |
| reasoning_mode=args.reasoning_mode, |
| temperature=args.temperature, |
| top_k=args.decode_top_k, |
| top_p=args.decode_top_p, |
| repetition_penalty=args.repetition_penalty, |
| ) |
| payload = { |
| "context": context, |
| "reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile, |
| "reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile), |
| "generated_token_count": len(generated_text.split()), |
| "generated_text": generated_text, |
| } |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def _content_to_text(content: object) -> str: |
| if content is None: |
| return "" |
| if isinstance(content, str): |
| return content.strip() |
| if isinstance(content, list): |
| parts: list[str] = [] |
| for item in content: |
| if isinstance(item, dict): |
| text = item.get("text", item.get("content", item.get("input_text", ""))) |
| if text: |
| parts.append(str(text).strip()) |
| elif item is not None: |
| parts.append(str(item).strip()) |
| return "\n".join(part for part in parts if part) |
| if isinstance(content, (dict, tuple)): |
| return json.dumps(content, ensure_ascii=False, separators=(",", ":")) |
| return str(content).strip() |
|
|
|
|
| def _coerce_json_payload(payload: object) -> object: |
| if not isinstance(payload, str): |
| return payload |
| stripped = payload.strip() |
| if not stripped: |
| return "" |
| try: |
| return json.loads(stripped) |
| except json.JSONDecodeError: |
| return stripped |
|
|
|
|
| def _render_source_lines(payload: object) -> list[str]: |
| if not isinstance(payload, dict): |
| return [] |
| nested_content = payload.get("content") |
| if isinstance(nested_content, dict): |
| nested_lines = _render_source_lines(nested_content) |
| if nested_lines: |
| return nested_lines |
| raw_sources = payload.get("sources", payload.get("source", [])) |
| if isinstance(raw_sources, dict): |
| sources = [raw_sources] |
| elif isinstance(raw_sources, list): |
| sources = raw_sources |
| elif raw_sources: |
| sources = [raw_sources] |
| else: |
| sources = [] |
|
|
| lines: list[str] = [] |
| for source in sources: |
| if isinstance(source, dict): |
| title = str(source.get("title", source.get("name", "source"))).strip() |
| url = str(source.get("url", source.get("uri", ""))).strip() |
| snippet = str(source.get("snippet", source.get("text", source.get("content", "")))).strip() |
| parts = [part for part in (title, url, snippet) if part] |
| if parts: |
| lines.append(f"<source> {' | '.join(parts)}") |
| elif source: |
| lines.append(f"<source> {str(source).strip()}") |
| return lines |
|
|
|
|
| def _render_tool_result(name: str, payload: object) -> list[str]: |
| tool_name = name.strip() or "tool" |
| parsed = _coerce_json_payload(payload) |
| if isinstance(parsed, dict): |
| explicit_name = str(parsed.get("name", parsed.get("tool", ""))).strip() |
| if explicit_name: |
| tool_name = explicit_name |
| status = str(parsed.get("status", "")).casefold() |
| ok_value = parsed.get("ok", None) |
| error = str(parsed.get("error", parsed.get("message", ""))).strip() |
| failed = ok_value is False or status in {"error", "failed", "failure", "timeout"} or bool(error) |
| if failed: |
| first = f"<tool_result> {tool_name} failed: {error or status or 'unknown error'}" |
| else: |
| summary = str(parsed.get("summary", parsed.get("content", parsed.get("text", "")))).strip() |
| first = f"<tool_result> {tool_name} ok" |
| if summary and not _render_source_lines(parsed): |
| first = f"{first}: {summary}" |
| return [first, *_render_source_lines(parsed)] |
| if parsed: |
| return [f"<tool_result> {tool_name} {str(parsed).strip()}"] |
| return [f"<tool_result> {tool_name} empty"] |
|
|
|
|
| def _render_tool_call(call: object) -> str: |
| if not isinstance(call, dict): |
| return f"<tool_call> {str(call).strip()}" |
| function_payload = call.get("function", {}) |
| function = function_payload if isinstance(function_payload, dict) else {} |
| name = str(call.get("name", function.get("name", "tool"))).strip() or "tool" |
| arguments = call.get("arguments", function.get("arguments", {})) |
| if not isinstance(arguments, str): |
| arguments = json.dumps(arguments, ensure_ascii=False, separators=(",", ":")) |
| return f"<tool_call> {name} {arguments}".strip() |
|
|
|
|
| def compose_generation_context( |
| prompt: str, |
| *, |
| system: str = "", |
| messages: object | None = None, |
| tool_results: object | None = None, |
| ) -> str: |
| clean_prompt = prompt.strip() |
| clean_system = system.strip() |
| lines: list[str] = [] |
| tool_protocol_seen = False |
| if clean_system: |
| lines.append(clean_system) |
|
|
| if isinstance(messages, list): |
| for message in messages: |
| if not isinstance(message, dict): |
| continue |
| role = str(message.get("role", "")).casefold() |
| content = _content_to_text(message.get("content", "")) |
| if role == "system": |
| if content: |
| lines.append(f"System instruction: {content}") |
| elif role == "user": |
| if content: |
| lines.append(f"User: {content}") |
| elif role == "assistant": |
| if content: |
| lines.append(f"Assistant: {content}") |
| if "<tool_call>" in content: |
| tool_protocol_seen = True |
| tool_calls = message.get("tool_calls", []) |
| if isinstance(tool_calls, list): |
| for call in tool_calls: |
| lines.append(_render_tool_call(call)) |
| tool_protocol_seen = True |
| elif role == "tool": |
| tool_name = str(message.get("name", message.get("tool_call_id", "tool"))) |
| lines.extend(_render_tool_result(tool_name, message.get("content", ""))) |
| tool_protocol_seen = True |
| elif content: |
| lines.append(f"{role.capitalize()}: {content}") |
|
|
| if clean_prompt: |
| lines.append(f"User: {clean_prompt}" if isinstance(messages, list) else clean_prompt) |
|
|
| if isinstance(tool_results, list): |
| for result in tool_results: |
| tool_name = "tool" |
| if isinstance(result, dict): |
| tool_name = str(result.get("name", result.get("tool", "tool"))) |
| lines.extend(_render_tool_result(tool_name, result)) |
| tool_protocol_seen = True |
| elif tool_results: |
| lines.extend(_render_tool_result("tool", tool_results)) |
| tool_protocol_seen = True |
|
|
| if tool_protocol_seen: |
| lines.append("<final>") |
| return "\n".join(line for line in lines if line).strip() |
|
|
|
|
| def command_generate_batch(args: argparse.Namespace) -> int: |
| model = ReframrModel.load(args.model) |
| prompts = load_prompt_suite(args.prompts) |
| output_path = Path(args.output) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| rows: list[dict[str, object]] = [] |
| with output_path.open("w", encoding="utf-8") as handle: |
| for index, record in enumerate(prompts): |
| prompt = str(record["prompt"]) |
| record_mode = str( |
| record.get( |
| "reasoning_mode", |
| args.reasoning_mode or model.config.default_reasoning_profile, |
| ) |
| ) |
| context = compose_generation_context( |
| prompt, |
| system=str(record.get("system", "")), |
| messages=record.get("messages"), |
| tool_results=record.get("tool_results"), |
| ) |
| max_tokens = int(record.get("max_tokens", args.max_tokens)) |
| generated_text = model.generate_text( |
| context, |
| max_tokens=max_tokens, |
| reasoning_mode=record_mode, |
| temperature=args.temperature, |
| top_k=args.decode_top_k, |
| top_p=args.decode_top_p, |
| repetition_penalty=args.repetition_penalty, |
| ) |
| row = { |
| "index": index, |
| "prompt": prompt, |
| "context": context, |
| "system": record.get("system", ""), |
| "tags": record.get("tags", []), |
| "reasoning_mode": record_mode, |
| "reasoning_tokens": reasoning_prefix(record_mode), |
| "generated_token_count": len(generated_text.split()), |
| "generated_text": generated_text, |
| } |
| rows.append(row) |
| handle.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n") |
| payload = { |
| "status": "generated", |
| "sample_count": len(rows), |
| "model_path": str(Path(args.model).resolve()), |
| "prompts_path": str(Path(args.prompts).resolve()), |
| "output_path": str(output_path.resolve()), |
| "model_loads": 1, |
| } |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_serve(args: argparse.Namespace) -> int: |
| model = ReframrModel.load(args.model) |
| default_mode = args.reasoning_mode or model.config.default_reasoning_profile |
| generated_history_by_context: dict[str, list[str]] = {} |
| session_turns_by_id: dict[str, list[tuple[str, str]]] = {} |
| for index, raw_line in enumerate(sys.stdin): |
| line = raw_line.strip() |
| if not line: |
| continue |
| try: |
| request = json.loads(line) |
| except json.JSONDecodeError as exc: |
| response = { |
| "index": index, |
| "error": "invalid_json", |
| "message": str(exc), |
| "model_loads": 1, |
| } |
| sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n") |
| sys.stdout.flush() |
| continue |
| if isinstance(request, str): |
| raw_context = request |
| base_context = request |
| request_payload: dict[str, object] = {} |
| elif isinstance(request, dict): |
| request_payload = request |
| raw_context = str(request_payload.get("prompt", request_payload.get("context", ""))) |
| base_context = compose_generation_context( |
| raw_context, |
| system=str(request_payload.get("system", "")), |
| messages=request_payload.get("messages"), |
| tool_results=request_payload.get("tool_results", request_payload.get("toolResults")), |
| ) |
| else: |
| response = { |
| "index": index, |
| "error": "invalid_request", |
| "message": "request must be a JSON object or string", |
| "model_loads": 1, |
| } |
| sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n") |
| sys.stdout.flush() |
| continue |
| session_id = str( |
| request_payload.get( |
| "session_id", |
| request_payload.get("conversation_id", request_payload.get("thread_id", "")), |
| ) |
| ).strip() |
| memory_turn_limit = max( |
| 0, |
| int(request_payload.get("memory_turns", getattr(args, "memory_turns", 16))), |
| ) |
| session_turns = session_turns_by_id.get(session_id, []) if session_id else [] |
| memory_context = "" |
| if session_turns and memory_turn_limit > 0: |
| memory_lines = ["Conversation memory:"] |
| for prior_user, prior_assistant in session_turns[-memory_turn_limit:]: |
| if prior_user.strip(): |
| memory_lines.append(f"Previous user: {prior_user.strip()}") |
| if prior_assistant.strip(): |
| memory_lines.append(f"Previous assistant: {prior_assistant.strip()}") |
| memory_context = "\n".join(memory_lines) |
| context = ( |
| f"{memory_context}\nCurrent user: {base_context}" |
| if memory_context |
| else base_context |
| ) |
| active_mode = str(request_payload.get("reasoning_mode", default_mode)) |
| max_tokens = int(request_payload.get("max_tokens", args.max_tokens)) |
| temperature = float(request_payload.get("temperature", args.temperature)) |
| top_k = int(request_payload.get("decode_top_k", args.decode_top_k)) |
| top_p = float(request_payload.get("decode_top_p", args.decode_top_p)) |
| repetition_penalty = float( |
| request_payload.get("repetition_penalty", args.repetition_penalty) |
| ) |
| history_key = " ".join(base_context.split()) |
| avoid_texts = generated_history_by_context.get(history_key, []) |
| generated_text = model.generate_text( |
| context, |
| max_tokens=max_tokens, |
| reasoning_mode=active_mode, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| avoid_texts=avoid_texts, |
| ) |
| if generated_text.strip(): |
| next_history = [*avoid_texts, generated_text] |
| generated_history_by_context[history_key] = next_history[-8:] |
| if session_id: |
| user_memory_text = raw_context if raw_context.strip() else base_context |
| next_session_turns = [*session_turns, (user_memory_text, generated_text)] |
| session_turns_by_id[session_id] = next_session_turns[-max(1, memory_turn_limit):] |
| response = { |
| "index": index, |
| "context": context, |
| "reasoning_mode": active_mode, |
| "reasoning_tokens": reasoning_prefix(active_mode), |
| "generated_token_count": len(generated_text.split()), |
| "generated_text": generated_text, |
| "memory_turn_count": len(session_turns[-memory_turn_limit:]) if memory_turn_limit > 0 else 0, |
| "model_loads": 1, |
| } |
| sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n") |
| sys.stdout.flush() |
| return 0 |
|
|
|
|
| def command_chat_completion(args: argparse.Namespace) -> int: |
| from .openai_compat import build_chat_completion_response, iter_sse_chat_completion |
|
|
| request_path = str(getattr(args, "request", "")).strip() |
| if request_path: |
| request_text = Path(request_path).read_text(encoding="utf-8") |
| else: |
| request_text = sys.stdin.read() |
| request = json.loads(request_text) |
| if not isinstance(request, dict): |
| raise ValueError("chat-completion request must be a JSON object") |
| model = ReframrModel.load(args.model) |
| if bool(request.get("stream", False)): |
| for event in iter_sse_chat_completion(model, request): |
| sys.stdout.write(event) |
| sys.stdout.flush() |
| return 0 |
| response = build_chat_completion_response(model, request) |
| sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n") |
| sys.stdout.flush() |
| return 0 |
|
|
|
|
| def command_trace(args: argparse.Namespace) -> int: |
| model = ReframrModel.load(args.model) |
| payload = model.trace_generation( |
| args.context, |
| max_tokens=args.max_tokens, |
| reasoning_mode=args.reasoning_mode, |
| top_k=args.top_k, |
| temperature=args.temperature, |
| top_p=args.decode_top_p, |
| repetition_penalty=args.repetition_penalty, |
| ) |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_inspect(args: argparse.Namespace) -> int: |
| print(json.dumps(inspect_checkpoint(args.model))) |
| return 0 |
|
|
|
|
| def command_craft_corpus(args: argparse.Namespace) -> int: |
| package = ( |
| build_generalization_corpus() |
| if args.variant == "generalization" |
| else build_foundation_corpus() |
| ) |
| paths = write_corpus_package(package, args.output_dir) |
| payload = { |
| "name": package.name, |
| "corpus_path": paths["corpus_path"], |
| "manifest_path": paths["manifest_path"], |
| "prompt_suite_path": paths["prompt_suite_path"], |
| "token_count_estimate": len(package.text.split()), |
| "memorization_samples": len(package.memorization_samples), |
| "generalization_samples": len(package.generalization_samples), |
| "generalization_prompt_count": len(package.open_ended_samples), |
| "variant": args.variant, |
| "section_counts": package.section_counts, |
| } |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_craft_curriculum(args: argparse.Namespace) -> int: |
| payload = write_curriculum_package( |
| args.output_dir, |
| CurriculumConfig( |
| records_per_category=args.records_per_category, |
| seed=args.seed, |
| train_ratio=args.train_ratio, |
| ), |
| effective_token_target=args.effective_token_target or None, |
| ) |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_craft_v2_plan(args: argparse.Namespace) -> int: |
| payload = write_v2_streaming_plan( |
| args.output, |
| rows_per_source=args.rows_per_source, |
| effective_token_target=args.effective_token_target, |
| wikipedia_mode=args.wikipedia_mode, |
| local_curriculum_paths=args.local_curriculum, |
| local_curriculum_limit=args.local_curriculum_limit, |
| ) |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_materialize_plan(args: argparse.Namespace) -> int: |
| max_bytes = int(max(0.0, float(args.max_gb)) * (1024 ** 3)) |
| shard_bytes = int(max(1, int(args.shard_mb)) * (1024 ** 2)) |
| payload = materialize_corpus_plan( |
| load_corpus_plan(args.plan), |
| args.output_dir, |
| max_bytes=max_bytes, |
| shard_bytes=shard_bytes, |
| log_every=args.log_every, |
| ) |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_craft_blind_prompts(args: argparse.Namespace) -> int: |
| payload = write_blind_prompt_suite( |
| args.output, |
| seed=args.seed, |
| variants_per_intent=args.variants_per_intent, |
| ) |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_evaluate(args: argparse.Namespace) -> int: |
| model = ReframrModel.load(args.model) |
| manifest = load_manifest(args.manifest) |
| payload = evaluate_manifest( |
| model, |
| manifest, |
| reasoning_mode=args.reasoning_mode, |
| top_k=args.top_k, |
| ) |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def command_benchmark_open(args: argparse.Namespace) -> int: |
| model = ReframrModel.load(args.model) |
| prompts = load_prompt_suite(args.prompts) |
| replay_sources = load_replay_sources( |
| args.replay_source, |
| limit=args.replay_source_limit, |
| ) |
| payload = benchmark_open_prompts( |
| model, |
| prompts, |
| reasoning_mode=args.reasoning_mode, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.decode_top_k, |
| top_p=args.decode_top_p, |
| repetition_penalty=args.repetition_penalty, |
| replay_sources=replay_sources, |
| replay_ngram_size=args.replay_ngram_size, |
| replay_overlap_threshold=args.replay_overlap_threshold, |
| ) |
| serialized = json.dumps(payload, ensure_ascii=False) |
| output_path = str(getattr(args, "output", "")).strip() |
| if output_path: |
| target = Path(output_path) |
| target.parent.mkdir(parents=True, exist_ok=True) |
| target.write_text(serialized + "\n", encoding="utf-8") |
| print(serialized) |
| return 0 |
|
|
|
|
| def command_sparse_context_benchmark(args: argparse.Namespace) -> int: |
| import random |
|
|
| model = ReframrModel.load(args.model) |
| if model.embedding_model is None: |
| raise RuntimeError("checkpoint does not contain embeddings") |
| if args.selector == "hashed": |
| kernel = HashedSparseAttention( |
| model.embedding_model.embeddings, |
| k_neighbors=args.top_k, |
| hash_bits=args.hash_bits, |
| probe_radius=args.probe_radius, |
| seed=args.seed, |
| candidate_multiplier=args.candidate_multiplier, |
| ) |
| elif args.selector == "faiss": |
| kernel = FaissSparseAttention( |
| model.embedding_model.embeddings, |
| k_neighbors=args.top_k, |
| approximate=args.faiss_hnsw, |
| hnsw_neighbors=args.hnsw_neighbors, |
| ef_search=args.ef_search, |
| ) |
| else: |
| kernel = AnalyticalSparseAttention( |
| model.embedding_model.embeddings, |
| k_neighbors=args.top_k, |
| ) |
| vocab_size = len(model.embedding_model.id_to_token) |
| rng = random.Random(int(args.seed)) |
| context_tokens = [rng.randrange(vocab_size) for _ in range(max(0, int(args.context_tokens)))] |
| query_tokens = [rng.randrange(vocab_size) for _ in range(max(0, int(args.query_count)))] |
| payload = kernel.benchmark_selection( |
| context_tokens, |
| query_tokens, |
| top_k=args.top_k, |
| ) |
| if args.compare_exact and args.selector == "hashed": |
| payload["exact_recall"] = compare_selectors( |
| model.embedding_model.embeddings, |
| context_tokens, |
| query_tokens, |
| top_k=args.top_k, |
| hash_bits=args.hash_bits, |
| probe_radius=args.probe_radius, |
| seed=args.seed, |
| ) |
| payload.update( |
| { |
| "schema_version": "reframr.sparse_context_benchmark.v1", |
| "model": str(Path(args.model).resolve()), |
| "selector": args.selector, |
| "hash_bits": int(args.hash_bits) if args.selector == "hashed" else 0, |
| "probe_radius": int(args.probe_radius) if args.selector == "hashed" else 0, |
| "candidate_multiplier": int(args.candidate_multiplier) if args.selector == "hashed" else 0, |
| "faiss_approximate": bool(args.selector == "faiss" and args.faiss_hnsw), |
| "hnsw_neighbors": int(args.hnsw_neighbors) if args.selector == "faiss" and args.faiss_hnsw else 0, |
| "ef_search": int(args.ef_search) if args.selector == "faiss" and args.faiss_hnsw else 0, |
| "tokenizer_vocab_size": vocab_size, |
| "embedding_dim": kernel.embedding_dim, |
| } |
| ) |
| serialized = json.dumps(payload, ensure_ascii=False) |
| output_path = str(getattr(args, "output", "")).strip() |
| if output_path: |
| target = Path(output_path) |
| target.parent.mkdir(parents=True, exist_ok=True) |
| target.write_text(serialized + "\n", encoding="utf-8") |
| print(serialized) |
| return 0 |
|
|
|
|
| def command_import_hf(args: argparse.Namespace) -> int: |
| payload = import_hf_dataset( |
| dataset=args.dataset, |
| output_path=args.output, |
| config=args.config, |
| split=args.split, |
| text_field=args.text_field, |
| limit=args.limit, |
| streaming=not args.no_streaming, |
| preference_target=args.preference_target, |
| min_words=args.min_words, |
| max_words=args.max_words, |
| min_alpha_ratio=args.min_alpha_ratio, |
| allowed_languages=tuple( |
| segment.strip() |
| for segment in args.allowed_languages.split(",") |
| if segment.strip() |
| ), |
| ) |
| print(json.dumps(payload)) |
| return 0 |
|
|
|
|
| def main(argv: list[str] | None = None) -> int: |
| configure_stdio() |
| parser = build_parser() |
| args = parser.parse_args(argv) |
| if args.command in {"compute", "train"}: |
| return command_compute(args) |
| if args.command == "recompute": |
| return command_recompute(args) |
| if args.command == "predict": |
| return command_predict(args) |
| if args.command == "generate": |
| return command_generate(args) |
| if args.command == "generate-batch": |
| return command_generate_batch(args) |
| if args.command == "serve": |
| return command_serve(args) |
| if args.command == "chat-completion": |
| return command_chat_completion(args) |
| if args.command == "trace": |
| return command_trace(args) |
| if args.command == "inspect": |
| return command_inspect(args) |
| if args.command == "craft-corpus": |
| return command_craft_corpus(args) |
| if args.command == "craft-curriculum": |
| return command_craft_curriculum(args) |
| if args.command == "craft-v2-plan": |
| return command_craft_v2_plan(args) |
| if args.command == "materialize-plan": |
| return command_materialize_plan(args) |
| if args.command == "craft-blind-prompts": |
| return command_craft_blind_prompts(args) |
| if args.command == "evaluate": |
| return command_evaluate(args) |
| if args.command == "benchmark-open": |
| return command_benchmark_open(args) |
| if args.command == "sparse-context-benchmark": |
| return command_sparse_context_benchmark(args) |
| if args.command == "import-hf": |
| return command_import_hf(args) |
| parser.error(f"Unknown command: {args.command}") |
| return 2 |
|
|