File size: 5,208 Bytes
a1933cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Stateful simulation engine for step-by-step email triage."""

from __future__ import annotations

import logging
import uuid
from threading import RLock
from typing import Any, Mapping

from core_engine.evaluator import TriageEvaluator
from core_engine.mail_factory import MailFactory
from core_engine.schemas import (
    AgentDecision,
    EvaluationRecord,
    PayloadValidationError,
    SyntheticMail,
)
from core_engine.score_bounds import enforce_strict_score

LOGGER = logging.getLogger(__name__)


class SimulationError(RuntimeError):
    """Raised when a simulation action is invalid for the current state."""


class SimulationEngine:
    """Manage generated emails, agent actions, and scoring state."""

    def __init__(
        self,
        batch_size: int = 12,
        random_seed: int | None = None,
        simulation_mode: str = "easy",
        mail_factory: MailFactory | None = None,
        evaluator: TriageEvaluator | None = None,
    ) -> None:
        self._batch_size = batch_size
        self._simulation_mode = simulation_mode if simulation_mode in {"easy", "hard"} else "easy"
        self._mail_factory = mail_factory or MailFactory(
            seed=random_seed,
            simulation_mode=self._simulation_mode,
        )
        self._evaluator = evaluator or TriageEvaluator()
        self._messages: list[SyntheticMail] = []
        self._records: list[EvaluationRecord] = []
        self._cursor = 0
        self._run_id = ""
        self._lock = RLock()

    def reset(self, message_count: int | None = None) -> dict[str, Any]:
        """Start a new simulation and return the initial visible state."""
        with self._lock:
            count = message_count or self._batch_size
            if count <= 0:
                raise SimulationError("Simulation must contain at least one email.")

            self._messages = self._mail_factory.build_batch(count)
            self._records = []
            self._cursor = 0
            self._run_id = str(uuid.uuid4())
            LOGGER.info(
                "Simulation %s initialized with %s emails in %s mode.",
                self._run_id,
                count,
                self._simulation_mode,
            )
            return self._state_unlocked()

    def step(self, action: AgentDecision | Mapping[str, Any]) -> dict[str, Any]:
        """Apply one agent action and return state, reward, and completion flag."""
        decision = (
            action if isinstance(action, AgentDecision) else AgentDecision.from_payload(action)
        )

        with self._lock:
            if not self._messages:
                raise SimulationError("Initialize the simulation before sending actions.")
            if self._cursor >= len(self._messages):
                raise SimulationError("Simulation is already complete.")

            current_message = self._messages[self._cursor]
            if decision.mail_id != current_message.mail_id:
                raise PayloadValidationError(
                    "Decision id must match the current email id "
                    f"'{current_message.mail_id}'."
                )

            record = self._evaluator.evaluate(current_message, decision)
            self._records.append(record)
            self._cursor += 1
            done = self._cursor >= len(self._messages)

            LOGGER.info(
                "Processed email %s with reward %.2f.",
                current_message.mail_id,
                record.step_score,
            )
            return {
                "state": self._state_unlocked(),
                "reward": enforce_strict_score(record.step_score / 100),
                "done": done,
                "evaluation": record.to_dict(),
                "score": self._evaluator.summarize(
                    self._records, len(self._messages)
                ).to_dict(),
            }

    def get_state(self) -> dict[str, Any]:
        """Return the current visible simulation state."""
        with self._lock:
            return self._state_unlocked()

    def _state_unlocked(self) -> dict[str, Any]:
        processed_ids = {record.mail_id for record in self._records}
        done = bool(self._messages) and self._cursor >= len(self._messages)
        current_email = None if done or not self._messages else self._messages[self._cursor]
        total_count = len(self._messages)

        return {
            "emails": [
                message.public_view(processed=message.mail_id in processed_ids)
                for message in self._messages
            ],
            "run_id": self._run_id,
            "simulation_mode": self._simulation_mode,
            "current_email": (
                None
                if current_email is None
                else current_email.public_view(
                    processed=current_email.mail_id in processed_ids
                )
            ),
            "progress": {
                "processed": len(self._records),
                "remaining": max(total_count - len(self._records), 0),
                "total": total_count,
            },
            "done": done,
            "score": self._evaluator.summarize(self._records, total_count).to_dict(),
        }