File size: 4,389 Bytes
a617acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Set
from uuid import uuid4

from auditenv.datasets.factory import generate_episode
from auditenv.grader import RewardConfig, grade_step, normalize_reward, terminal_missed_penalty
from auditenv.models import AuditAction, AuditObservation, AuditReward, EnvState, StepResult, TaskId


MAX_STEPS = {
    "easy": 12,
    "medium": 20,
    "hard": 28,
}


@dataclass
class RuntimeState:
    session_id: str
    task_id: TaskId
    documents: List[dict]
    ground_truth: Dict[str, str]
    evidence_map: Dict[str, List[str]]
    steps_remaining: int
    partial_score: float = 0.0
    findings_submitted: int = 0
    found_truth_keys: Set[str] = field(default_factory=set)


class AuditEnvRuntime:
    def __init__(self, default_seed: int = 42, datasets_config_path: str = "configs/datasets.yaml") -> None:
        self.default_seed = default_seed
        self.datasets_config_path = datasets_config_path
        self.cfg = RewardConfig.from_yaml()
        self.current: RuntimeState | None = None

    def reset(self, task_id: TaskId, seed: int | None = None) -> AuditObservation:
        config_path = self.datasets_config_path
        if not Path(config_path).exists():
            config_path = "configs/datasets.yaml"
        episode = generate_episode(task_id=task_id, seed=seed or self.default_seed, config_path=config_path)
        self.current = RuntimeState(
            session_id=str(uuid4()),
            task_id=task_id,
            documents=episode.documents,
            ground_truth=episode.ground_truth,
            evidence_map=episode.evidence_map,
            steps_remaining=MAX_STEPS[task_id],
        )
        return self._observation()

    def step(self, action: AuditAction) -> StepResult:
        if self.current is None:
            raise RuntimeError("Environment not initialized. Call reset() first.")
        if action.task_id != self.current.task_id:
            raise ValueError("Action task_id does not match active session task_id")

        raw_reward, reason = grade_step(
            action=action,
            ground_truth=self.current.ground_truth,
            evidence_map=self.current.evidence_map,
            found=self.current.found_truth_keys,
            cfg=self.cfg,
        )

        self.current.partial_score += raw_reward
        self.current.findings_submitted += 1 if action.action_type != "noop" else 0
        self.current.steps_remaining -= 1

        done = self.current.steps_remaining <= 0
        info = {"reason": reason}

        if done:
            missed_penalty = terminal_missed_penalty(
                ground_truth=self.current.ground_truth,
                found=self.current.found_truth_keys,
                cfg=self.cfg,
            )
            self.current.partial_score += missed_penalty
            raw_reward += missed_penalty
            info["terminal_missed_penalty"] = missed_penalty

        reward = AuditReward(
            value=raw_reward,
            normalized=normalize_reward(raw_reward),
            reason=reason,
        )

        return StepResult(
            observation=self._observation(),
            reward=reward,
            done=done,
            info=info,
        )

    def state(self) -> EnvState:
        if self.current is None:
            raise RuntimeError("Environment not initialized. Call reset() first.")

        return EnvState(
            session_id=self.current.session_id,
            task_id=self.current.task_id,
            steps_remaining=self.current.steps_remaining,
            findings_submitted=self.current.findings_submitted,
            partial_score=self.current.partial_score,
            found_truth_keys=sorted(self.current.found_truth_keys),
        )

    def _observation(self) -> AuditObservation:
        if self.current is None:
            raise RuntimeError("Environment not initialized. Call reset() first.")

        max_docs = 12
        return AuditObservation(
            session_id=self.current.session_id,
            task_id=self.current.task_id,
            documents=self.current.documents[:max_docs],
            findings_submitted=self.current.findings_submitted,
            steps_remaining=self.current.steps_remaining,
            current_partial_score=self.current.partial_score,
        )