Spaces:
Sleeping
Sleeping
| """Stateful simulation engine for step-by-step email triage.""" | |
| from __future__ import annotations | |
| import logging | |
| import uuid | |
| from threading import RLock | |
| from typing import Any, Mapping | |
| from core_engine.evaluator import TriageEvaluator | |
| from core_engine.mail_factory import MailFactory | |
| from core_engine.schemas import ( | |
| AgentDecision, | |
| EvaluationRecord, | |
| PayloadValidationError, | |
| SyntheticMail, | |
| ) | |
| from core_engine.score_bounds import enforce_strict_score | |
| LOGGER = logging.getLogger(__name__) | |
| class SimulationError(RuntimeError): | |
| """Raised when a simulation action is invalid for the current state.""" | |
| class SimulationEngine: | |
| """Manage generated emails, agent actions, and scoring state.""" | |
| def __init__( | |
| self, | |
| batch_size: int = 12, | |
| random_seed: int | None = None, | |
| simulation_mode: str = "easy", | |
| mail_factory: MailFactory | None = None, | |
| evaluator: TriageEvaluator | None = None, | |
| ) -> None: | |
| self._batch_size = batch_size | |
| self._simulation_mode = simulation_mode if simulation_mode in {"easy", "hard"} else "easy" | |
| self._mail_factory = mail_factory or MailFactory( | |
| seed=random_seed, | |
| simulation_mode=self._simulation_mode, | |
| ) | |
| self._evaluator = evaluator or TriageEvaluator() | |
| self._messages: list[SyntheticMail] = [] | |
| self._records: list[EvaluationRecord] = [] | |
| self._cursor = 0 | |
| self._run_id = "" | |
| self._lock = RLock() | |
| def reset(self, message_count: int | None = None) -> dict[str, Any]: | |
| """Start a new simulation and return the initial visible state.""" | |
| with self._lock: | |
| count = message_count or self._batch_size | |
| if count <= 0: | |
| raise SimulationError("Simulation must contain at least one email.") | |
| self._messages = self._mail_factory.build_batch(count) | |
| self._records = [] | |
| self._cursor = 0 | |
| self._run_id = str(uuid.uuid4()) | |
| LOGGER.info( | |
| "Simulation %s initialized with %s emails in %s mode.", | |
| self._run_id, | |
| count, | |
| self._simulation_mode, | |
| ) | |
| return self._state_unlocked() | |
| def step(self, action: AgentDecision | Mapping[str, Any]) -> dict[str, Any]: | |
| """Apply one agent action and return state, reward, and completion flag.""" | |
| decision = ( | |
| action if isinstance(action, AgentDecision) else AgentDecision.from_payload(action) | |
| ) | |
| with self._lock: | |
| if not self._messages: | |
| raise SimulationError("Initialize the simulation before sending actions.") | |
| if self._cursor >= len(self._messages): | |
| raise SimulationError("Simulation is already complete.") | |
| current_message = self._messages[self._cursor] | |
| if decision.mail_id != current_message.mail_id: | |
| raise PayloadValidationError( | |
| "Decision id must match the current email id " | |
| f"'{current_message.mail_id}'." | |
| ) | |
| record = self._evaluator.evaluate(current_message, decision) | |
| self._records.append(record) | |
| self._cursor += 1 | |
| done = self._cursor >= len(self._messages) | |
| LOGGER.info( | |
| "Processed email %s with reward %.2f.", | |
| current_message.mail_id, | |
| record.step_score, | |
| ) | |
| return { | |
| "state": self._state_unlocked(), | |
| "reward": enforce_strict_score(record.step_score / 100), | |
| "done": done, | |
| "evaluation": record.to_dict(), | |
| "score": self._evaluator.summarize( | |
| self._records, len(self._messages) | |
| ).to_dict(), | |
| } | |
| def get_state(self) -> dict[str, Any]: | |
| """Return the current visible simulation state.""" | |
| with self._lock: | |
| return self._state_unlocked() | |
| def _state_unlocked(self) -> dict[str, Any]: | |
| processed_ids = {record.mail_id for record in self._records} | |
| done = bool(self._messages) and self._cursor >= len(self._messages) | |
| current_email = None if done or not self._messages else self._messages[self._cursor] | |
| total_count = len(self._messages) | |
| return { | |
| "emails": [ | |
| message.public_view(processed=message.mail_id in processed_ids) | |
| for message in self._messages | |
| ], | |
| "run_id": self._run_id, | |
| "simulation_mode": self._simulation_mode, | |
| "current_email": ( | |
| None | |
| if current_email is None | |
| else current_email.public_view( | |
| processed=current_email.mail_id in processed_ids | |
| ) | |
| ), | |
| "progress": { | |
| "processed": len(self._records), | |
| "remaining": max(total_count - len(self._records), 0), | |
| "total": total_count, | |
| }, | |
| "done": done, | |
| "score": self._evaluator.summarize(self._records, total_count).to_dict(), | |
| } | |