sentinel-env / environment.py
XcodeAddy's picture
Initial Set-up
d82d9e8
raw
history blame
16.2 kB
from __future__ import annotations
import random
import uuid
from typing import Any
from graders import (
grade_task1_step,
grade_task2_step, grade_task2_terminal,
grade_task3_step, grade_task3_terminal,
)
from scenarios import sample_scenario, get_scenario, Scenario
from specialists import SpecialistPool
from task_graph import TaskGraph
from trust_ledger import TrustLedger
# ---------------------------------------------------------------------------
# Lazy Pydantic imports — falls back to plain dicts if not installed
# ---------------------------------------------------------------------------
try:
from models import (
SentinelObservation, SentinelAction, SentinelReward,
SentinelState, StepResult, EpisodeStatus,
TaskType, ActionType, SpecialistId, ResetRequest,
)
_PYDANTIC = True
except ImportError:
_PYDANTIC = False
# Hidden ground-truth reliability for each specialist FSM type.
# These are NEVER sent to the agent — only used by the reward engine
# to compute calibration score.
_GROUND_TRUTH_RELIABILITY: dict[str, float] = {
"S0": 0.90, # AccurateSlow
"S1": 0.60, # OverconfidentFast
"S2": 0.70, # DomainBound (avg of in/out domain)
"S3": 0.15, # Adversarial (unreliable overall)
"S4": 0.65, # Degrading (avg over calls)
}
MAX_STEPS = {
"task1": 15,
"task2": 30,
"task3": 45,
}
class SentinelEnv:
"""
Core environment class — mirrors Round 1 IncidentEnv pattern exactly.
reset() / step() / state() API.
No BaseEnv subclassing needed — plain Python class, FastAPI wraps it.
"""
def __init__(self) -> None:
self.current_scenario: Scenario | None = None
self.episode_id: str = ""
self.session_id: str = ""
self.step_count: int = 0
self.max_steps: int = 0
self.total_reward: float = 0.0
self.reward_events: int = 0
self.last_reward: float = 0.0
self.done: bool = False
self.episode_status: str = "active"
self.last_action_summary: str | None = None
self._graph: TaskGraph | None = None
self._ledger: TrustLedger = TrustLedger()
self._pool: SpecialistPool = SpecialistPool()
self._rng: random.Random = random.Random()
# ------------------------------------------------------------------
# reset()
# ------------------------------------------------------------------
def reset(
self,
task_type: str | None = None,
scenario_id: str | None = None,
seed: int | None = None,
) -> dict:
self._rng = random.Random(seed)
# Select scenario
if scenario_id:
scenario = get_scenario(scenario_id)
else:
task = task_type or "task3"
scenario = sample_scenario(task, seed=seed)
self.current_scenario = scenario
self.episode_id = str(uuid.uuid4())
self.session_id = str(uuid.uuid4())
self.step_count = 0
self.max_steps = MAX_STEPS[scenario["task_type"]]
self.total_reward = 0.0
self.reward_events = 0
self.last_reward = 0.0
self.done = False
self.episode_status = "active"
self.last_action_summary = None
# Reset subcomponents
self._graph = TaskGraph(scenario)
self._ledger.reset()
self._pool.reset(seed=seed)
return self._build_step_result(
reward_value=0.0,
reason="Episode initialized.",
breakdown={},
done=False,
extra_info={"episode_id": self.episode_id, "session_id": self.session_id},
)
# ------------------------------------------------------------------
# step()
# ------------------------------------------------------------------
def step(self, action: dict) -> dict:
if self.current_scenario is None:
raise RuntimeError("Call reset() before step().")
if self.done:
raise RuntimeError("Episode already completed. Call reset().")
# --- Validate session ---
if action.get("session_id") and action["session_id"] != self.session_id:
raise ValueError(
f"session_id mismatch: expected '{self.session_id}', got '{action['session_id']}'"
)
action_type = action.get("action_type", "delegate")
specialist_id = action.get("specialist_id")
task_type = self.current_scenario["task_type"]
# --- Validate action fields ---
if action_type in ("delegate", "verify") and not specialist_id:
raise ValueError(f"action_type='{action_type}' requires specialist_id.")
if action_type == "solve_independently" and not action.get("subtask_response"):
raise ValueError("action_type='solve_independently' requires subtask_response.")
# --- Get current subtask ---
node = self._graph.current_node()
if node is None:
# All nodes done — emit terminal reward
return self._terminal_reward()
subtask = node.subtask
stakes = subtask["stakes"]
step_cost = 1
# --- Execute specialist or self-solve ---
if action_type == "skip":
self._graph.skip_node(subtask["id"])
outcome = 0.0
was_adversarial = False
self.last_action_summary = f"Skipped {subtask['id']}"
elif action_type == "solve_independently":
# Agent solves itself — always correct (no specialist involved)
# But costs 2 steps (enforced via max_steps budget pressure).
step_cost = 2
outcome = 1.0
was_adversarial = False
self._graph.record_outcome(subtask["id"], outcome, "SELF")
self.last_action_summary = f"Self-solved {subtask['id']}"
elif action_type == "verify":
# First get result, then verify (costs +1 step)
result = self._pool.execute(specialist_id, subtask["description"], stakes, self._rng)
step_cost = int(result.metadata.get("step_cost", 1)) + 1
outcome = result.outcome if not result.is_adversarial else 0.0
was_adversarial = result.is_adversarial
# Verification means agent caught adversarial — treat as detection
if result.is_adversarial:
outcome = 1.0 # successfully avoided
self._graph.record_outcome(subtask["id"], outcome, specialist_id, was_adversarial)
self._ledger.update(specialist_id, result.outcome, stakes)
self.last_action_summary = f"Verified {specialist_id} on {subtask['id']}"
else: # delegate
result = self._pool.execute(specialist_id, subtask["description"], stakes, self._rng)
step_cost = int(result.metadata.get("step_cost", 1))
was_adversarial = result.is_adversarial
outcome = 0.0 if was_adversarial else result.outcome
self._graph.record_outcome(subtask["id"], outcome, specialist_id, was_adversarial)
self._ledger.update(specialist_id, result.outcome, stakes)
self.last_action_summary = f"Delegated to {specialist_id} on {subtask['id']}"
self.step_count += max(1, step_cost)
# --- Grade this step ---
reward_value, reason, breakdown = self._grade_step(
task_type, action_type, specialist_id, outcome,
stakes, was_adversarial,
)
self.last_reward = reward_value
self.total_reward += reward_value
self.reward_events += 1
# --- Check episode end ---
all_done = self._graph.is_done()
budget_gone = self.step_count >= self.max_steps
poisoned = was_adversarial and action_type == "delegate" and stakes >= 0.85
if all_done or budget_gone or poisoned:
# Emit terminal reward on top of step reward
return self._terminal_reward(
step_reward=reward_value,
step_reason=reason,
step_breakdown=breakdown,
forced_end=poisoned,
)
return self._build_step_result(reward_value, reason, breakdown, done=False)
# ------------------------------------------------------------------
# state()
# ------------------------------------------------------------------
def state(self, session_id: str | None = None) -> dict:
if self.current_scenario is None:
raise RuntimeError("No active episode. Call reset() first.")
return {
"episode_id": self.episode_id,
"session_id": session_id or self.session_id,
"step_count": self.step_count,
"max_steps": self.max_steps,
"total_reward": round(self.total_reward, 4),
"score": round(self.normalized_score(), 4),
"done": self.done,
"scenario_id": self.current_scenario["scenario_id"],
"task_type": self.current_scenario["task_type"],
"difficulty": self._difficulty(),
"status": self.episode_status,
"last_reward": round(self.last_reward, 4),
"subtasks_completed": self._graph.subtasks_completed(),
"subtasks_total": self._graph.subtasks_total(),
"trust_snapshot": self._ledger.snapshot(),
"adversarial_detections": self._graph.adversarial_detections(),
"adversarial_poisonings": self._graph.adversarial_poisonings(),
}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _grade_step(
self,
task_type: str,
action_type: str,
specialist_id: str | None,
outcome: float,
stakes: float,
was_adversarial: bool,
) -> tuple[float, str, dict]:
if task_type == "task1":
return grade_task1_step(
chosen_specialist=specialist_id or "SELF",
specialist_outcome=outcome,
stakes=stakes,
was_adversarial=was_adversarial,
action_type=action_type,
)
elif task_type == "task2":
return grade_task2_step(
specialist_outcome=outcome,
action_type=action_type,
step_count=self.step_count,
max_steps=self.max_steps,
)
else: # task3
return grade_task3_step(
specialist_outcome=outcome,
stakes=stakes,
was_adversarial=was_adversarial,
action_type=action_type,
step_count=self.step_count,
max_steps=self.max_steps,
)
def _terminal_reward(
self,
step_reward: float = 0.0,
step_reason: str = "",
step_breakdown: dict | None = None,
forced_end: bool = False,
) -> dict:
task_type = self.current_scenario["task_type"]
if task_type == "task1":
# Terminal for task1: completion-based if no step reward provided
from graders import clamp
if step_reward > 0.0:
terminal_value = step_reward
terminal_reason = step_reason
terminal_breakdown = step_breakdown or {}
else:
completion = self._graph.completion_rate()
terminal_value = clamp(0.4 + 0.5 * completion)
terminal_reason = f"Task1 complete. Completion={completion:.0%}."
terminal_breakdown = {"completion_rate": round(completion, 3)}
elif task_type == "task2":
terminal_value, terminal_reason, terminal_breakdown = grade_task2_terminal(
self._graph, self._ledger, self._public_ground_truth_reliability()
)
else:
terminal_value, terminal_reason, terminal_breakdown = grade_task3_terminal(
self._graph, self._ledger, self._public_ground_truth_reliability(),
self.step_count, self.max_steps,
)
if forced_end:
terminal_value = 0.01
terminal_reason = "Critical adversarial poison accepted at peak stakes — episode failed."
self.last_reward = terminal_value
self.total_reward += terminal_value
self.reward_events += 1
self.done = True
self.episode_status = "failed" if forced_end else "completed"
return self._build_step_result(
terminal_value, terminal_reason, terminal_breakdown,
done=True,
extra_info={
**self._graph.summary(),
"trust_snapshot": self._ledger.snapshot(),
"forced_end": forced_end,
},
)
def _build_step_result(
self,
reward_value: float,
reason: str,
breakdown: dict,
done: bool,
extra_info: dict | None = None,
) -> dict:
node = self._graph.current_node() if self._graph and not done else None
subtask_index = self._graph.node_index(node.subtask["id"]) if node else (
self._graph.subtasks_total() if self._graph else 0
)
obs = {
"session_id": self.session_id,
"scenario_id": self.current_scenario["scenario_id"] if self.current_scenario else "",
"task_type": self.current_scenario["task_type"] if self.current_scenario else "",
"difficulty": self._difficulty(),
"task_description": self.current_scenario["description"] if self.current_scenario else "",
"current_subtask": node.subtask["description"] if node else "All subtasks complete.",
"subtask_index": subtask_index,
"subtasks_total": self._graph.subtasks_total() if self._graph else 0,
"subtasks_remaining": self._graph.subtasks_remaining() if self._graph else 0,
"available_specialists": self._pool.available_ids(),
"trust_snapshot": self._ledger.snapshot(),
"stakes_level": node.subtask["stakes"] if node else 0.0,
"step_count": self.step_count,
"max_steps": self.max_steps,
"last_action_summary": self.last_action_summary,
"last_reward": round(self.last_reward, 4),
"episode_status": self.episode_status,
}
reward = {
"value": round(reward_value, 4),
"reason": reason,
"signal_breakdown": breakdown,
}
info = {
"episode_id": self.episode_id,
"session_id": self.session_id,
"step_count": self.step_count,
"max_steps": self.max_steps,
"total_reward": round(self.total_reward, 4),
"score": round(self.normalized_score(), 4),
}
if extra_info:
info.update(extra_info)
return {"observation": obs, "reward": reward, "done": done, "info": info}
def _difficulty(self) -> str:
return {"task1": "easy", "task2": "medium", "task3": "hard"}.get(
self.current_scenario["task_type"] if self.current_scenario else "task3", "hard"
)
def normalized_score(self) -> float:
"""Episode score normalized to 0.0-1.0 for judging logs."""
if self.reward_events <= 0:
return 0.0
return max(0.0, min(1.0, self.total_reward / self.reward_events))
def _public_ground_truth_reliability(self) -> dict[str, float]:
return self._pool.public_ground_truth_reliability(_GROUND_TRUTH_RELIABILITY)