from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Set from uuid import uuid4 from auditenv.datasets.factory import generate_episode from auditenv.grader import RewardConfig, grade_step, normalize_reward, terminal_missed_penalty from auditenv.models import AuditAction, AuditObservation, AuditReward, EnvState, StepResult, TaskId MAX_STEPS = { "easy": 12, "medium": 20, "hard": 28, } @dataclass class RuntimeState: session_id: str task_id: TaskId documents: List[dict] ground_truth: Dict[str, str] evidence_map: Dict[str, List[str]] steps_remaining: int partial_score: float = 0.0 findings_submitted: int = 0 found_truth_keys: Set[str] = field(default_factory=set) class AuditEnvRuntime: def __init__(self, default_seed: int = 42, datasets_config_path: str = "configs/datasets.yaml") -> None: self.default_seed = default_seed self.datasets_config_path = datasets_config_path self.cfg = RewardConfig.from_yaml() self.current: RuntimeState | None = None def reset(self, task_id: TaskId, seed: int | None = None) -> AuditObservation: config_path = self.datasets_config_path if not Path(config_path).exists(): config_path = "configs/datasets.yaml" episode = generate_episode(task_id=task_id, seed=seed or self.default_seed, config_path=config_path) self.current = RuntimeState( session_id=str(uuid4()), task_id=task_id, documents=episode.documents, ground_truth=episode.ground_truth, evidence_map=episode.evidence_map, steps_remaining=MAX_STEPS[task_id], ) return self._observation() def step(self, action: AuditAction) -> StepResult: if self.current is None: raise RuntimeError("Environment not initialized. Call reset() first.") if action.task_id != self.current.task_id: raise ValueError("Action task_id does not match active session task_id") raw_reward, reason = grade_step( action=action, ground_truth=self.current.ground_truth, evidence_map=self.current.evidence_map, found=self.current.found_truth_keys, cfg=self.cfg, ) self.current.partial_score += raw_reward self.current.findings_submitted += 1 if action.action_type != "noop" else 0 self.current.steps_remaining -= 1 done = self.current.steps_remaining <= 0 info = {"reason": reason} if done: missed_penalty = terminal_missed_penalty( ground_truth=self.current.ground_truth, found=self.current.found_truth_keys, cfg=self.cfg, ) self.current.partial_score += missed_penalty raw_reward += missed_penalty info["terminal_missed_penalty"] = missed_penalty reward = AuditReward( value=raw_reward, normalized=normalize_reward(raw_reward), reason=reason, ) return StepResult( observation=self._observation(), reward=reward, done=done, info=info, ) def state(self) -> EnvState: if self.current is None: raise RuntimeError("Environment not initialized. Call reset() first.") return EnvState( session_id=self.current.session_id, task_id=self.current.task_id, steps_remaining=self.current.steps_remaining, findings_submitted=self.current.findings_submitted, partial_score=self.current.partial_score, found_truth_keys=sorted(self.current.found_truth_keys), ) def _observation(self) -> AuditObservation: if self.current is None: raise RuntimeError("Environment not initialized. Call reset() first.") max_docs = 12 return AuditObservation( session_id=self.current.session_id, task_id=self.current.task_id, documents=self.current.documents[:max_docs], findings_submitted=self.current.findings_submitted, steps_remaining=self.current.steps_remaining, current_partial_score=self.current.partial_score, )