# inference.py — Master agent baseline for CORP-ENV (local Environment + OpenAI-compatible API) # # Uses the four action types: delegate, update_swd (JSON Patch), query_swd (JSONPath), finalize. from __future__ import annotations import argparse import json import os import re import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional from dotenv import load_dotenv from openai import OpenAI from corp_env.models import CorpAction, CorpObservation from server.agents.master_prompts import build_system_prompt from server.environment import CorpEnvironment from server.llm_env import openai_client_kwargs_master load_dotenv() MASTER_KWARGS = openai_client_kwargs_master() MASTER_API_KEY = MASTER_KWARGS.get("api_key") MODEL_NAME = os.getenv("CORP_MASTER_MODEL") or os.getenv("MODEL_NAME") BENCHMARK = "corp-env" MAX_HISTORY_MESSAGES = 40 MAX_RETRIES = 5 RETRY_BASE_DELAY = 2 DEFAULT_TASKS = ["e1_launch_readiness", "m1_budget_reallocation", "h1_acquisition_defence"] def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: err = error if error else "null" print( f"[STEP] step={step} action={action} reward={reward:.3f} done={str(done).lower()} error={err}", flush=True, ) def log_end(task: str, steps: int, score: float, rewards: List[float]) -> None: rs = ",".join(f"{r:.3f}" for r in rewards) print(f"[END] task={task} steps={steps} score={score:.3f} rewards={rs}", flush=True) class SwdTraceWriter: """Append SWD snapshots to a dedicated file (not mixed with console logs).""" def __init__(self, path: Optional[str], task_id: str) -> None: self.path = path.strip() if path else None self.task_id = task_id self._jsonl = bool(self.path and self.path.lower().endswith(".jsonl")) if not self.path: return p = Path(self.path) p.parent.mkdir(parents=True, exist_ok=True) ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ") with p.open("a", encoding="utf-8") as f: f.write( f"\n{'=' * 72}\n" f"# CORP-ENV SWD trace | task={task_id} | started_utc={ts}\n" f"{'=' * 72}\n" ) def write( self, *, phase: str, step_index: int, action: Optional[CorpAction], obs: CorpObservation, ) -> None: if not self.path: return action_blob: Dict[str, Any] if action is None: action_blob = {"note": "initial observation after reset"} else: action_blob = action.model_dump(mode="json", exclude_none=True) if self._jsonl: record = { "phase": phase, "step_index": step_index, "env_turn": obs.turn, "reward": obs.reward, "done": obs.done, "error": obs.error, "action": action_blob, "swd": obs.swd, } line = json.dumps(record, ensure_ascii=False) with Path(self.path).open("a", encoding="utf-8") as f: f.write(line + "\n") return with Path(self.path).open("a", encoding="utf-8") as f: f.write( f"\n--- {phase} step_index={step_index} env_turn={obs.turn} " f"reward={obs.reward} done={obs.done} ---\n" ) f.write(f"action: {json.dumps(action_blob, indent=2, ensure_ascii=False)}\n") f.write(f"swd:\n{json.dumps(obs.swd, indent=2, ensure_ascii=False)}\n") def extract_json(raw_text: str) -> dict: cleaned = raw_text.strip() cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned) cleaned = re.sub(r"\s*```\s*$", "", cleaned) cleaned = cleaned.strip() try: return json.loads(cleaned) except json.JSONDecodeError: pass start = cleaned.find("{") if start == -1: raise ValueError("No JSON object found") depth = 0 in_string = False escape_next = False for i in range(start, len(cleaned)): c = cleaned[i] if escape_next: escape_next = False continue if c == "\\" and in_string: escape_next = True continue if c == '"' and not escape_next: in_string = not in_string continue if in_string: continue if c == "{": depth += 1 elif c == "}": depth -= 1 if depth == 0: return json.loads(cleaned[start : i + 1]) raise ValueError("Unbalanced braces") def parse_action(raw_text: str) -> CorpAction: d = extract_json(raw_text) d.pop("thought", None) return CorpAction.model_validate(d) def build_observation_message(step: int, obs: CorpObservation) -> str: parts = [ f"--- Step {step} ---", f"Role: {obs.role} (tier: {obs.master_tier})", f"Task: {obs.task_description}", f"Available agents: {', '.join(obs.available_agents)}", f"Turn: {obs.turn} tokens_used: {obs.tokens_used}/{obs.token_budget}", ] if obs.available_actions: parts.append("Available actions:\n- " + "\n- ".join(obs.available_actions)) if obs.next_step_hint: parts.append(f"Next-step hint: {obs.next_step_hint}") if obs.recent_actions: parts.append("Recent actions: " + " | ".join(obs.recent_actions)) parts.append(f"SWD:\n{json.dumps(obs.swd, indent=2)[:12000]}") if obs.agent_last_output: parts.append(f"Last worker output:\n{obs.agent_last_output[:4000]}") if obs.query_result is not None: parts.append(f"Query result: {json.dumps(obs.query_result)[:2000]}") if obs.error: parts.append(f"Error: {obs.error}") parts.append(f"Reward (last step): {obs.reward}") parts.append("Respond with your next JSON action.") return "\n".join(parts) def trim_history(messages: list, max_messages: int = MAX_HISTORY_MESSAGES) -> None: while len(messages) > max_messages: messages.pop(1) def run_episode( client: OpenAI, task_id: str, max_steps: int, swd_trace: Optional[SwdTraceWriter], ) -> tuple[float, int, List[float]]: os.environ["CORP_TASK_ID"] = task_id env = CorpEnvironment() rewards: List[float] = [] total = 0.0 steps = 0 log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) obs = env.reset(task_id=task_id) if swd_trace: swd_trace.write(phase="after_reset", step_index=0, action=None, obs=obs) system_prompt = build_system_prompt(obs.master_tier, obs.role) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": build_observation_message(0, obs)}, ] for step in range(1, max_steps + 1): if obs.done: break trim_history(messages) raw_text = None for attempt in range(MAX_RETRIES): try: completion = client.chat.completions.create( model=MODEL_NAME, messages=messages, temperature=0.2, max_tokens=2048, ) raw_text = (completion.choices[0].message.content or "").strip() break except Exception as exc: exc_s = str(exc) if ("429" in exc_s or "rate" in exc_s.lower()) and attempt < MAX_RETRIES - 1: time.sleep(RETRY_BASE_DELAY * (2**attempt)) continue print(f"[ERROR] {exc}", flush=True) log_end(task_id, step, total, rewards) return total, step, rewards if raw_text is None: continue messages.append({"role": "assistant", "content": raw_text}) try: action = parse_action(raw_text) alog = action.model_dump_json(exclude_none=True) except Exception as exc: action = CorpAction(action_type="query_swd", payload="$.phase") alog = f"PARSE_ERROR: {exc}" messages.append( { "role": "user", "content": f"Invalid JSON action: {exc}. Fix and output only JSON.", } ) obs = env.step(action) rewards.append(float(obs.reward or 0.0)) total += float(obs.reward or 0.0) steps = step log_step(step, alog[:200], float(obs.reward or 0.0), obs.done, obs.error) if swd_trace: swd_trace.write(phase="after_step", step_index=step, action=action, obs=obs) messages.append({"role": "user", "content": build_observation_message(step, obs)}) if obs.done: break log_end(task_id, steps, total, rewards) return total, steps, rewards def deterministic_e1_smoke(swd_trace: Optional[SwdTraceWriter] = None) -> None: """Offline smoke: E1 solved with stub workers (no master LLM).""" os.environ["CORP_TASK_ID"] = "e1_launch_readiness" os.environ["CORP_STUB_WORKERS"] = "1" env = CorpEnvironment() obs = env.reset(task_id="e1_launch_readiness") if swd_trace: swd_trace.write(phase="after_reset", step_index=0, action=None, obs=obs) seq = [ CorpAction( action_type="delegate", agent_id="qa_engineer", payload="Report current test status for the 48h launch window.", ), CorpAction( action_type="log_reasoning", payload="QA reports blockers; will align with release plan before finalizing.", ), CorpAction( action_type="log_decision", payload="Proceed with GO pending QA-flagged mitigations.", ), CorpAction(action_type="finalize", payload="GO"), ] total = 0.0 rlist: List[float] = [] for i, act in enumerate(seq, start=1): obs = env.step(act) r = float(obs.reward or 0.0) total += r rlist.append(r) log_step(i, act.action_type, r, obs.done, obs.error) if swd_trace: swd_trace.write(phase="after_step", step_index=i, action=act, obs=obs) log_end("e1_launch_readiness", len(seq), total, rlist) def main() -> None: parser = argparse.ArgumentParser(description="CORP-ENV baseline master agent") parser.add_argument( "--tasks", type=str, default=",".join(DEFAULT_TASKS), help="Comma-separated task ids", ) parser.add_argument("--max-steps", type=int, default=30, help="Max steps per episode") parser.add_argument( "--swd-trace", type=str, default=os.getenv("CORP_SWD_TRACE_FILE", ""), help="Append SWD evolution to this file (.jsonl recommended). Overrides CORP_SWD_TRACE_FILE.", ) args = parser.parse_args() trace_path = (args.swd_trace or "").strip() or None if not MASTER_API_KEY: print( "No master API key (set CORP_MASTER_API_KEY or HF_TOKEN / OPENAI_API_KEY) - " "running deterministic E1 smoke only. Set keys to run the LLM master on --tasks.", flush=True, ) tw = SwdTraceWriter(trace_path, "e1_launch_readiness") if trace_path else None deterministic_e1_smoke(swd_trace=tw) return client = OpenAI(**MASTER_KWARGS) for tid in [t.strip() for t in args.tasks.split(",") if t.strip()]: ms = args.max_steps * 2 if tid == "h1_acquisition_defence" else args.max_steps tw = SwdTraceWriter(trace_path, tid) if trace_path else None run_episode(client, tid, max_steps=ms, swd_trace=tw) if __name__ == "__main__": main()