Spaces:
Sleeping
Sleeping
File size: 5,208 Bytes
a1933cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """Stateful simulation engine for step-by-step email triage."""
from __future__ import annotations
import logging
import uuid
from threading import RLock
from typing import Any, Mapping
from core_engine.evaluator import TriageEvaluator
from core_engine.mail_factory import MailFactory
from core_engine.schemas import (
AgentDecision,
EvaluationRecord,
PayloadValidationError,
SyntheticMail,
)
from core_engine.score_bounds import enforce_strict_score
LOGGER = logging.getLogger(__name__)
class SimulationError(RuntimeError):
"""Raised when a simulation action is invalid for the current state."""
class SimulationEngine:
"""Manage generated emails, agent actions, and scoring state."""
def __init__(
self,
batch_size: int = 12,
random_seed: int | None = None,
simulation_mode: str = "easy",
mail_factory: MailFactory | None = None,
evaluator: TriageEvaluator | None = None,
) -> None:
self._batch_size = batch_size
self._simulation_mode = simulation_mode if simulation_mode in {"easy", "hard"} else "easy"
self._mail_factory = mail_factory or MailFactory(
seed=random_seed,
simulation_mode=self._simulation_mode,
)
self._evaluator = evaluator or TriageEvaluator()
self._messages: list[SyntheticMail] = []
self._records: list[EvaluationRecord] = []
self._cursor = 0
self._run_id = ""
self._lock = RLock()
def reset(self, message_count: int | None = None) -> dict[str, Any]:
"""Start a new simulation and return the initial visible state."""
with self._lock:
count = message_count or self._batch_size
if count <= 0:
raise SimulationError("Simulation must contain at least one email.")
self._messages = self._mail_factory.build_batch(count)
self._records = []
self._cursor = 0
self._run_id = str(uuid.uuid4())
LOGGER.info(
"Simulation %s initialized with %s emails in %s mode.",
self._run_id,
count,
self._simulation_mode,
)
return self._state_unlocked()
def step(self, action: AgentDecision | Mapping[str, Any]) -> dict[str, Any]:
"""Apply one agent action and return state, reward, and completion flag."""
decision = (
action if isinstance(action, AgentDecision) else AgentDecision.from_payload(action)
)
with self._lock:
if not self._messages:
raise SimulationError("Initialize the simulation before sending actions.")
if self._cursor >= len(self._messages):
raise SimulationError("Simulation is already complete.")
current_message = self._messages[self._cursor]
if decision.mail_id != current_message.mail_id:
raise PayloadValidationError(
"Decision id must match the current email id "
f"'{current_message.mail_id}'."
)
record = self._evaluator.evaluate(current_message, decision)
self._records.append(record)
self._cursor += 1
done = self._cursor >= len(self._messages)
LOGGER.info(
"Processed email %s with reward %.2f.",
current_message.mail_id,
record.step_score,
)
return {
"state": self._state_unlocked(),
"reward": enforce_strict_score(record.step_score / 100),
"done": done,
"evaluation": record.to_dict(),
"score": self._evaluator.summarize(
self._records, len(self._messages)
).to_dict(),
}
def get_state(self) -> dict[str, Any]:
"""Return the current visible simulation state."""
with self._lock:
return self._state_unlocked()
def _state_unlocked(self) -> dict[str, Any]:
processed_ids = {record.mail_id for record in self._records}
done = bool(self._messages) and self._cursor >= len(self._messages)
current_email = None if done or not self._messages else self._messages[self._cursor]
total_count = len(self._messages)
return {
"emails": [
message.public_view(processed=message.mail_id in processed_ids)
for message in self._messages
],
"run_id": self._run_id,
"simulation_mode": self._simulation_mode,
"current_email": (
None
if current_email is None
else current_email.public_view(
processed=current_email.mail_id in processed_ids
)
),
"progress": {
"processed": len(self._records),
"remaining": max(total_count - len(self._records), 0),
"total": total_count,
},
"done": done,
"score": self._evaluator.summarize(self._records, total_count).to_dict(),
}
|