from environment.types import ( Email, Observation, Action, Reward, State, GroundTruth, EmailCategory, Team ) from environment.data_generator import DataGenerator from environment.graders import ( SpamDetectionGrader, MultiClassRoutingGrader, ContextAwareTriageGrader, compute_step_reward ) from datetime import datetime from typing import Tuple, Dict, Any, List, Optional class EmailTriageEnv: """Main email triage environment implementing OpenEnv spec""" def __init__(self, task_name: str = "spam_detection"): self.task_name = task_name self.generator = DataGenerator() self.step_count = 0 self.current_email_idx = 0 self.actions_taken = [] self.rewards_accumulated = 0.0 self.done = False # Data for current task self.emails: List[Email] = [] self.ground_truths: List[GroundTruth] = [] self.current_observation: Optional[Observation] = None # Set up task self._setup_task(task_name) def _setup_task(self, task_name: str): """Initialize task-specific data""" if task_name == "spam_detection": self.emails, self.ground_truths = self.generator.generate_task1_emails() self.grader = SpamDetectionGrader() elif task_name == "multi_class_routing": self.emails, self.ground_truths = self.generator.generate_task2_emails() self.grader = MultiClassRoutingGrader() elif task_name == "context_aware_triage": self.emails, self.ground_truths = self.generator.generate_task3_emails() self.grader = ContextAwareTriageGrader() else: raise ValueError(f"Unknown task: {task_name}") def reset(self) -> Observation: """Reset environment to initial state""" self.step_count = 0 self.current_email_idx = 0 self.actions_taken = [] self.rewards_accumulated = 0.0 self.done = False # Get first email if self.emails: return self._get_observation() return Observation( current_email=Email( email_id="none", subject="", body="", sender_domain="", timestamp=datetime.now() ), inbox_state={"pending": 0, "spam": 0, "urgent": 0, "processed": 0}, step_count=0, task_name=self.task_name ) def _get_observation(self) -> Observation: """Get observation for current email""" if self.current_email_idx >= len(self.emails): # End of task self.done = True return Observation( current_email=Email( email_id="done", subject="Task Complete", body="All emails processed", sender_domain="", timestamp=datetime.now() ), inbox_state={ "pending": 0, "spam": len([t for t in self.ground_truths if t.category == EmailCategory.SPAM]), "urgent": len([t for t in self.ground_truths if t.category == EmailCategory.URGENT]), "processed": self.current_email_idx }, step_count=self.step_count, task_name=self.task_name, info={"done": True, "final_score": self._compute_final_score()} ) current_email = self.emails[self.current_email_idx] inbox_state = { "pending": len(self.emails) - self.current_email_idx, "spam": len([t for t in self.ground_truths[self.current_email_idx:] if t.category == EmailCategory.SPAM]), "urgent": len([t for t in self.ground_truths[self.current_email_idx:] if t.category == EmailCategory.URGENT]), "processed": self.current_email_idx } return Observation( current_email=current_email, inbox_state=inbox_state, step_count=self.step_count, task_name=self.task_name ) def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: """Process one email with the given action""" if self.current_email_idx >= len(self.emails): self.done = True reward = Reward(value=0.0) obs = self._get_observation() return obs, reward, True, {} # Get ground truth for current email ground_truth = self.ground_truths[self.current_email_idx] # Compute reward for this step step_reward, breakdown = compute_step_reward(action, ground_truth) reward = Reward( value=step_reward, breakdown=breakdown ) self.actions_taken.append(action) self.rewards_accumulated += step_reward self.step_count += 1 self.current_email_idx += 1 # Check if done if self.current_email_idx >= len(self.emails): self.done = True # Get next observation next_obs = self._get_observation() info = { "email_id": ground_truth.email_id, "ground_truth_category": ground_truth.category, "ground_truth_team": ground_truth.team, "ground_truth_priority": ground_truth.priority, "action_classification": action.classification, "action_team": action.team, "action_priority": action.priority, } if self.done: info["final_score"] = self._compute_final_score() info["task_complete"] = True return next_obs, reward, self.done, info def _compute_final_score(self) -> float: """Compute final task score""" if not self.actions_taken: return 0.0 return self.grader.score_actions(self.actions_taken, self.ground_truths) def state(self) -> State: """Return current complete state""" return State( current_observation=self.current_observation or self._get_observation(), current_reward=self.rewards_accumulated / max(1, self.step_count), done=self.done, history=[ { "step": i, "action": action.model_dump(), "ground_truth": truth.model_dump(), "email_id": truth.email_id } for i, (action, truth) in enumerate(zip(self.actions_taken, self.ground_truths)) ], info={ "task_name": self.task_name, "step_count": self.step_count, "total_emails": len(self.emails), "final_score": self._compute_final_score() if self.done else None } ) def describe_action_space(self) -> Dict[str, Any]: """Describe the action space""" return { "type": "object", "properties": { "classification": { "type": "string", "enum": [cat.value for cat in EmailCategory], "description": "Email classification category" }, "team": { "type": "string", "enum": [t.value for t in Team], "description": "Team to route email to" }, "priority": { "type": "integer", "minimum": 0, "maximum": 3, "description": "Priority level (0=low, 3=high)" } }, "required": ["classification", "team", "priority"] } def describe_observation_space(self) -> Dict[str, Any]: """Describe the observation space""" return { "type": "object", "properties": { "current_email": { "type": "object", "properties": { "email_id": {"type": "string"}, "subject": {"type": "string"}, "body": {"type": "string"}, "sender_domain": {"type": "string"}, "timestamp": {"type": "string", "format": "date-time"}, "is_vip_sender": {"type": "boolean"}, "sla_hours": {"type": ["integer", "null"]} } }, "inbox_state": { "type": "object", "properties": { "pending": {"type": "integer"}, "spam": {"type": "integer"}, "urgent": {"type": "integer"}, "processed": {"type": "integer"} } }, "step_count": {"type": "integer"}, "task_name": {"type": "string"} } }