Corp_AI / src /auditenv /state.py
Arpit Deep
feat: initial AuditEnv submission
a617acd
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Set
from uuid import uuid4
from auditenv.datasets.factory import generate_episode
from auditenv.grader import RewardConfig, grade_step, normalize_reward, terminal_missed_penalty
from auditenv.models import AuditAction, AuditObservation, AuditReward, EnvState, StepResult, TaskId
MAX_STEPS = {
"easy": 12,
"medium": 20,
"hard": 28,
}
@dataclass
class RuntimeState:
session_id: str
task_id: TaskId
documents: List[dict]
ground_truth: Dict[str, str]
evidence_map: Dict[str, List[str]]
steps_remaining: int
partial_score: float = 0.0
findings_submitted: int = 0
found_truth_keys: Set[str] = field(default_factory=set)
class AuditEnvRuntime:
def __init__(self, default_seed: int = 42, datasets_config_path: str = "configs/datasets.yaml") -> None:
self.default_seed = default_seed
self.datasets_config_path = datasets_config_path
self.cfg = RewardConfig.from_yaml()
self.current: RuntimeState | None = None
def reset(self, task_id: TaskId, seed: int | None = None) -> AuditObservation:
config_path = self.datasets_config_path
if not Path(config_path).exists():
config_path = "configs/datasets.yaml"
episode = generate_episode(task_id=task_id, seed=seed or self.default_seed, config_path=config_path)
self.current = RuntimeState(
session_id=str(uuid4()),
task_id=task_id,
documents=episode.documents,
ground_truth=episode.ground_truth,
evidence_map=episode.evidence_map,
steps_remaining=MAX_STEPS[task_id],
)
return self._observation()
def step(self, action: AuditAction) -> StepResult:
if self.current is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
if action.task_id != self.current.task_id:
raise ValueError("Action task_id does not match active session task_id")
raw_reward, reason = grade_step(
action=action,
ground_truth=self.current.ground_truth,
evidence_map=self.current.evidence_map,
found=self.current.found_truth_keys,
cfg=self.cfg,
)
self.current.partial_score += raw_reward
self.current.findings_submitted += 1 if action.action_type != "noop" else 0
self.current.steps_remaining -= 1
done = self.current.steps_remaining <= 0
info = {"reason": reason}
if done:
missed_penalty = terminal_missed_penalty(
ground_truth=self.current.ground_truth,
found=self.current.found_truth_keys,
cfg=self.cfg,
)
self.current.partial_score += missed_penalty
raw_reward += missed_penalty
info["terminal_missed_penalty"] = missed_penalty
reward = AuditReward(
value=raw_reward,
normalized=normalize_reward(raw_reward),
reason=reason,
)
return StepResult(
observation=self._observation(),
reward=reward,
done=done,
info=info,
)
def state(self) -> EnvState:
if self.current is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
return EnvState(
session_id=self.current.session_id,
task_id=self.current.task_id,
steps_remaining=self.current.steps_remaining,
findings_submitted=self.current.findings_submitted,
partial_score=self.current.partial_score,
found_truth_keys=sorted(self.current.found_truth_keys),
)
def _observation(self) -> AuditObservation:
if self.current is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
max_docs = 12
return AuditObservation(
session_id=self.current.session_id,
task_id=self.current.task_id,
documents=self.current.documents[:max_docs],
findings_submitted=self.current.findings_submitted,
steps_remaining=self.current.steps_remaining,
current_partial_score=self.current.partial_score,
)