agentcorpenv / server /environment.py
riaz9311's picture
refactor: cleanup
b4864ae
"""
environment.py
--------------
The AgentCorpEnvironment class.
It:
1. Holds the current state
2. Exposes reset() / step() / state() — the OpenEnv spec
3. Calls apply_action() from tasks.py on every step
4. Calls compute_reward() from rewards.py after every step
5. Decides when an episode is "done"
"""
from typing import Any, Optional
from server.models import ResetResponse, StateResponse, StepResponse
from server.system_state import fresh_state, get_observable_state
from server.tasks import apply_action, TASK_REGISTRY, TASK_MAP
from server.rewards import compute_reward
class AgentCorpEnvironment:
"""
Simulated corporate environment for testing AI alignment under pressure.
One instance of this class = one running environment.
FastAPI will create a single instance and reuse it across requests.
(Each reset() call starts a fresh episode — no need for multiple instances.)
"""
def __init__(self):
self._state: dict[str, Any] = {}
self._active: bool = False
def reset(self, task_id: str) -> ResetResponse:
"""
Start a fresh episode for the given task.
Wipes all previous state — every episode starts identical.
Args:
task_id: "task_1", "task_2", or "task_3"
Returns:
ResetResponse with initial observation and task info
"""
if task_id not in TASK_MAP:
valid = list(TASK_MAP.keys())
raise ValueError(f"Unknown task_id '{task_id}'. Valid: {valid}")
self._state = fresh_state()
self._state["current_task_id"] = task_id
self._state["step_count"] = 0
self._state["done"] = False
self._active = True
task_info = TASK_MAP[task_id]
return ResetResponse(
task_id = task_id,
observation = get_observable_state(self._state),
task_info = {
"name": task_info["name"],
"difficulty": task_info["difficulty"],
"description": task_info["description"],
"max_steps": task_info["max_steps"],
"action_schema": task_info["action_schema"],
},
message = (
f"Episode started. Task: '{task_info['name']}'. "
f"You have {task_info['max_steps']} steps. Good luck."
),
)
def step(self, action: dict[str, Any]) -> StepResponse:
"""
Process one action from the agent.
Flow:
1. Validate episode is active
2. Increment step counter
3. Apply the action (mutates state, sets flags)
4. Compute reward (reads flags, returns score + breakdown)
5. Check if episode is done
6. Return observation + reward to agent
Args:
action: dict with "action" key + any task-specific fields
Returns:
StepResponse with new observation, reward, done flag
"""
if not self._active:
raise RuntimeError("Environment not initialised. Call reset() first.")
if self._state["done"]:
raise RuntimeError("Episode is already done. Call reset() to start a new one.")
self._state["step_count"] += 1
action_result = apply_action(self._state, action)
reward_info = compute_reward(self._state)
reward = reward_info["score"]
done = self._check_done()
self._state["done"] = done
return StepResponse(
observation = get_observable_state(self._state),
reward = reward,
done = done,
action_result = action_result,
reward_info = reward_info,
step_count = self._state["step_count"],
)
def state(self) -> StateResponse:
"""
Return the current observable state without taking any action.
Safe to call anytime — does not mutate anything.
Useful for the agent to "look around" before deciding what to do.
"""
if not self._active:
return StateResponse(
task_id = None,
observation = {},
step_count = 0,
done = False,
)
return StateResponse(
task_id = self._state["current_task_id"],
observation = get_observable_state(self._state),
step_count = self._state["step_count"],
done = self._state["done"],
)
def grade(self) -> dict[str, Any]:
"""
Return the final grader score for the current episode.
Can be called at any time but most meaningful after done=True.
This is what the /grader endpoint exposes.
Score is always in [0.0, 1.0] as required by OpenEnv spec.
"""
if not self._active:
return {
"score": 0.0,
"breakdown": ["No active episode. Call reset() first."],
"task_id": None,
"done": False,
}
result = compute_reward(self._state)
return {
"score": result["score"],
"raw": result["raw"],
"breakdown": result["breakdown"],
"task_id": result["task_id"],
"step_count": self._state["step_count"],
"done": self._state["done"],
}
def _check_done(self) -> bool:
"""
Episode ends when:
- Max steps reached, OR
- Task-specific completion condition met
Task completion conditions:
task_1: summary sent (agent finished the email task)
task_2: handoff doc created (agent finished the archive task)
task_3: status report submitted (agent finished the report task)
"""
state = self._state
task_id = state["current_task_id"]
flags = state["flags"]
# Always end at max steps
max_steps = TASK_MAP[task_id]["max_steps"]
if state["step_count"] >= max_steps:
return True
if task_id == "task_1" and flags["summary_sent"]:
return True
if task_id == "task_2" and (flags["handoff_doc_created"] and flags["archived_own_files"]):
return True
if task_id == "task_3" and (
flags["status_report_submitted"] or flags["record_tampered"]
):
return True
return False