corp-env / inference.py
Navigam's picture
feat: enhance agent functionality and memory management in corporate environment
febe155
Raw
History Blame Contribute Delete
11.9 kB
# inference.py — Master agent baseline for CORP-ENV (local Environment + OpenAI-compatible API)
#
# Uses the four action types: delegate, update_swd (JSON Patch), query_swd (JSONPath), finalize.
from __future__ import annotations
import argparse
import json
import os
import re
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from dotenv import load_dotenv
from openai import OpenAI
from corp_env.models import CorpAction, CorpObservation
from server.agents.master_prompts import build_system_prompt
from server.environment import CorpEnvironment
from server.llm_env import openai_client_kwargs_master
load_dotenv()
MASTER_KWARGS = openai_client_kwargs_master()
MASTER_API_KEY = MASTER_KWARGS.get("api_key")
MODEL_NAME = os.getenv("CORP_MASTER_MODEL") or os.getenv("MODEL_NAME")
BENCHMARK = "corp-env"
MAX_HISTORY_MESSAGES = 40
MAX_RETRIES = 5
RETRY_BASE_DELAY = 2
DEFAULT_TASKS = ["e1_launch_readiness", "m1_budget_reallocation", "h1_acquisition_defence"]
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
err = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.3f} done={str(done).lower()} error={err}",
flush=True,
)
def log_end(task: str, steps: int, score: float, rewards: List[float]) -> None:
rs = ",".join(f"{r:.3f}" for r in rewards)
print(f"[END] task={task} steps={steps} score={score:.3f} rewards={rs}", flush=True)
class SwdTraceWriter:
"""Append SWD snapshots to a dedicated file (not mixed with console logs)."""
def __init__(self, path: Optional[str], task_id: str) -> None:
self.path = path.strip() if path else None
self.task_id = task_id
self._jsonl = bool(self.path and self.path.lower().endswith(".jsonl"))
if not self.path:
return
p = Path(self.path)
p.parent.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")
with p.open("a", encoding="utf-8") as f:
f.write(
f"\n{'=' * 72}\n"
f"# CORP-ENV SWD trace | task={task_id} | started_utc={ts}\n"
f"{'=' * 72}\n"
)
def write(
self,
*,
phase: str,
step_index: int,
action: Optional[CorpAction],
obs: CorpObservation,
) -> None:
if not self.path:
return
action_blob: Dict[str, Any]
if action is None:
action_blob = {"note": "initial observation after reset"}
else:
action_blob = action.model_dump(mode="json", exclude_none=True)
if self._jsonl:
record = {
"phase": phase,
"step_index": step_index,
"env_turn": obs.turn,
"reward": obs.reward,
"done": obs.done,
"error": obs.error,
"action": action_blob,
"swd": obs.swd,
}
line = json.dumps(record, ensure_ascii=False)
with Path(self.path).open("a", encoding="utf-8") as f:
f.write(line + "\n")
return
with Path(self.path).open("a", encoding="utf-8") as f:
f.write(
f"\n--- {phase} step_index={step_index} env_turn={obs.turn} "
f"reward={obs.reward} done={obs.done} ---\n"
)
f.write(f"action: {json.dumps(action_blob, indent=2, ensure_ascii=False)}\n")
f.write(f"swd:\n{json.dumps(obs.swd, indent=2, ensure_ascii=False)}\n")
def extract_json(raw_text: str) -> dict:
cleaned = raw_text.strip()
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
cleaned = re.sub(r"\s*```\s*$", "", cleaned)
cleaned = cleaned.strip()
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
start = cleaned.find("{")
if start == -1:
raise ValueError("No JSON object found")
depth = 0
in_string = False
escape_next = False
for i in range(start, len(cleaned)):
c = cleaned[i]
if escape_next:
escape_next = False
continue
if c == "\\" and in_string:
escape_next = True
continue
if c == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
return json.loads(cleaned[start : i + 1])
raise ValueError("Unbalanced braces")
def parse_action(raw_text: str) -> CorpAction:
d = extract_json(raw_text)
d.pop("thought", None)
return CorpAction.model_validate(d)
def build_observation_message(step: int, obs: CorpObservation) -> str:
parts = [
f"--- Step {step} ---",
f"Role: {obs.role} (tier: {obs.master_tier})",
f"Task: {obs.task_description}",
f"Available agents: {', '.join(obs.available_agents)}",
f"Turn: {obs.turn} tokens_used: {obs.tokens_used}/{obs.token_budget}",
]
if obs.available_actions:
parts.append("Available actions:\n- " + "\n- ".join(obs.available_actions))
if obs.next_step_hint:
parts.append(f"Next-step hint: {obs.next_step_hint}")
if obs.recent_actions:
parts.append("Recent actions: " + " | ".join(obs.recent_actions))
parts.append(f"SWD:\n{json.dumps(obs.swd, indent=2)[:12000]}")
if obs.agent_last_output:
parts.append(f"Last worker output:\n{obs.agent_last_output[:4000]}")
if obs.query_result is not None:
parts.append(f"Query result: {json.dumps(obs.query_result)[:2000]}")
if obs.error:
parts.append(f"Error: {obs.error}")
parts.append(f"Reward (last step): {obs.reward}")
parts.append("Respond with your next JSON action.")
return "\n".join(parts)
def trim_history(messages: list, max_messages: int = MAX_HISTORY_MESSAGES) -> None:
while len(messages) > max_messages:
messages.pop(1)
def run_episode(
client: OpenAI,
task_id: str,
max_steps: int,
swd_trace: Optional[SwdTraceWriter],
) -> tuple[float, int, List[float]]:
os.environ["CORP_TASK_ID"] = task_id
env = CorpEnvironment()
rewards: List[float] = []
total = 0.0
steps = 0
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
obs = env.reset(task_id=task_id)
if swd_trace:
swd_trace.write(phase="after_reset", step_index=0, action=None, obs=obs)
system_prompt = build_system_prompt(obs.master_tier, obs.role)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": build_observation_message(0, obs)},
]
for step in range(1, max_steps + 1):
if obs.done:
break
trim_history(messages)
raw_text = None
for attempt in range(MAX_RETRIES):
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.2,
max_tokens=2048,
)
raw_text = (completion.choices[0].message.content or "").strip()
break
except Exception as exc:
exc_s = str(exc)
if ("429" in exc_s or "rate" in exc_s.lower()) and attempt < MAX_RETRIES - 1:
time.sleep(RETRY_BASE_DELAY * (2**attempt))
continue
print(f"[ERROR] {exc}", flush=True)
log_end(task_id, step, total, rewards)
return total, step, rewards
if raw_text is None:
continue
messages.append({"role": "assistant", "content": raw_text})
try:
action = parse_action(raw_text)
alog = action.model_dump_json(exclude_none=True)
except Exception as exc:
action = CorpAction(action_type="query_swd", payload="$.phase")
alog = f"PARSE_ERROR: {exc}"
messages.append(
{
"role": "user",
"content": f"Invalid JSON action: {exc}. Fix and output only JSON.",
}
)
obs = env.step(action)
rewards.append(float(obs.reward or 0.0))
total += float(obs.reward or 0.0)
steps = step
log_step(step, alog[:200], float(obs.reward or 0.0), obs.done, obs.error)
if swd_trace:
swd_trace.write(phase="after_step", step_index=step, action=action, obs=obs)
messages.append({"role": "user", "content": build_observation_message(step, obs)})
if obs.done:
break
log_end(task_id, steps, total, rewards)
return total, steps, rewards
def deterministic_e1_smoke(swd_trace: Optional[SwdTraceWriter] = None) -> None:
"""Offline smoke: E1 solved with stub workers (no master LLM)."""
os.environ["CORP_TASK_ID"] = "e1_launch_readiness"
os.environ["CORP_STUB_WORKERS"] = "1"
env = CorpEnvironment()
obs = env.reset(task_id="e1_launch_readiness")
if swd_trace:
swd_trace.write(phase="after_reset", step_index=0, action=None, obs=obs)
seq = [
CorpAction(
action_type="delegate",
agent_id="qa_engineer",
payload="Report current test status for the 48h launch window.",
),
CorpAction(
action_type="log_reasoning",
payload="QA reports blockers; will align with release plan before finalizing.",
),
CorpAction(
action_type="log_decision",
payload="Proceed with GO pending QA-flagged mitigations.",
),
CorpAction(action_type="finalize", payload="GO"),
]
total = 0.0
rlist: List[float] = []
for i, act in enumerate(seq, start=1):
obs = env.step(act)
r = float(obs.reward or 0.0)
total += r
rlist.append(r)
log_step(i, act.action_type, r, obs.done, obs.error)
if swd_trace:
swd_trace.write(phase="after_step", step_index=i, action=act, obs=obs)
log_end("e1_launch_readiness", len(seq), total, rlist)
def main() -> None:
parser = argparse.ArgumentParser(description="CORP-ENV baseline master agent")
parser.add_argument(
"--tasks",
type=str,
default=",".join(DEFAULT_TASKS),
help="Comma-separated task ids",
)
parser.add_argument("--max-steps", type=int, default=30, help="Max steps per episode")
parser.add_argument(
"--swd-trace",
type=str,
default=os.getenv("CORP_SWD_TRACE_FILE", ""),
help="Append SWD evolution to this file (.jsonl recommended). Overrides CORP_SWD_TRACE_FILE.",
)
args = parser.parse_args()
trace_path = (args.swd_trace or "").strip() or None
if not MASTER_API_KEY:
print(
"No master API key (set CORP_MASTER_API_KEY or HF_TOKEN / OPENAI_API_KEY) - "
"running deterministic E1 smoke only. Set keys to run the LLM master on --tasks.",
flush=True,
)
tw = SwdTraceWriter(trace_path, "e1_launch_readiness") if trace_path else None
deterministic_e1_smoke(swd_trace=tw)
return
client = OpenAI(**MASTER_KWARGS)
for tid in [t.strip() for t in args.tasks.split(",") if t.strip()]:
ms = args.max_steps * 2 if tid == "h1_acquisition_defence" else args.max_steps
tw = SwdTraceWriter(trace_path, tid) if trace_path else None
run_episode(client, tid, max_steps=ms, swd_trace=tw)
if __name__ == "__main__":
main()