import argparse import json import sys from dataclasses import replace from pathlib import Path from .checkpoint import inspect_checkpoint from .config import ReframrConfig from .corpus_recipes import ( build_foundation_corpus, build_generalization_corpus, write_corpus_package, ) from .curriculum import CurriculumConfig, write_curriculum_package from .datasets import load_prompt_suite, load_text_corpus from .evaluation import ( benchmark_open_prompts, evaluate_manifest, load_manifest, load_replay_sources, ) from .hf_import import import_hf_dataset from .materialize import DEFAULT_CACHE_BYTE_LIMIT, DEFAULT_SHARD_BYTE_LIMIT, materialize_corpus_plan from .model import ReframrModel from .reasoning import REASONING_PROFILES, TOKENIZER_NAME, reasoning_prefix from .sparse_context import ( AnalyticalSparseAttention, FaissSparseAttention, HashedSparseAttention, compare_selectors, ) from .streaming import estimate_corpus_plan, fit_model_from_corpus_plan, load_corpus_plan from .tokenizer import MAX_TOKENIZER_VOCAB_SIZE, clamp_vocab_size, recommend_vocab_size from .v2_data import write_blind_prompt_suite, write_v2_streaming_plan def configure_stdio() -> None: for stream in (sys.stdout, sys.stderr): reconfigure = getattr(stream, "reconfigure", None) if reconfigure is not None: reconfigure(encoding="utf-8") def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="reframr", description="Compute and query REFRAMR analytical language model checkpoints.", ) subparsers = parser.add_subparsers(dest="command", required=True) compute = subparsers.add_parser( "compute", aliases=["train"], help="Compute a REFRAMR checkpoint from a text corpus with no epoch loop.", ) compute.add_argument( "--input", required=True, help="Path to a text, JSON, or JSONL corpus file, or a directory of such files.", ) compute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.") compute.add_argument("--embedding-dim", type=int, default=16) compute.add_argument("--state-dim", type=int, default=32) compute.add_argument("--timescales", default="1.0,0.5,0.25,0.125") compute.add_argument("--window-size", type=int, default=2) compute.add_argument("--regularization", type=float, default=1e-3) compute.add_argument("--min-frequency", type=int, default=1) compute.add_argument( "--max-vocab", type=int, default=256, help="Cap analytical embedding vocabulary to keep weight computation fast on CPU.", ) compute.add_argument("--tokenizer-vocab-size", type=int, default=0) compute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2) compute.add_argument( "--max-training-examples", type=int, default=60000, help="Cap sampled recurrent training states while still reading the full corpus for tokenizer, embeddings, and transitions.", ) compute.add_argument( "--max-memory-examples", type=int, default=-1, help="Cap saved associative memory examples separately from readout training. Use -1 to match --max-training-examples.", ) compute.add_argument( "--max-state-tokens-per-document", type=int, default=768, help="Cap recurrent state steps per document with a deterministic corpus sketch. Use 0 to step full documents.", ) compute.add_argument( "--max-transition-contexts", type=int, default=4096, help="Keep only the strongest learned transition contexts per order. Use 0 to disable the cap.", ) compute.add_argument( "--max-transition-next-tokens", type=int, default=4, help="Keep this many learned next-token choices per transition context.", ) case_group = compute.add_mutually_exclusive_group() case_group.add_argument( "--lowercase", action="store_true", help="Normalize corpus text to lowercase before tokenization.", ) case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS) compute.add_argument( "--reasoning-profile", choices=sorted(REASONING_PROFILES), default="none", help="Default reasoning-control profile baked into the checkpoint.", ) compute.add_argument( "--layout-profile", default="rfm-base", help="Structured analytical layout label to store in checkpoint metadata, such as rfm-70b-structured.", ) compute.add_argument( "--effective-parameter-target", type=int, default=0, help="Dense-equivalent structured target to store in checkpoint metadata; this does not allocate dense tensors.", ) recompute = subparsers.add_parser( "recompute", help="Compute a REFRAMR checkpoint from a streaming corpus plan with no raw-text cache.", ) recompute.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.") recompute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.") recompute.add_argument("--embedding-dim", type=int, default=16) recompute.add_argument("--state-dim", type=int, default=32) recompute.add_argument("--timescales", default="1.0,0.5,0.25,0.125") recompute.add_argument("--window-size", type=int, default=2) recompute.add_argument("--regularization", type=float, default=1e-3) recompute.add_argument("--min-frequency", type=int, default=1) recompute.add_argument("--max-vocab", type=int, default=256) recompute.add_argument("--tokenizer-vocab-size", type=int, default=0) recompute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2) recompute.add_argument("--max-training-examples", type=int, default=60000) recompute.add_argument("--max-memory-examples", type=int, default=-1) recompute.add_argument("--max-state-tokens-per-document", type=int, default=768) recompute.add_argument("--max-transition-contexts", type=int, default=4096) recompute.add_argument("--max-transition-next-tokens", type=int, default=4) recompute.add_argument("--log-every", type=int, default=0) recompute.add_argument( "--dry-run", action="store_true", help="Estimate accepted rows and compute shape without fitting or saving a checkpoint.", ) recompute.add_argument( "--estimate-max-rows-per-source", type=int, default=0, help="Optional cap for preflight row scanning per local source.", ) recompute.add_argument( "--calibrate-rows", type=int, default=0, help="Run a bounded representative fit first and estimate full-run wall-clock time.", ) recompute.add_argument( "--calibrate-only", action="store_true", help="Stop after calibration instead of computing and saving the full checkpoint.", ) recompute_case_group = recompute.add_mutually_exclusive_group() recompute_case_group.add_argument("--lowercase", action="store_true") recompute_case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS) recompute.add_argument( "--reasoning-profile", choices=sorted(REASONING_PROFILES), default="none", help="Default reasoning-control profile baked into the checkpoint.", ) recompute.add_argument( "--layout-profile", default="rfm-base", help="Structured analytical layout label to store in checkpoint metadata, such as rfm-70b-structured.", ) recompute.add_argument( "--effective-parameter-target", type=int, default=0, help="Dense-equivalent structured target to store in checkpoint metadata; this does not allocate dense tensors.", ) predict = subparsers.add_parser("predict", help="Predict the next-token distribution from a saved model.") predict.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") predict.add_argument("--context", required=True, help="Input context text.") predict.add_argument("--top-k", type=int, default=5) predict.add_argument( "--reasoning-mode", choices=sorted(REASONING_PROFILES), default=None, help="Override the checkpoint's default reasoning-control profile.", ) generate = subparsers.add_parser("generate", help="Generate long-form text from a saved model.") generate.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") generate.add_argument("--context", required=True, help="Prompt or starting context text.") generate.add_argument("--system", default="", help="Optional system instruction to prepend as learned context.") generate.add_argument("--max-tokens", type=int, default=64) generate.add_argument("--temperature", type=float, default=0.82) generate.add_argument("--decode-top-k", type=int, default=24) generate.add_argument("--decode-top-p", type=float, default=0.92) generate.add_argument("--repetition-penalty", type=float, default=1.18) generate.add_argument( "--reasoning-mode", choices=sorted(REASONING_PROFILES), default=None, help="Override the checkpoint's default reasoning-control profile.", ) generate_batch = subparsers.add_parser( "generate-batch", help="Generate answers for a prompt file while keeping one checkpoint loaded.", ) generate_batch.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") generate_batch.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.") generate_batch.add_argument("--output", required=True, help="Path to write JSONL generations.") generate_batch.add_argument("--max-tokens", type=int, default=64) generate_batch.add_argument("--temperature", type=float, default=0.82) generate_batch.add_argument("--decode-top-k", type=int, default=24) generate_batch.add_argument("--decode-top-p", type=float, default=0.92) generate_batch.add_argument("--repetition-penalty", type=float, default=1.18) generate_batch.add_argument( "--reasoning-mode", choices=sorted(REASONING_PROFILES), default=None, help="Override the checkpoint's default reasoning-control profile.", ) serve = subparsers.add_parser( "serve", help="Keep one checkpoint loaded and answer JSONL generation requests from stdin.", ) serve.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") serve.add_argument("--max-tokens", type=int, default=64) serve.add_argument("--temperature", type=float, default=0.82) serve.add_argument("--decode-top-k", type=int, default=24) serve.add_argument("--decode-top-p", type=float, default=0.92) serve.add_argument("--repetition-penalty", type=float, default=1.18) serve.add_argument( "--memory-turns", type=int, default=16, help="Number of prior JSONL session turns to prepend as conversation memory.", ) serve.add_argument( "--reasoning-mode", choices=sorted(REASONING_PROFILES), default=None, help="Override the checkpoint's default reasoning-control profile.", ) chat_completion = subparsers.add_parser( "chat-completion", help="Run one OpenAI-compatible chat completion request from stdin or a JSON file.", ) chat_completion.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") chat_completion.add_argument( "--request", default="", help="Optional path to a JSON request. Defaults to stdin.", ) trace = subparsers.add_parser("trace", help="Trace REFRAMR reasoning components through generation steps.") trace.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.") trace.add_argument("--context", required=True, help="Prompt or starting context text.") trace.add_argument("--max-tokens", type=int, default=8) trace.add_argument("--top-k", type=int, default=5) trace.add_argument("--temperature", type=float, default=0.82) trace.add_argument("--decode-top-p", type=float, default=0.92) trace.add_argument("--repetition-penalty", type=float, default=1.18) trace.add_argument( "--reasoning-mode", choices=sorted(REASONING_PROFILES), default=None, help="Override the checkpoint's default reasoning-control profile.", ) inspect = subparsers.add_parser("inspect", help="Inspect a REFRAMR safetensors checkpoint.") inspect.add_argument("--model", required=True, help="Path to a .safetensors checkpoint.") craft = subparsers.add_parser( "craft-corpus", help="Generate a JSON-first bootstrap corpus, manifest, and generalization prompt suite.", ) craft.add_argument("--output-dir", required=True, help="Directory to write corpus and manifest files.") craft.add_argument( "--variant", choices=("foundation", "generalization"), default="foundation", help="Choose between the mixed foundation corpus and the language-first generalization corpus.", ) craft_curriculum = subparsers.add_parser( "craft-curriculum", help="Generate the OkeyMeta JSON curriculum shard, manifest, holdout prompts, and recompute plan.", ) craft_curriculum.add_argument("--output-dir", required=True, help="Directory to write curriculum files.") craft_curriculum.add_argument( "--records-per-category", type=int, default=1000, help="How many JSON records to generate for each curriculum category.", ) craft_curriculum.add_argument("--seed", type=int, default=7) craft_curriculum.add_argument("--train-ratio", type=float, default=0.92) craft_curriculum.add_argument( "--effective-token-target", type=int, default=0, help="Set plan weighting so compact curriculum statistics represent this many effective tokens.", ) craft_v2_plan = subparsers.add_parser( "craft-v2-plan", help="Write a strict streaming Hugging Face recompute plan for the v2 data mix.", ) craft_v2_plan.add_argument("--output", required=True, help="Path to write the streaming plan JSON.") craft_v2_plan.add_argument( "--rows-per-source", type=int, default=10_000, help="Base accepted row target per source before per-domain multipliers.", ) craft_v2_plan.add_argument( "--effective-token-target", type=int, default=0, help="Optional effective token target recorded in the plan metadata.", ) craft_v2_plan.add_argument( "--wikipedia-mode", choices=("skip", "hf", "viewer"), default="skip", help="Use skip for fast smoke runs; hf/viewer include Wikipedia through the fast HF viewer adapter.", ) craft_v2_plan.add_argument( "--local-curriculum", action="append", default=[], help="Local JSON/JSONL curriculum shard to blend before HF sources.", ) craft_v2_plan.add_argument( "--local-curriculum-limit", type=int, default=0, help="Maximum accepted rows per local curriculum shard. Use 0 for all rows.", ) materialize_plan = subparsers.add_parser( "materialize-plan", help="Write bounded normalized JSONL shards from a corpus plan, then emit a local recompute plan.", ) materialize_plan.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.") materialize_plan.add_argument("--output-dir", required=True, help="Directory for normalized JSONL shards.") materialize_plan.add_argument( "--max-gb", type=float, default=DEFAULT_CACHE_BYTE_LIMIT / (1024 ** 3), help="Maximum normalized cache size in GB. Defaults to 3GB.", ) materialize_plan.add_argument( "--shard-mb", type=int, default=DEFAULT_SHARD_BYTE_LIMIT // (1024 ** 2), help="Maximum size per JSONL shard in MB.", ) materialize_plan.add_argument("--log-every", type=int, default=0) craft_blind_prompts = subparsers.add_parser( "craft-blind-prompts", help="Write a blind open-prompt JSONL suite for v2 generalization checks.", ) craft_blind_prompts.add_argument("--output", required=True, help="Path to write JSONL prompts.") craft_blind_prompts.add_argument("--seed", type=int, default=2026) craft_blind_prompts.add_argument( "--variants-per-intent", type=int, default=4, help="How many prompt variants to generate per evaluation intent.", ) evaluate = subparsers.add_parser( "evaluate", help="Evaluate memorization and held-out generalization from a benchmark manifest.", ) evaluate.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.") evaluate.add_argument("--manifest", required=True, help="Path to a corpus benchmark manifest JSON file.") evaluate.add_argument( "--reasoning-mode", choices=sorted(REASONING_PROFILES), default=None, help="Override the checkpoint's default reasoning-control profile during evaluation.", ) evaluate.add_argument("--top-k", type=int, default=5) benchmark_open = subparsers.add_parser( "benchmark-open", help="Run arbitrary prompt files through a checkpoint with open-ended output metrics.", ) benchmark_open.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.") benchmark_open.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.") benchmark_open.add_argument("--max-tokens", type=int, default=64) benchmark_open.add_argument("--temperature", type=float, default=0.82) benchmark_open.add_argument("--decode-top-k", type=int, default=24) benchmark_open.add_argument("--decode-top-p", type=float, default=0.92) benchmark_open.add_argument("--repetition-penalty", type=float, default=1.18) benchmark_open.add_argument( "--replay-source", action="append", default=[], help="JSON/JSONL/TXT corpus path used only to flag generated source-row replay.", ) benchmark_open.add_argument( "--replay-source-limit", type=int, default=10_000, help="Maximum source rows to load for replay checks.", ) benchmark_open.add_argument("--replay-ngram-size", type=int, default=8) benchmark_open.add_argument("--replay-overlap-threshold", type=float, default=0.70) benchmark_open.add_argument( "--output", default="", help="Optional UTF-8 JSON path for benchmark results.", ) benchmark_open.add_argument( "--reasoning-mode", choices=sorted(REASONING_PROFILES), default=None, help="Override the checkpoint's default reasoning-control profile during benchmarking.", ) sparse_benchmark = subparsers.add_parser( "sparse-context-benchmark", help="Measure analytical sparse-context selection speed on a checkpoint embedding table.", ) sparse_benchmark.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.") sparse_benchmark.add_argument("--context-tokens", type=int, default=100_000) sparse_benchmark.add_argument("--query-count", type=int, default=64) sparse_benchmark.add_argument("--top-k", type=int, default=64) sparse_benchmark.add_argument("--seed", type=int, default=2026) sparse_benchmark.add_argument( "--selector", choices=("exact", "hashed", "faiss"), default="hashed", help="Use exact cosine scan or hashed approximate sparse selection.", ) sparse_benchmark.add_argument("--hash-bits", type=int, default=12) sparse_benchmark.add_argument("--probe-radius", type=int, default=1) sparse_benchmark.add_argument("--candidate-multiplier", type=int, default=12) sparse_benchmark.add_argument("--faiss-hnsw", action="store_true") sparse_benchmark.add_argument("--hnsw-neighbors", type=int, default=32) sparse_benchmark.add_argument("--ef-search", type=int, default=64) sparse_benchmark.add_argument( "--compare-exact", action="store_true", help="Also compute exact top-k recall for the selected query set.", ) sparse_benchmark.add_argument("--output", default="", help="Optional UTF-8 JSON path for benchmark results.") import_hf = subparsers.add_parser( "import-hf", help="Import Hugging Face dataset text into the REFRAMR JSON record standard.", ) import_hf.add_argument("--dataset", required=True, help="Hugging Face dataset id.") import_hf.add_argument("--output", required=True, help="Path to write the JSONL corpus.") import_hf.add_argument("--config", default=None, help="Optional dataset config/subset.") import_hf.add_argument("--split", default="train", help="Dataset split to import.") import_hf.add_argument("--text-field", default=None, help="Explicit text column name.") import_hf.add_argument("--limit", type=int, default=1000, help="Maximum records to import.") import_hf.add_argument( "--min-words", type=int, default=0, help="Drop imported records shorter than this many words.", ) import_hf.add_argument( "--max-words", type=int, default=0, help="Drop imported records longer than this many words. Use 0 to disable.", ) import_hf.add_argument( "--min-alpha-ratio", type=float, default=0.0, help="Drop imported records whose alphabetic-character ratio falls below this threshold.", ) import_hf.add_argument( "--allowed-languages", default="", help="Optional comma-separated language codes to keep, such as en,yo,ig,ha.", ) import_hf.add_argument( "--preference-target", choices=("both", "chosen", "rejected"), default="chosen", help="When importing preference datasets, keep both sides or only the chosen/rejected side.", ) import_hf.add_argument( "--no-streaming", action="store_true", help="Disable streaming dataset reads.", ) return parser def parse_timescales(raw_timescales: str) -> tuple[float, ...]: values = [segment.strip() for segment in raw_timescales.split(",") if segment.strip()] if not values: raise ValueError("At least one timescale is required.") return tuple(float(value) for value in values) def command_compute(args: argparse.Namespace) -> int: text = load_text_corpus(args.input) requested_vocab_size = args.tokenizer_vocab_size or recommend_vocab_size( text, lowercase=args.lowercase, ) tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size) config = ReframrConfig( embedding_dim=args.embedding_dim, state_dim=args.state_dim, timescales=parse_timescales(args.timescales), window_size=args.window_size, regularization=args.regularization, min_frequency=args.min_frequency, max_vocab=args.max_vocab, tokenizer_vocab_size=tokenizer_vocab_size, tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency, max_training_examples=args.max_training_examples, max_memory_examples=( None if args.max_memory_examples < 0 else args.max_memory_examples ), max_state_tokens_per_document=( None if args.max_state_tokens_per_document <= 0 else args.max_state_tokens_per_document ), max_transition_contexts_per_order=( args.max_transition_contexts if args.max_transition_contexts > 0 else None ), max_transition_next_tokens=args.max_transition_next_tokens, lowercase=args.lowercase, default_reasoning_profile=args.reasoning_profile, layout_profile=args.layout_profile, effective_parameter_target=args.effective_parameter_target, ) model = ReframrModel(config).fit(text) model.save(args.output) assert model.tokenizer is not None assert model.embedding_model is not None summary = { "status": "computed", "format": "safetensors", "model_path": str(Path(args.output).resolve()), "tokenizer_name": TOKENIZER_NAME, "vocab_size": len(model.embedding_model.id_to_token), "tokenizer_vocab_budget": config.tokenizer_vocab_size, "tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE, "tokenizer_vocab_size": model.tokenizer.vocab_size, "reasoning_profile": config.default_reasoning_profile, "reasoning_tokens": reasoning_prefix(config.default_reasoning_profile), "lowercase": config.lowercase, "max_training_examples": config.max_training_examples, "max_memory_examples": config.max_memory_examples, "max_state_tokens_per_document": config.max_state_tokens_per_document, "max_transition_contexts_per_order": config.max_transition_contexts_per_order, "max_transition_next_tokens": config.max_transition_next_tokens, "embedding_dim": config.embedding_dim, "state_dim": config.state_dim, "timescales": list(config.timescales), "layout_profile": config.layout_profile, "effective_parameter_target": config.effective_parameter_target, } print(json.dumps(summary)) return 0 def command_recompute(args: argparse.Namespace) -> int: plan = load_corpus_plan(args.plan) requested_vocab_size = args.tokenizer_vocab_size or 1024 tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size) config = ReframrConfig( embedding_dim=args.embedding_dim, state_dim=args.state_dim, timescales=parse_timescales(args.timescales), window_size=args.window_size, regularization=args.regularization, min_frequency=args.min_frequency, max_vocab=args.max_vocab, tokenizer_vocab_size=tokenizer_vocab_size, tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency, max_training_examples=args.max_training_examples, max_memory_examples=( None if args.max_memory_examples < 0 else args.max_memory_examples ), max_state_tokens_per_document=( None if args.max_state_tokens_per_document <= 0 else args.max_state_tokens_per_document ), max_transition_contexts_per_order=( args.max_transition_contexts if args.max_transition_contexts > 0 else None ), max_transition_next_tokens=args.max_transition_next_tokens, lowercase=args.lowercase, default_reasoning_profile=args.reasoning_profile, layout_profile=args.layout_profile, effective_parameter_target=args.effective_parameter_target, ) if args.dry_run: estimate = estimate_corpus_plan( plan, max_rows_per_source=args.estimate_max_rows_per_source, ) accepted = int(estimate.get("accepted_documents", 0) or 0) state_cap = config.max_state_tokens_per_document or 768 estimated_state_tokens = accepted * state_cap summary = { "status": "dry_run", "plan_path": str(Path(args.plan).resolve()), "output_path": str(Path(args.output).resolve()), "accepted_documents": accepted, "seen_texts": estimate.get("seen_texts", 0), "rejected_texts": estimate.get("rejected_texts", 0), "estimated_words": estimate.get("estimated_words", 0), "estimated_state_token_budget": estimated_state_tokens, "embedding_dim": config.embedding_dim, "state_dim": config.state_dim, "tokenizer_vocab_budget": config.tokenizer_vocab_size, "max_vocab": config.max_vocab, "max_training_examples": config.max_training_examples, "max_memory_examples": config.max_memory_examples, "max_state_tokens_per_document": config.max_state_tokens_per_document, "max_transition_contexts_per_order": config.max_transition_contexts_per_order, "max_transition_next_tokens": config.max_transition_next_tokens, "layout_profile": config.layout_profile, "effective_parameter_target": config.effective_parameter_target, "estimate_seconds": estimate.get("seconds", 0), "sources": estimate.get("sources", []), } print(json.dumps(summary)) return 0 if args.calibrate_rows > 0: calibration = _calibrate_recompute_plan( plan, config, target_rows=args.calibrate_rows, estimate_max_rows_per_source=args.estimate_max_rows_per_source, log_every=args.log_every, ) print(json.dumps(calibration), flush=True) if args.calibrate_only: return 0 model, payload = fit_model_from_corpus_plan( plan, config, log_every=args.log_every, ) model.save(args.output) summary = { "status": "recomputed", "format": "safetensors", "streaming": True, "plan_path": str(Path(args.plan).resolve()), "model_path": str(Path(args.output).resolve()), "tokenizer_name": TOKENIZER_NAME, "tokenizer_vocab_budget": config.tokenizer_vocab_size, "tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE, "tokenizer_vocab_size": payload["tokenizer_vocab_size"], "vocab_size": payload["embedding_vocab_size"], "documents_processed": payload["documents_processed"], "source_counts": payload["source_counts"], "examples_processed": payload["examples_processed"], "associative_examples": payload["associative_examples"], "answer_associative_examples": payload.get("answer_associative_examples", 0), "general_associative_examples": payload.get("general_associative_examples", 0), "answer_intent_examples": payload.get("answer_intent_examples", 0), "answer_start_examples": payload.get("answer_start_examples", 0), "answer_sequence_examples": payload.get("answer_sequence_examples", 0), "prompt_answer_readout_examples": payload.get("prompt_answer_readout_examples", 0), "prompt_answer_start_readout_examples": payload.get("prompt_answer_start_readout_examples", 0), "preference_pairs": payload.get("preference_pairs", 0), "preference_state_pairs": payload.get("preference_state_pairs", 0), "stage_seconds": payload.get("stage_seconds", {}), "readout_solver": payload.get("readout_solver"), "reasoning_profile": config.default_reasoning_profile, "reasoning_tokens": reasoning_prefix(config.default_reasoning_profile), "lowercase": config.lowercase, "max_training_examples": config.max_training_examples, "max_memory_examples": config.max_memory_examples, "max_state_tokens_per_document": config.max_state_tokens_per_document, "state_tokens_before_sketch": payload.get("state_tokens_before_sketch", 0), "state_tokens_after_sketch": payload.get("state_tokens_after_sketch", 0), "max_transition_contexts_per_order": config.max_transition_contexts_per_order, "max_transition_next_tokens": config.max_transition_next_tokens, "embedding_dim": config.embedding_dim, "state_dim": config.state_dim, "timescales": list(config.timescales), "layout_profile": config.layout_profile, "effective_parameter_target": config.effective_parameter_target, } print(json.dumps(summary)) return 0 def _limited_calibration_plan( plan: list[object], *, target_rows: int, full_accepted: int, ) -> list[object]: if target_rows <= 0: return plan ratio = min(1.0, target_rows / max(1, full_accepted)) limited: list[object] = [] fallback_limit = max(1, target_rows // max(1, len(plan))) for entry in plan: raw_limit = int(getattr(entry, "limit", 0) or 0) if raw_limit > 0: next_limit = max(1, min(raw_limit, int((raw_limit * ratio) + 0.999999))) else: record_count = len(getattr(entry, "records", ()) or ()) source_cap = record_count if record_count > 0 else fallback_limit next_limit = max(1, min(source_cap, fallback_limit)) limited.append(replace(entry, limit=next_limit)) return limited def _estimate_full_seconds_from_calibration( *, full_documents: int, full_state_tokens: int, calibration_payload: dict[str, object], ) -> dict[str, object]: calibration_documents = max(1, int(calibration_payload.get("documents_processed", 0) or 0)) calibration_state_tokens = max( 1, int(calibration_payload.get("state_tokens_after_sketch", 0) or 0), ) document_scale = full_documents / calibration_documents state_scale = full_state_tokens / calibration_state_tokens stage_seconds = calibration_payload.get("stage_seconds", {}) if not isinstance(stage_seconds, dict): stage_seconds = {} fixed_weighted = {"tokenizer_fit", "embedding", "kernel_warmup", "preference"} state_weighted = {"state_and_readout", "finalize_prompt_readouts", "finalize_memory_arrays"} document_weighted = { "stream_and_segment", "vocabulary", "cooccurrence", "model_finalize", "finalize_answer_sequences", "finalize_transition_tables", } stage_estimates: dict[str, float] = {} for stage, raw_seconds in stage_seconds.items(): seconds = float(raw_seconds) if stage in fixed_weighted: scale = 1.0 elif stage in state_weighted: scale = state_scale elif stage in document_weighted: scale = document_scale else: scale = max(document_scale, state_scale) stage_estimates[str(stage)] = round(seconds * scale, 3) total_seconds = round(sum(stage_estimates.values()), 3) return { "estimated_full_seconds": total_seconds, "estimated_full_minutes": round(total_seconds / 60.0, 3), "scale_documents": round(document_scale, 4), "scale_state_tokens": round(state_scale, 4), "stage_estimates": stage_estimates, } def _calibrate_recompute_plan( plan: list[object], config: ReframrConfig, *, target_rows: int, estimate_max_rows_per_source: int, log_every: int, ) -> dict[str, object]: full_estimate = estimate_corpus_plan( plan, max_rows_per_source=estimate_max_rows_per_source, ) full_documents = int(full_estimate.get("accepted_documents", 0) or 0) state_cap = config.max_state_tokens_per_document or 768 full_state_tokens = full_documents * state_cap calibration_plan = _limited_calibration_plan( plan, target_rows=target_rows, full_accepted=full_documents, ) _, calibration_payload = fit_model_from_corpus_plan( calibration_plan, config, log_every=log_every, ) runtime_estimate = _estimate_full_seconds_from_calibration( full_documents=full_documents, full_state_tokens=full_state_tokens, calibration_payload=calibration_payload, ) return { "status": "calibration", "target_rows": target_rows, "full_accepted_documents": full_documents, "full_estimated_words": full_estimate.get("estimated_words", 0), "full_estimated_state_token_budget": full_state_tokens, "calibration_documents": calibration_payload.get("documents_processed", 0), "calibration_state_tokens": calibration_payload.get("state_tokens_after_sketch", 0), "calibration_stage_seconds": calibration_payload.get("stage_seconds", {}), **runtime_estimate, } def command_predict(args: argparse.Namespace) -> int: model = ReframrModel.load(args.model) distribution = model.predict_next_distribution( args.context, reasoning_mode=args.reasoning_mode, ) predictions = sorted( distribution.items(), key=lambda item: item[1], reverse=True, )[: args.top_k] payload = { "context": args.context, "reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile, "reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile), "predictions": [ {"token": token, "probability": probability} for token, probability in predictions ], } print(json.dumps(payload)) return 0 def command_generate(args: argparse.Namespace) -> int: model = ReframrModel.load(args.model) context = compose_generation_context(args.context, system=args.system) generated_text = model.generate_text( context, max_tokens=args.max_tokens, reasoning_mode=args.reasoning_mode, temperature=args.temperature, top_k=args.decode_top_k, top_p=args.decode_top_p, repetition_penalty=args.repetition_penalty, ) payload = { "context": context, "reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile, "reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile), "generated_token_count": len(generated_text.split()), "generated_text": generated_text, } print(json.dumps(payload)) return 0 def _content_to_text(content: object) -> str: if content is None: return "" if isinstance(content, str): return content.strip() if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict): text = item.get("text", item.get("content", item.get("input_text", ""))) if text: parts.append(str(text).strip()) elif item is not None: parts.append(str(item).strip()) return "\n".join(part for part in parts if part) if isinstance(content, (dict, tuple)): return json.dumps(content, ensure_ascii=False, separators=(",", ":")) return str(content).strip() def _coerce_json_payload(payload: object) -> object: if not isinstance(payload, str): return payload stripped = payload.strip() if not stripped: return "" try: return json.loads(stripped) except json.JSONDecodeError: return stripped def _render_source_lines(payload: object) -> list[str]: if not isinstance(payload, dict): return [] nested_content = payload.get("content") if isinstance(nested_content, dict): nested_lines = _render_source_lines(nested_content) if nested_lines: return nested_lines raw_sources = payload.get("sources", payload.get("source", [])) if isinstance(raw_sources, dict): sources = [raw_sources] elif isinstance(raw_sources, list): sources = raw_sources elif raw_sources: sources = [raw_sources] else: sources = [] lines: list[str] = [] for source in sources: if isinstance(source, dict): title = str(source.get("title", source.get("name", "source"))).strip() url = str(source.get("url", source.get("uri", ""))).strip() snippet = str(source.get("snippet", source.get("text", source.get("content", "")))).strip() parts = [part for part in (title, url, snippet) if part] if parts: lines.append(f" {' | '.join(parts)}") elif source: lines.append(f" {str(source).strip()}") return lines def _render_tool_result(name: str, payload: object) -> list[str]: tool_name = name.strip() or "tool" parsed = _coerce_json_payload(payload) if isinstance(parsed, dict): explicit_name = str(parsed.get("name", parsed.get("tool", ""))).strip() if explicit_name: tool_name = explicit_name status = str(parsed.get("status", "")).casefold() ok_value = parsed.get("ok", None) error = str(parsed.get("error", parsed.get("message", ""))).strip() failed = ok_value is False or status in {"error", "failed", "failure", "timeout"} or bool(error) if failed: first = f" {tool_name} failed: {error or status or 'unknown error'}" else: summary = str(parsed.get("summary", parsed.get("content", parsed.get("text", "")))).strip() first = f" {tool_name} ok" if summary and not _render_source_lines(parsed): first = f"{first}: {summary}" return [first, *_render_source_lines(parsed)] if parsed: return [f" {tool_name} {str(parsed).strip()}"] return [f" {tool_name} empty"] def _render_tool_call(call: object) -> str: if not isinstance(call, dict): return f" {str(call).strip()}" function_payload = call.get("function", {}) function = function_payload if isinstance(function_payload, dict) else {} name = str(call.get("name", function.get("name", "tool"))).strip() or "tool" arguments = call.get("arguments", function.get("arguments", {})) if not isinstance(arguments, str): arguments = json.dumps(arguments, ensure_ascii=False, separators=(",", ":")) return f" {name} {arguments}".strip() def compose_generation_context( prompt: str, *, system: str = "", messages: object | None = None, tool_results: object | None = None, ) -> str: clean_prompt = prompt.strip() clean_system = system.strip() lines: list[str] = [] tool_protocol_seen = False if clean_system: lines.append(clean_system) if isinstance(messages, list): for message in messages: if not isinstance(message, dict): continue role = str(message.get("role", "")).casefold() content = _content_to_text(message.get("content", "")) if role == "system": if content: lines.append(f"System instruction: {content}") elif role == "user": if content: lines.append(f"User: {content}") elif role == "assistant": if content: lines.append(f"Assistant: {content}") if "" in content: tool_protocol_seen = True tool_calls = message.get("tool_calls", []) if isinstance(tool_calls, list): for call in tool_calls: lines.append(_render_tool_call(call)) tool_protocol_seen = True elif role == "tool": tool_name = str(message.get("name", message.get("tool_call_id", "tool"))) lines.extend(_render_tool_result(tool_name, message.get("content", ""))) tool_protocol_seen = True elif content: lines.append(f"{role.capitalize()}: {content}") if clean_prompt: lines.append(f"User: {clean_prompt}" if isinstance(messages, list) else clean_prompt) if isinstance(tool_results, list): for result in tool_results: tool_name = "tool" if isinstance(result, dict): tool_name = str(result.get("name", result.get("tool", "tool"))) lines.extend(_render_tool_result(tool_name, result)) tool_protocol_seen = True elif tool_results: lines.extend(_render_tool_result("tool", tool_results)) tool_protocol_seen = True if tool_protocol_seen: lines.append("") return "\n".join(line for line in lines if line).strip() def command_generate_batch(args: argparse.Namespace) -> int: model = ReframrModel.load(args.model) prompts = load_prompt_suite(args.prompts) output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) rows: list[dict[str, object]] = [] with output_path.open("w", encoding="utf-8") as handle: for index, record in enumerate(prompts): prompt = str(record["prompt"]) record_mode = str( record.get( "reasoning_mode", args.reasoning_mode or model.config.default_reasoning_profile, ) ) context = compose_generation_context( prompt, system=str(record.get("system", "")), messages=record.get("messages"), tool_results=record.get("tool_results"), ) max_tokens = int(record.get("max_tokens", args.max_tokens)) generated_text = model.generate_text( context, max_tokens=max_tokens, reasoning_mode=record_mode, temperature=args.temperature, top_k=args.decode_top_k, top_p=args.decode_top_p, repetition_penalty=args.repetition_penalty, ) row = { "index": index, "prompt": prompt, "context": context, "system": record.get("system", ""), "tags": record.get("tags", []), "reasoning_mode": record_mode, "reasoning_tokens": reasoning_prefix(record_mode), "generated_token_count": len(generated_text.split()), "generated_text": generated_text, } rows.append(row) handle.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n") payload = { "status": "generated", "sample_count": len(rows), "model_path": str(Path(args.model).resolve()), "prompts_path": str(Path(args.prompts).resolve()), "output_path": str(output_path.resolve()), "model_loads": 1, } print(json.dumps(payload)) return 0 def command_serve(args: argparse.Namespace) -> int: model = ReframrModel.load(args.model) default_mode = args.reasoning_mode or model.config.default_reasoning_profile generated_history_by_context: dict[str, list[str]] = {} session_turns_by_id: dict[str, list[tuple[str, str]]] = {} for index, raw_line in enumerate(sys.stdin): line = raw_line.strip() if not line: continue try: request = json.loads(line) except json.JSONDecodeError as exc: response = { "index": index, "error": "invalid_json", "message": str(exc), "model_loads": 1, } sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n") sys.stdout.flush() continue if isinstance(request, str): raw_context = request base_context = request request_payload: dict[str, object] = {} elif isinstance(request, dict): request_payload = request raw_context = str(request_payload.get("prompt", request_payload.get("context", ""))) base_context = compose_generation_context( raw_context, system=str(request_payload.get("system", "")), messages=request_payload.get("messages"), tool_results=request_payload.get("tool_results", request_payload.get("toolResults")), ) else: response = { "index": index, "error": "invalid_request", "message": "request must be a JSON object or string", "model_loads": 1, } sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n") sys.stdout.flush() continue session_id = str( request_payload.get( "session_id", request_payload.get("conversation_id", request_payload.get("thread_id", "")), ) ).strip() memory_turn_limit = max( 0, int(request_payload.get("memory_turns", getattr(args, "memory_turns", 16))), ) session_turns = session_turns_by_id.get(session_id, []) if session_id else [] memory_context = "" if session_turns and memory_turn_limit > 0: memory_lines = ["Conversation memory:"] for prior_user, prior_assistant in session_turns[-memory_turn_limit:]: if prior_user.strip(): memory_lines.append(f"Previous user: {prior_user.strip()}") if prior_assistant.strip(): memory_lines.append(f"Previous assistant: {prior_assistant.strip()}") memory_context = "\n".join(memory_lines) context = ( f"{memory_context}\nCurrent user: {base_context}" if memory_context else base_context ) active_mode = str(request_payload.get("reasoning_mode", default_mode)) max_tokens = int(request_payload.get("max_tokens", args.max_tokens)) temperature = float(request_payload.get("temperature", args.temperature)) top_k = int(request_payload.get("decode_top_k", args.decode_top_k)) top_p = float(request_payload.get("decode_top_p", args.decode_top_p)) repetition_penalty = float( request_payload.get("repetition_penalty", args.repetition_penalty) ) history_key = " ".join(base_context.split()) avoid_texts = generated_history_by_context.get(history_key, []) generated_text = model.generate_text( context, max_tokens=max_tokens, reasoning_mode=active_mode, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, avoid_texts=avoid_texts, ) if generated_text.strip(): next_history = [*avoid_texts, generated_text] generated_history_by_context[history_key] = next_history[-8:] if session_id: user_memory_text = raw_context if raw_context.strip() else base_context next_session_turns = [*session_turns, (user_memory_text, generated_text)] session_turns_by_id[session_id] = next_session_turns[-max(1, memory_turn_limit):] response = { "index": index, "context": context, "reasoning_mode": active_mode, "reasoning_tokens": reasoning_prefix(active_mode), "generated_token_count": len(generated_text.split()), "generated_text": generated_text, "memory_turn_count": len(session_turns[-memory_turn_limit:]) if memory_turn_limit > 0 else 0, "model_loads": 1, } sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n") sys.stdout.flush() return 0 def command_chat_completion(args: argparse.Namespace) -> int: from .openai_compat import build_chat_completion_response, iter_sse_chat_completion request_path = str(getattr(args, "request", "")).strip() if request_path: request_text = Path(request_path).read_text(encoding="utf-8") else: request_text = sys.stdin.read() request = json.loads(request_text) if not isinstance(request, dict): raise ValueError("chat-completion request must be a JSON object") model = ReframrModel.load(args.model) if bool(request.get("stream", False)): for event in iter_sse_chat_completion(model, request): sys.stdout.write(event) sys.stdout.flush() return 0 response = build_chat_completion_response(model, request) sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n") sys.stdout.flush() return 0 def command_trace(args: argparse.Namespace) -> int: model = ReframrModel.load(args.model) payload = model.trace_generation( args.context, max_tokens=args.max_tokens, reasoning_mode=args.reasoning_mode, top_k=args.top_k, temperature=args.temperature, top_p=args.decode_top_p, repetition_penalty=args.repetition_penalty, ) print(json.dumps(payload)) return 0 def command_inspect(args: argparse.Namespace) -> int: print(json.dumps(inspect_checkpoint(args.model))) return 0 def command_craft_corpus(args: argparse.Namespace) -> int: package = ( build_generalization_corpus() if args.variant == "generalization" else build_foundation_corpus() ) paths = write_corpus_package(package, args.output_dir) payload = { "name": package.name, "corpus_path": paths["corpus_path"], "manifest_path": paths["manifest_path"], "prompt_suite_path": paths["prompt_suite_path"], "token_count_estimate": len(package.text.split()), "memorization_samples": len(package.memorization_samples), "generalization_samples": len(package.generalization_samples), "generalization_prompt_count": len(package.open_ended_samples), "variant": args.variant, "section_counts": package.section_counts, } print(json.dumps(payload)) return 0 def command_craft_curriculum(args: argparse.Namespace) -> int: payload = write_curriculum_package( args.output_dir, CurriculumConfig( records_per_category=args.records_per_category, seed=args.seed, train_ratio=args.train_ratio, ), effective_token_target=args.effective_token_target or None, ) print(json.dumps(payload)) return 0 def command_craft_v2_plan(args: argparse.Namespace) -> int: payload = write_v2_streaming_plan( args.output, rows_per_source=args.rows_per_source, effective_token_target=args.effective_token_target, wikipedia_mode=args.wikipedia_mode, local_curriculum_paths=args.local_curriculum, local_curriculum_limit=args.local_curriculum_limit, ) print(json.dumps(payload)) return 0 def command_materialize_plan(args: argparse.Namespace) -> int: max_bytes = int(max(0.0, float(args.max_gb)) * (1024 ** 3)) shard_bytes = int(max(1, int(args.shard_mb)) * (1024 ** 2)) payload = materialize_corpus_plan( load_corpus_plan(args.plan), args.output_dir, max_bytes=max_bytes, shard_bytes=shard_bytes, log_every=args.log_every, ) print(json.dumps(payload)) return 0 def command_craft_blind_prompts(args: argparse.Namespace) -> int: payload = write_blind_prompt_suite( args.output, seed=args.seed, variants_per_intent=args.variants_per_intent, ) print(json.dumps(payload)) return 0 def command_evaluate(args: argparse.Namespace) -> int: model = ReframrModel.load(args.model) manifest = load_manifest(args.manifest) payload = evaluate_manifest( model, manifest, reasoning_mode=args.reasoning_mode, top_k=args.top_k, ) print(json.dumps(payload)) return 0 def command_benchmark_open(args: argparse.Namespace) -> int: model = ReframrModel.load(args.model) prompts = load_prompt_suite(args.prompts) replay_sources = load_replay_sources( args.replay_source, limit=args.replay_source_limit, ) payload = benchmark_open_prompts( model, prompts, reasoning_mode=args.reasoning_mode, max_tokens=args.max_tokens, temperature=args.temperature, top_k=args.decode_top_k, top_p=args.decode_top_p, repetition_penalty=args.repetition_penalty, replay_sources=replay_sources, replay_ngram_size=args.replay_ngram_size, replay_overlap_threshold=args.replay_overlap_threshold, ) serialized = json.dumps(payload, ensure_ascii=False) output_path = str(getattr(args, "output", "")).strip() if output_path: target = Path(output_path) target.parent.mkdir(parents=True, exist_ok=True) target.write_text(serialized + "\n", encoding="utf-8") print(serialized) return 0 def command_sparse_context_benchmark(args: argparse.Namespace) -> int: import random model = ReframrModel.load(args.model) if model.embedding_model is None: raise RuntimeError("checkpoint does not contain embeddings") if args.selector == "hashed": kernel = HashedSparseAttention( model.embedding_model.embeddings, k_neighbors=args.top_k, hash_bits=args.hash_bits, probe_radius=args.probe_radius, seed=args.seed, candidate_multiplier=args.candidate_multiplier, ) elif args.selector == "faiss": kernel = FaissSparseAttention( model.embedding_model.embeddings, k_neighbors=args.top_k, approximate=args.faiss_hnsw, hnsw_neighbors=args.hnsw_neighbors, ef_search=args.ef_search, ) else: kernel = AnalyticalSparseAttention( model.embedding_model.embeddings, k_neighbors=args.top_k, ) vocab_size = len(model.embedding_model.id_to_token) rng = random.Random(int(args.seed)) context_tokens = [rng.randrange(vocab_size) for _ in range(max(0, int(args.context_tokens)))] query_tokens = [rng.randrange(vocab_size) for _ in range(max(0, int(args.query_count)))] payload = kernel.benchmark_selection( context_tokens, query_tokens, top_k=args.top_k, ) if args.compare_exact and args.selector == "hashed": payload["exact_recall"] = compare_selectors( model.embedding_model.embeddings, context_tokens, query_tokens, top_k=args.top_k, hash_bits=args.hash_bits, probe_radius=args.probe_radius, seed=args.seed, ) payload.update( { "schema_version": "reframr.sparse_context_benchmark.v1", "model": str(Path(args.model).resolve()), "selector": args.selector, "hash_bits": int(args.hash_bits) if args.selector == "hashed" else 0, "probe_radius": int(args.probe_radius) if args.selector == "hashed" else 0, "candidate_multiplier": int(args.candidate_multiplier) if args.selector == "hashed" else 0, "faiss_approximate": bool(args.selector == "faiss" and args.faiss_hnsw), "hnsw_neighbors": int(args.hnsw_neighbors) if args.selector == "faiss" and args.faiss_hnsw else 0, "ef_search": int(args.ef_search) if args.selector == "faiss" and args.faiss_hnsw else 0, "tokenizer_vocab_size": vocab_size, "embedding_dim": kernel.embedding_dim, } ) serialized = json.dumps(payload, ensure_ascii=False) output_path = str(getattr(args, "output", "")).strip() if output_path: target = Path(output_path) target.parent.mkdir(parents=True, exist_ok=True) target.write_text(serialized + "\n", encoding="utf-8") print(serialized) return 0 def command_import_hf(args: argparse.Namespace) -> int: payload = import_hf_dataset( dataset=args.dataset, output_path=args.output, config=args.config, split=args.split, text_field=args.text_field, limit=args.limit, streaming=not args.no_streaming, preference_target=args.preference_target, min_words=args.min_words, max_words=args.max_words, min_alpha_ratio=args.min_alpha_ratio, allowed_languages=tuple( segment.strip() for segment in args.allowed_languages.split(",") if segment.strip() ), ) print(json.dumps(payload)) return 0 def main(argv: list[str] | None = None) -> int: configure_stdio() parser = build_parser() args = parser.parse_args(argv) if args.command in {"compute", "train"}: return command_compute(args) if args.command == "recompute": return command_recompute(args) if args.command == "predict": return command_predict(args) if args.command == "generate": return command_generate(args) if args.command == "generate-batch": return command_generate_batch(args) if args.command == "serve": return command_serve(args) if args.command == "chat-completion": return command_chat_completion(args) if args.command == "trace": return command_trace(args) if args.command == "inspect": return command_inspect(args) if args.command == "craft-corpus": return command_craft_corpus(args) if args.command == "craft-curriculum": return command_craft_curriculum(args) if args.command == "craft-v2-plan": return command_craft_v2_plan(args) if args.command == "materialize-plan": return command_materialize_plan(args) if args.command == "craft-blind-prompts": return command_craft_blind_prompts(args) if args.command == "evaluate": return command_evaluate(args) if args.command == "benchmark-open": return command_benchmark_open(args) if args.command == "sparse-context-benchmark": return command_sparse_context_benchmark(args) if args.command == "import-hf": return command_import_hf(args) parser.error(f"Unknown command: {args.command}") return 2