Vansh Jagetia
clean deploy for hf
a1933cb
"""Stateful simulation engine for step-by-step email triage."""
from __future__ import annotations
import logging
import uuid
from threading import RLock
from typing import Any, Mapping
from core_engine.evaluator import TriageEvaluator
from core_engine.mail_factory import MailFactory
from core_engine.schemas import (
AgentDecision,
EvaluationRecord,
PayloadValidationError,
SyntheticMail,
)
from core_engine.score_bounds import enforce_strict_score
LOGGER = logging.getLogger(__name__)
class SimulationError(RuntimeError):
"""Raised when a simulation action is invalid for the current state."""
class SimulationEngine:
"""Manage generated emails, agent actions, and scoring state."""
def __init__(
self,
batch_size: int = 12,
random_seed: int | None = None,
simulation_mode: str = "easy",
mail_factory: MailFactory | None = None,
evaluator: TriageEvaluator | None = None,
) -> None:
self._batch_size = batch_size
self._simulation_mode = simulation_mode if simulation_mode in {"easy", "hard"} else "easy"
self._mail_factory = mail_factory or MailFactory(
seed=random_seed,
simulation_mode=self._simulation_mode,
)
self._evaluator = evaluator or TriageEvaluator()
self._messages: list[SyntheticMail] = []
self._records: list[EvaluationRecord] = []
self._cursor = 0
self._run_id = ""
self._lock = RLock()
def reset(self, message_count: int | None = None) -> dict[str, Any]:
"""Start a new simulation and return the initial visible state."""
with self._lock:
count = message_count or self._batch_size
if count <= 0:
raise SimulationError("Simulation must contain at least one email.")
self._messages = self._mail_factory.build_batch(count)
self._records = []
self._cursor = 0
self._run_id = str(uuid.uuid4())
LOGGER.info(
"Simulation %s initialized with %s emails in %s mode.",
self._run_id,
count,
self._simulation_mode,
)
return self._state_unlocked()
def step(self, action: AgentDecision | Mapping[str, Any]) -> dict[str, Any]:
"""Apply one agent action and return state, reward, and completion flag."""
decision = (
action if isinstance(action, AgentDecision) else AgentDecision.from_payload(action)
)
with self._lock:
if not self._messages:
raise SimulationError("Initialize the simulation before sending actions.")
if self._cursor >= len(self._messages):
raise SimulationError("Simulation is already complete.")
current_message = self._messages[self._cursor]
if decision.mail_id != current_message.mail_id:
raise PayloadValidationError(
"Decision id must match the current email id "
f"'{current_message.mail_id}'."
)
record = self._evaluator.evaluate(current_message, decision)
self._records.append(record)
self._cursor += 1
done = self._cursor >= len(self._messages)
LOGGER.info(
"Processed email %s with reward %.2f.",
current_message.mail_id,
record.step_score,
)
return {
"state": self._state_unlocked(),
"reward": enforce_strict_score(record.step_score / 100),
"done": done,
"evaluation": record.to_dict(),
"score": self._evaluator.summarize(
self._records, len(self._messages)
).to_dict(),
}
def get_state(self) -> dict[str, Any]:
"""Return the current visible simulation state."""
with self._lock:
return self._state_unlocked()
def _state_unlocked(self) -> dict[str, Any]:
processed_ids = {record.mail_id for record in self._records}
done = bool(self._messages) and self._cursor >= len(self._messages)
current_email = None if done or not self._messages else self._messages[self._cursor]
total_count = len(self._messages)
return {
"emails": [
message.public_view(processed=message.mail_id in processed_ids)
for message in self._messages
],
"run_id": self._run_id,
"simulation_mode": self._simulation_mode,
"current_email": (
None
if current_email is None
else current_email.public_view(
processed=current_email.mail_id in processed_ids
)
),
"progress": {
"processed": len(self._records),
"remaining": max(total_count - len(self._records), 0),
"total": total_count,
},
"done": done,
"score": self._evaluator.summarize(self._records, total_count).to_dict(),
}