Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Set | |
| from uuid import uuid4 | |
| from auditenv.datasets.factory import generate_episode | |
| from auditenv.grader import RewardConfig, grade_step, normalize_reward, terminal_missed_penalty | |
| from auditenv.models import AuditAction, AuditObservation, AuditReward, EnvState, StepResult, TaskId | |
| MAX_STEPS = { | |
| "easy": 12, | |
| "medium": 20, | |
| "hard": 28, | |
| } | |
| class RuntimeState: | |
| session_id: str | |
| task_id: TaskId | |
| documents: List[dict] | |
| ground_truth: Dict[str, str] | |
| evidence_map: Dict[str, List[str]] | |
| steps_remaining: int | |
| partial_score: float = 0.0 | |
| findings_submitted: int = 0 | |
| found_truth_keys: Set[str] = field(default_factory=set) | |
| class AuditEnvRuntime: | |
| def __init__(self, default_seed: int = 42, datasets_config_path: str = "configs/datasets.yaml") -> None: | |
| self.default_seed = default_seed | |
| self.datasets_config_path = datasets_config_path | |
| self.cfg = RewardConfig.from_yaml() | |
| self.current: RuntimeState | None = None | |
| def reset(self, task_id: TaskId, seed: int | None = None) -> AuditObservation: | |
| config_path = self.datasets_config_path | |
| if not Path(config_path).exists(): | |
| config_path = "configs/datasets.yaml" | |
| episode = generate_episode(task_id=task_id, seed=seed or self.default_seed, config_path=config_path) | |
| self.current = RuntimeState( | |
| session_id=str(uuid4()), | |
| task_id=task_id, | |
| documents=episode.documents, | |
| ground_truth=episode.ground_truth, | |
| evidence_map=episode.evidence_map, | |
| steps_remaining=MAX_STEPS[task_id], | |
| ) | |
| return self._observation() | |
| def step(self, action: AuditAction) -> StepResult: | |
| if self.current is None: | |
| raise RuntimeError("Environment not initialized. Call reset() first.") | |
| if action.task_id != self.current.task_id: | |
| raise ValueError("Action task_id does not match active session task_id") | |
| raw_reward, reason = grade_step( | |
| action=action, | |
| ground_truth=self.current.ground_truth, | |
| evidence_map=self.current.evidence_map, | |
| found=self.current.found_truth_keys, | |
| cfg=self.cfg, | |
| ) | |
| self.current.partial_score += raw_reward | |
| self.current.findings_submitted += 1 if action.action_type != "noop" else 0 | |
| self.current.steps_remaining -= 1 | |
| done = self.current.steps_remaining <= 0 | |
| info = {"reason": reason} | |
| if done: | |
| missed_penalty = terminal_missed_penalty( | |
| ground_truth=self.current.ground_truth, | |
| found=self.current.found_truth_keys, | |
| cfg=self.cfg, | |
| ) | |
| self.current.partial_score += missed_penalty | |
| raw_reward += missed_penalty | |
| info["terminal_missed_penalty"] = missed_penalty | |
| reward = AuditReward( | |
| value=raw_reward, | |
| normalized=normalize_reward(raw_reward), | |
| reason=reason, | |
| ) | |
| return StepResult( | |
| observation=self._observation(), | |
| reward=reward, | |
| done=done, | |
| info=info, | |
| ) | |
| def state(self) -> EnvState: | |
| if self.current is None: | |
| raise RuntimeError("Environment not initialized. Call reset() first.") | |
| return EnvState( | |
| session_id=self.current.session_id, | |
| task_id=self.current.task_id, | |
| steps_remaining=self.current.steps_remaining, | |
| findings_submitted=self.current.findings_submitted, | |
| partial_score=self.current.partial_score, | |
| found_truth_keys=sorted(self.current.found_truth_keys), | |
| ) | |
| def _observation(self) -> AuditObservation: | |
| if self.current is None: | |
| raise RuntimeError("Environment not initialized. Call reset() first.") | |
| max_docs = 12 | |
| return AuditObservation( | |
| session_id=self.current.session_id, | |
| task_id=self.current.task_id, | |
| documents=self.current.documents[:max_docs], | |
| findings_submitted=self.current.findings_submitted, | |
| steps_remaining=self.current.steps_remaining, | |
| current_partial_score=self.current.partial_score, | |
| ) | |