| """CORP-ENV gym: SWD-centric corporate planning with delegate / patch / query / finalize.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import random |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from jsonpath_ng import parse as jsonpath_parse |
|
|
| from corp_env.models import CorpAction, CorpObservation |
| from openenv.core.env_server.interfaces import Environment |
| from openenv.core.env_server.types import EnvironmentMetadata, State |
|
|
| from server.agents.personas import AgentSlot |
| from server.agents.prompts import normalize_worker_id, swd_report_key_for_family |
| from server.memory import append_master_episode, append_worker_episode |
| from server.reward import compute_reward, compute_swd_coherence |
| from server.swd import ( |
| VALID_PHASES, |
| new_episode_id, |
| try_apply_json_patch, |
| validate_swd_structure, |
| ) |
| from server.tasks import TASKS |
| from server.tasks.base import CorpTask |
| from server.worker_client import call_worker_model |
|
|
| SAFETY_MAX_TURNS = 256 |
|
|
| _AVAILABLE_ACTIONS_HINT = [ |
| "delegate(agent_id, payload=task_prompt) -> consult a worker", |
| "log_reasoning(payload=text) -> append a reasoning_log entry", |
| "log_decision(payload=text) -> append to decisions", |
| "log_conflict(payload=JSON {id?, summary, source_agents?}) -> append to conflicts_identified", |
| "log_resolution(payload=JSON {conflict_id, resolution_type?, text?}) -> append to conflict_resolutions", |
| "advance_phase(payload=\"analysis\"|\"decision\"|\"execution\")", |
| "query_swd(payload=JSONPath) -> read the SWD", |
| "update_swd(payload=RFC6902 JSON Patch string) -> power edits", |
| "finalize(payload=GO|NO_GO|JSON object) -> end the episode", |
| ] |
|
|
|
|
| class CorpEnvironment(Environment[CorpAction, CorpObservation, State]): |
| """Shared workspace document environment (master + frozen workers).""" |
|
|
| SUPPORTS_CONCURRENT_SESSIONS = True |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| self.task: CorpTask |
| self.swd: Dict[str, Any] = {} |
| self.turn = 0 |
| self.tokens_used = 0 |
| self.episode_metadata: Dict[str, Any] = {} |
| self._episode_id: str = "" |
|
|
| def get_metadata(self) -> EnvironmentMetadata: |
| return EnvironmentMetadata( |
| name="corp-env", |
| description=( |
| "Multi-agent corporate decision environment with shared workspace " |
| "document (SWD) governance." |
| ), |
| version="0.2.0", |
| ) |
|
|
| def reset( |
| self, |
| seed: Optional[int] = None, |
| episode_id: Optional[str] = None, |
| **kwargs: Any, |
| ) -> CorpObservation: |
| if seed is not None: |
| random.seed(seed) |
|
|
| tid = ( |
| kwargs.get("task_id") |
| or os.environ.get("CORP_TASK_ID") |
| or random.choice(list(TASKS.keys())) |
| ) |
| if tid not in TASKS: |
| tid = "e1_launch_readiness" |
|
|
| self.task = TASKS[tid] |
| self._episode_id = episode_id or new_episode_id() |
| self.swd = self.task.initial_swd(self._episode_id) |
| self.turn = 0 |
| self.tokens_used = 0 |
| self.episode_metadata = { |
| "task_id": self.task.task_id, |
| "invalid_json_count": 0, |
| "wrong_agent_count": 0, |
| "consecutive_same_agent_calls": 0, |
| "last_agent": None, |
| "finalized": False, |
| "version_decreased": False, |
| "turn_completed": {}, |
| "token_budget_exceeded": False, |
| "recent_actions": [], |
| } |
| return self._observation(reward=0.0, done=False, agent_last_output=None, query_result=None) |
|
|
| def step( |
| self, |
| action: CorpAction, |
| timeout_s: Optional[float] = None, |
| **kwargs: Any, |
| ) -> CorpObservation: |
| _ = timeout_s, kwargs |
| self.turn += 1 |
| step_reward = 0.0 |
| done = False |
| agent_last_output: Optional[str] = None |
| query_result: Any = None |
| error: Optional[str] = None |
|
|
| self._add_tokens(action.payload + (action.agent_id or "")) |
|
|
| at = action.action_type |
| if at == "delegate": |
| step_reward, agent_last_output, error = self._step_delegate(action) |
| elif at == "update_swd": |
| step_reward, error = self._step_update_swd(action) |
| elif at == "query_swd": |
| query_result, error = self._step_query_swd(action) |
| elif at == "finalize": |
| step_reward, done, error = self._step_finalize(action) |
| elif at == "log_reasoning": |
| step_reward, error = self._step_log_reasoning(action) |
| elif at == "log_decision": |
| step_reward, error = self._step_log_decision(action) |
| elif at == "log_conflict": |
| step_reward, error = self._step_log_conflict(action) |
| elif at == "log_resolution": |
| step_reward, error = self._step_log_resolution(action) |
| elif at == "advance_phase": |
| step_reward, error = self._step_advance_phase(action) |
| else: |
| error = f"unknown action_type {at}" |
|
|
| self._push_recent_action(action, error) |
|
|
| new_milestones = self._update_milestone_status() |
| step_reward += 0.2 * len(new_milestones) |
| step_reward -= 0.01 |
|
|
| if not done and self.tokens_used > self.task.token_budget: |
| self.episode_metadata["token_budget_exceeded"] = True |
| step_reward -= 0.20 |
| done = True |
|
|
| if not done and self._all_milestones_missed(): |
| done = True |
| step_reward -= 0.1 |
|
|
| if not done and self.turn >= SAFETY_MAX_TURNS: |
| done = True |
| error = (error or "") + "; safety max turns reached" |
|
|
| return self._observation( |
| reward=step_reward, |
| done=done, |
| agent_last_output=agent_last_output, |
| query_result=query_result, |
| error=error, |
| ) |
|
|
| @property |
| def state(self) -> State: |
| return State( |
| episode_id=str(self.swd.get("episode_id", self._episode_id)), |
| step_count=self.turn, |
| ) |
|
|
| def close(self) -> None: |
| return None |
|
|
| |
|
|
| def _add_tokens(self, text: str) -> None: |
| self.tokens_used += max(8, len(text) // 4) |
|
|
| def _push_recent_action(self, action: CorpAction, error: Optional[str]) -> None: |
| summary = action.action_type |
| if action.agent_id: |
| summary += f":{action.agent_id}" |
| if action.payload: |
| snip = action.payload.replace("\n", " ")[:60] |
| summary += f" '{snip}'" |
| if error: |
| summary += " [err]" |
| recent: List[str] = self.episode_metadata.setdefault("recent_actions", []) |
| recent.append(summary) |
| if len(recent) > 5: |
| del recent[:-5] |
|
|
| def _next_step_hint(self) -> Optional[str]: |
| """Return a one-line hint for the first pending milestone.""" |
| for m in self.swd.get("milestones", []) or []: |
| if m.get("status") == "pending": |
| mid = str(m.get("id", "")) |
| label = str(m.get("label", "")) |
| |
| lbl = label.lower() |
| if "agent_report" in lbl or "consult" in lbl or "populated" in lbl: |
| slot_ids = self.task.available_agents |
| missing = self._slots_missing_reports() |
| target = missing[0] if missing else (slot_ids[0] if slot_ids else "an advisor") |
| return ( |
| f"Milestone {mid}: use delegate with agent_id='{target}' " |
| f"and a specific prompt about the task." |
| ) |
| if "conflict" in lbl and "resolv" in lbl: |
| return ( |
| f"Milestone {mid}: call log_conflict then log_resolution " |
| f"(conflict_id matching the logged conflict)." |
| ) |
| if "conflict" in lbl: |
| return f"Milestone {mid}: call log_conflict with a short summary." |
| if "phased" in lbl or "recommendation" in lbl or "final" in lbl: |
| return f"Milestone {mid}: call finalize with the required payload." |
| if "reasoning_log" in lbl: |
| return ( |
| f"Milestone {mid}: call log_reasoning a few more times " |
| f"(different turns) to build the reasoning_log." |
| ) |
| return f"Milestone {mid}: {label}" |
| return "All milestones complete - call finalize with your recommendation." |
|
|
| def _slots_missing_reports(self) -> List[str]: |
| out: List[str] = [] |
| ar = self.swd.get("agent_reports", {}) or {} |
| for slot in self.task.agent_slots: |
| key = swd_report_key_for_family(slot.family) |
| if ar.get(key) is None: |
| out.append(slot.id) |
| return out |
|
|
| def _observation( |
| self, |
| *, |
| reward: float, |
| done: bool, |
| agent_last_output: Optional[str], |
| query_result: Any, |
| error: Optional[str] = None, |
| ) -> CorpObservation: |
| return CorpObservation( |
| task_description=self.task.description, |
| role=self.task.role, |
| master_tier=getattr(self.task, "master_tier", "senior"), |
| available_agents=list(self.task.available_agents), |
| available_actions=list(_AVAILABLE_ACTIONS_HINT), |
| next_step_hint=self._next_step_hint(), |
| recent_actions=list(self.episode_metadata.get("recent_actions", [])), |
| swd=json.loads(json.dumps(self.swd)), |
| agent_last_output=agent_last_output, |
| query_result=query_result, |
| tokens_used=self.tokens_used, |
| token_budget=self.task.token_budget, |
| turn=self.turn, |
| reward=reward, |
| done=done, |
| error=error, |
| metadata={"episode_metadata": dict(self.episode_metadata)}, |
| ) |
|
|
| def _step_delegate( |
| self, action: CorpAction |
| ) -> Tuple[float, Optional[str], Optional[str]]: |
| reward = 0.0 |
| canonical = normalize_worker_id(action.agent_id) |
| if canonical is None or canonical not in self.task.available_agents: |
| reward -= 0.10 |
| self.episode_metadata["wrong_agent_count"] += 1 |
| return reward, None, "invalid or unavailable agent_id for this task" |
|
|
| slot = self.task.get_slot(canonical) |
| if slot is None: |
| reward -= 0.10 |
| self.episode_metadata["wrong_agent_count"] += 1 |
| return reward, None, f"no slot registered for {canonical}" |
|
|
| if canonical == self.episode_metadata.get("last_agent"): |
| self.episode_metadata["consecutive_same_agent_calls"] += 1 |
| else: |
| self.episode_metadata["consecutive_same_agent_calls"] = 0 |
| self.episode_metadata["last_agent"] = canonical |
|
|
| task_text = action.payload.strip() |
| extra = self.task.intel_injections.get(canonical) |
| if extra: |
| task_text = f"{task_text}\n\n{extra}".strip() |
|
|
| key = swd_report_key_for_family(slot.family) |
| existing = self.swd.get("agent_reports", {}).get(key) |
| if existing is None: |
| reward += 0.1 |
| else: |
| reward -= 0.02 |
|
|
| if self.episode_metadata["consecutive_same_agent_calls"] > 1: |
| reward -= 0.05 * self.episode_metadata["consecutive_same_agent_calls"] |
|
|
| out = call_worker_model( |
| slot, |
| task_text or "Provide your assessment.", |
| task_id=self.task.task_id, |
| episode_id=self._episode_id, |
| turn=self.turn, |
| ) |
| self._add_tokens(out) |
| self.swd.setdefault("agent_reports", {})[key] = out |
| self.swd["swd_version"] = int(self.swd.get("swd_version", 0)) + 1 |
| return reward, out, None |
|
|
| def _step_update_swd(self, action: CorpAction) -> Tuple[float, Optional[str]]: |
| reward = 0.0 |
| try: |
| ops = json.loads(action.payload) |
| if not isinstance(ops, list): |
| raise ValueError("JSON Patch must be a list of operations") |
| except (json.JSONDecodeError, ValueError) as e: |
| self.episode_metadata["invalid_json_count"] += 1 |
| reward -= 0.15 |
| return reward, f"invalid JSON patch: {e}" |
|
|
| if not ops: |
| reward -= 0.05 |
|
|
| old_phase = self.swd.get("phase") |
| old_ver = int(self.swd.get("swd_version", 0)) |
| new_doc, err = try_apply_json_patch(self.swd, ops) |
| if err: |
| self.episode_metadata["invalid_json_count"] += 1 |
| reward -= 0.15 |
| return reward, err |
|
|
| if new_doc != self.swd: |
| reward += 0.05 |
|
|
| new_phase = new_doc.get("phase") |
| if old_phase != new_phase: |
| reward += 0.1 |
|
|
| if int(new_doc.get("swd_version", 0)) < old_ver: |
| self.episode_metadata["version_decreased"] = True |
| reward -= 0.5 |
|
|
| self.swd = new_doc |
| reward += 0.15 * compute_swd_coherence(self.swd) |
| log = self.swd.get("reasoning_log", []) or [] |
| if len(log) > 0: |
| reward += 0.02 |
| return reward, None |
|
|
| def _apply_patch_or_error( |
| self, ops: List[Dict[str, Any]] |
| ) -> Tuple[float, Optional[str]]: |
| reward = 0.0 |
| new_doc, err = try_apply_json_patch(self.swd, ops) |
| if err: |
| self.episode_metadata["invalid_json_count"] += 1 |
| return reward - 0.15, err |
| reward += 0.05 |
| self.swd = new_doc |
| reward += 0.15 * compute_swd_coherence(self.swd) |
| return reward, None |
|
|
| def _step_log_reasoning(self, action: CorpAction) -> Tuple[float, Optional[str]]: |
| text = (action.payload or "").strip() |
| if not text: |
| return -0.05, "log_reasoning requires non-empty payload" |
| return self._apply_patch_or_error( |
| [ |
| { |
| "op": "add", |
| "path": "/reasoning_log/-", |
| "value": {"turn": self.turn, "text": text}, |
| } |
| ] |
| ) |
|
|
| def _step_log_decision(self, action: CorpAction) -> Tuple[float, Optional[str]]: |
| text = (action.payload or "").strip() |
| if not text: |
| return -0.05, "log_decision requires non-empty payload" |
| return self._apply_patch_or_error( |
| [ |
| { |
| "op": "add", |
| "path": "/decisions/-", |
| "value": {"turn": self.turn, "summary": text}, |
| } |
| ] |
| ) |
|
|
| def _step_log_conflict(self, action: CorpAction) -> Tuple[float, Optional[str]]: |
| raw = (action.payload or "").strip() |
| if not raw: |
| return -0.05, "log_conflict requires non-empty payload" |
| try: |
| obj = json.loads(raw) if raw.startswith("{") else {"summary": raw} |
| if not isinstance(obj, dict): |
| return -0.15, "log_conflict payload must be a JSON object or plain string" |
| except json.JSONDecodeError as e: |
| return -0.15, f"log_conflict invalid JSON: {e}" |
|
|
| existing = self.swd.get("conflicts_identified", []) or [] |
| if "id" not in obj or not obj.get("id"): |
| obj["id"] = f"c{len(existing) + 1}" |
| obj.setdefault("turn", self.turn) |
|
|
| return self._apply_patch_or_error( |
| [{"op": "add", "path": "/conflicts_identified/-", "value": obj}] |
| ) |
|
|
| def _step_log_resolution(self, action: CorpAction) -> Tuple[float, Optional[str]]: |
| raw = (action.payload or "").strip() |
| if not raw: |
| return -0.05, "log_resolution requires non-empty payload" |
| try: |
| obj = json.loads(raw) if raw.startswith("{") else {"text": raw} |
| if not isinstance(obj, dict): |
| return -0.15, "log_resolution payload must be a JSON object or plain string" |
| except json.JSONDecodeError as e: |
| return -0.15, f"log_resolution invalid JSON: {e}" |
|
|
| |
| if not obj.get("conflict_id"): |
| conflicts = self.swd.get("conflicts_identified", []) or [] |
| if conflicts and isinstance(conflicts[-1], dict): |
| last_id = conflicts[-1].get("id") |
| if last_id: |
| obj["conflict_id"] = last_id |
| obj.setdefault("turn", self.turn) |
|
|
| return self._apply_patch_or_error( |
| [{"op": "add", "path": "/conflict_resolutions/-", "value": obj}] |
| ) |
|
|
| def _step_advance_phase(self, action: CorpAction) -> Tuple[float, Optional[str]]: |
| phase = (action.payload or "").strip().lower() |
| if phase not in VALID_PHASES: |
| return -0.05, f"advance_phase payload must be one of {sorted(VALID_PHASES)}" |
| old_phase = self.swd.get("phase") |
| if phase == old_phase: |
| return -0.02, f"already in phase {phase}" |
| reward, err = self._apply_patch_or_error( |
| [{"op": "replace", "path": "/phase", "value": phase}] |
| ) |
| if err is None: |
| reward += 0.1 |
| return reward, err |
|
|
| def _step_query_swd(self, action: CorpAction) -> Tuple[Any, Optional[str]]: |
| expr = action.payload.strip() |
| if not expr: |
| return None, "empty JSONPath expression" |
| try: |
| matches = [m.value for m in jsonpath_parse(expr).find(self.swd)] |
| return matches, None |
| except Exception as e: |
| return None, f"query_swd error: {e}" |
|
|
| def _step_finalize(self, action: CorpAction) -> Tuple[float, bool, Optional[str]]: |
| raw = action.payload.strip() |
| if not raw: |
| return -0.05, True, "finalize requires non-empty payload" |
|
|
| final_val: Any |
| try: |
| if raw.startswith("{") or raw.startswith("["): |
| final_val = json.loads(raw) |
| else: |
| final_val = raw |
| except json.JSONDecodeError: |
| final_val = raw |
|
|
| self.swd["final_recommendation"] = final_val |
| self.swd["swd_version"] = int(self.swd.get("swd_version", 0)) + 1 |
| self.episode_metadata["finalized"] = True |
|
|
| verify_result = self.task.verifier(self.swd) |
| terminal = compute_reward( |
| self.swd, |
| verify_result, |
| self.episode_metadata, |
| self.task.description, |
| ) |
| ok, err = validate_swd_structure(self.swd) |
| if not ok: |
| terminal = max(0.0, terminal - 0.05) |
| self._persist_episode_memory(verify_result, terminal) |
| return terminal, True, f"finalize on structurally invalid SWD: {err}" |
|
|
| self._persist_episode_memory(verify_result, terminal) |
| return terminal, True, None |
|
|
| def _persist_episode_memory( |
| self, verify_result: Dict[str, bool], score: float |
| ) -> None: |
| """Write per-worker + master episode summaries to the memory store.""" |
| try: |
| pass_rate = ( |
| sum(1 for v in verify_result.values() if v) / max(len(verify_result), 1) |
| ) |
| milestones = self.swd.get("milestones", []) or [] |
| counts = { |
| "total": len(milestones), |
| "complete": sum(1 for m in milestones if m.get("status") == "complete"), |
| "missed": sum(1 for m in milestones if m.get("status") == "missed"), |
| "pending": sum(1 for m in milestones if m.get("status") == "pending"), |
| } |
| final_snip = self.swd.get("final_recommendation") |
| if isinstance(final_snip, str) and len(final_snip) > 200: |
| final_snip = final_snip[:200] + "..." |
|
|
| append_master_episode( |
| task_id=self.task.task_id, |
| episode_id=self._episode_id, |
| role=self.task.role, |
| master_tier=getattr(self.task, "master_tier", "senior"), |
| final_recommendation=final_snip, |
| score=score, |
| milestone_counts=counts, |
| notes=f"verifier pass_rate={round(pass_rate, 3)}", |
| ) |
|
|
| ar = self.swd.get("agent_reports", {}) or {} |
| for slot in self.task.agent_slots: |
| key = swd_report_key_for_family(slot.family) |
| report = ar.get(key) |
| if report is None: |
| continue |
| summary = str(report) |
| if len(summary) > 240: |
| summary = summary[:240] + "..." |
| append_worker_episode( |
| task_id=self.task.task_id, |
| slot_id=slot.id, |
| episode_id=self._episode_id, |
| summary=summary, |
| verifier_pass_rate=pass_rate, |
| score=score, |
| ) |
| except Exception: |
| |
| return |
|
|
| def _update_milestone_status(self) -> List[str]: |
| completed_now: List[str] = [] |
| for m in self.swd.get("milestones", []) or []: |
| mid = m.get("id") |
| if not mid or m.get("status") in ("complete", "missed"): |
| continue |
| if self.task.milestone_complete(self.swd, str(mid)): |
| m["status"] = "complete" |
| self.episode_metadata["turn_completed"][str(mid)] = self.turn |
| completed_now.append(str(mid)) |
| elif self.turn > int(m.get("due_by_turn", 9999)): |
| m["status"] = "missed" |
| return completed_now |
|
|
| def _all_milestones_missed(self) -> bool: |
| ms = self.swd.get("milestones", []) or [] |
| if not ms: |
| return False |
| return all(m.get("status") == "missed" for m in ms) |
|
|
|
|
| |
| _ = AgentSlot |
|
|