bit-vector-tensor-control-policy / scripts /run_codex_inference.py
J94's picture
Initial Space upload
3436bdd verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Any
import yaml
ROOT = Path(__file__).resolve().parents[1]
DEFAULT_CONFIG = ROOT / "inference.yaml"
DEFAULT_SCHEMA = ROOT / "schemas" / "inference_output_v0.json"
DEFAULT_POLICY = ROOT / "policy" / "control_language_v0.json"
def load_yaml(path: Path) -> dict[str, Any]:
data = yaml.safe_load(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError(f"{path} did not decode to a mapping")
return data
def backend_timeout_seconds(backend: dict[str, Any]) -> float | None:
raw = backend.get("timeout_seconds")
if raw in (None, "", 0):
return None
timeout = float(raw)
if timeout <= 0:
raise ValueError("timeout_seconds must be positive when configured")
return timeout
def build_prompt(*, user_input: str, system_context: dict[str, Any], policy: dict[str, Any], lanes: list[str]) -> str:
compact_context = {
"current_position": system_context.get("current_position", {}),
"runtime_contract": system_context.get("runtime_contract", {}),
"latest_runtime_state": system_context.get("latest_runtime_state", {}),
"agent_bootstrap": system_context.get("agent_bootstrap", {}),
}
compact_policy = {
"bits": policy.get("bits", []),
"vectors": policy.get("vectors", []),
"tensors": policy.get("tensors", []),
"invariants": policy.get("invariants", []),
}
return "\n".join(
[
"You are the inference backend for bit_vector_tensor_control_policy.",
"Work as the reasoning engine under the deterministic turn kernel.",
"Do not execute code, write files, or open an execution lane.",
f"Choose lane from: {', '.join(lanes)}.",
"Use `audit` when the ask is mainly about verification, inconsistency, evaluation, or review.",
"Use `tooling` when the ask is mainly about how to use a tool or operator surface.",
"Use `memory` otherwise.",
"Abstain plainly when the ask depends on missing or unsupported state.",
"Keep `answer_text` concise and user-facing.",
"Keep `decision_brief` to one sentence.",
"Return JSON only matching the provided schema.",
"",
"System context:",
json.dumps(compact_context, ensure_ascii=True, separators=(",", ":")),
"",
"Policy context:",
json.dumps(compact_policy, ensure_ascii=True, separators=(",", ":")),
"",
f"User ask: {user_input}",
]
)
def run_inference(
*,
config_path: Path,
system_context_path: Path,
user_input: str,
output_path: Path,
schema_path: Path,
) -> dict[str, Any]:
config = load_yaml(config_path)
backend_id = config["default_backend"]
backend = config["backends"][backend_id]
if not backend.get("enabled", False):
raise RuntimeError(f"backend {backend_id} is disabled")
if backend.get("provider") != "codex_cli":
raise RuntimeError(f"unsupported provider {backend.get('provider')}")
system_context = json.loads(system_context_path.read_text(encoding="utf-8"))
policy = json.loads(DEFAULT_POLICY.read_text(encoding="utf-8"))
lanes = backend.get("lanes", {}).get("allowed", ["memory", "tooling", "audit"])
prompt = build_prompt(user_input=user_input, system_context=system_context, policy=policy, lanes=lanes)
output_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as temp_schema:
temp_schema.write(schema_path.read_text(encoding="utf-8"))
temp_schema_path = Path(temp_schema.name)
command = [backend.get("command", "codex"), "exec"]
if backend.get("model"):
command.extend(["-m", str(backend["model"])])
if backend.get("sandbox"):
command.extend(["-s", str(backend["sandbox"])])
if backend.get("ephemeral", False):
command.append("--ephemeral")
if backend.get("skip_git_repo_check", False):
command.append("--skip-git-repo-check")
command.extend(
[
"-C",
str(ROOT),
"--output-schema",
str(temp_schema_path),
"-o",
str(output_path),
"-",
]
)
try:
completed = subprocess.run(
command,
input=prompt,
text=True,
capture_output=True,
cwd=ROOT,
check=False,
timeout=backend_timeout_seconds(backend),
)
except subprocess.TimeoutExpired as exc:
raise RuntimeError(f"codex exec timed out after {exc.timeout} seconds") from exc
finally:
temp_schema_path.unlink(missing_ok=True)
if completed.returncode != 0:
stderr = completed.stderr.strip()
raise RuntimeError(stderr or "codex exec failed without stderr")
result = json.loads(output_path.read_text(encoding="utf-8"))
result["backend"] = {
"id": backend_id,
"provider": backend.get("provider", "codex_cli"),
"model": backend.get("model", ""),
"sandbox": backend.get("sandbox", ""),
"uses_chatgpt_subscription": bool(backend.get("use_chatgpt_subscription", False)),
}
output_path.write_text(json.dumps(result, indent=2, sort_keys=True) + "\n", encoding="utf-8")
return result
def main() -> int:
parser = argparse.ArgumentParser(description="Run repo-configured Codex CLI inference for the conversational turn kernel.")
parser.add_argument("--config", default=str(DEFAULT_CONFIG))
parser.add_argument("--schema", default=str(DEFAULT_SCHEMA))
parser.add_argument("--system-context", required=True)
parser.add_argument("--user-input", required=True)
parser.add_argument("--output", required=True)
args = parser.parse_args()
result = run_inference(
config_path=Path(args.config),
system_context_path=Path(args.system_context),
user_input=args.user_input,
output_path=Path(args.output),
schema_path=Path(args.schema),
)
json.dump(result, sys.stdout, indent=2)
sys.stdout.write("\n")
return 0
if __name__ == "__main__":
raise SystemExit(main())