Spaces:
Sleeping
Sleeping
| from environment.types import ( | |
| Email, Observation, Action, Reward, State, GroundTruth, | |
| EmailCategory, Team | |
| ) | |
| from environment.data_generator import DataGenerator | |
| from environment.graders import ( | |
| SpamDetectionGrader, MultiClassRoutingGrader, | |
| ContextAwareTriageGrader, compute_step_reward | |
| ) | |
| from datetime import datetime | |
| from typing import Tuple, Dict, Any, List, Optional | |
| class EmailTriageEnv: | |
| """Main email triage environment implementing OpenEnv spec""" | |
| def __init__(self, task_name: str = "spam_detection"): | |
| self.task_name = task_name | |
| self.generator = DataGenerator() | |
| self.step_count = 0 | |
| self.current_email_idx = 0 | |
| self.actions_taken = [] | |
| self.rewards_accumulated = 0.0 | |
| self.done = False | |
| # Data for current task | |
| self.emails: List[Email] = [] | |
| self.ground_truths: List[GroundTruth] = [] | |
| self.current_observation: Optional[Observation] = None | |
| # Set up task | |
| self._setup_task(task_name) | |
| def _setup_task(self, task_name: str): | |
| """Initialize task-specific data""" | |
| if task_name == "spam_detection": | |
| self.emails, self.ground_truths = self.generator.generate_task1_emails() | |
| self.grader = SpamDetectionGrader() | |
| elif task_name == "multi_class_routing": | |
| self.emails, self.ground_truths = self.generator.generate_task2_emails() | |
| self.grader = MultiClassRoutingGrader() | |
| elif task_name == "context_aware_triage": | |
| self.emails, self.ground_truths = self.generator.generate_task3_emails() | |
| self.grader = ContextAwareTriageGrader() | |
| else: | |
| raise ValueError(f"Unknown task: {task_name}") | |
| def reset(self) -> Observation: | |
| """Reset environment to initial state""" | |
| self.step_count = 0 | |
| self.current_email_idx = 0 | |
| self.actions_taken = [] | |
| self.rewards_accumulated = 0.0 | |
| self.done = False | |
| # Get first email | |
| if self.emails: | |
| return self._get_observation() | |
| return Observation( | |
| current_email=Email( | |
| email_id="none", | |
| subject="", | |
| body="", | |
| sender_domain="", | |
| timestamp=datetime.now() | |
| ), | |
| inbox_state={"pending": 0, "spam": 0, "urgent": 0, "processed": 0}, | |
| step_count=0, | |
| task_name=self.task_name | |
| ) | |
| def _get_observation(self) -> Observation: | |
| """Get observation for current email""" | |
| if self.current_email_idx >= len(self.emails): | |
| # End of task | |
| self.done = True | |
| return Observation( | |
| current_email=Email( | |
| email_id="done", | |
| subject="Task Complete", | |
| body="All emails processed", | |
| sender_domain="", | |
| timestamp=datetime.now() | |
| ), | |
| inbox_state={ | |
| "pending": 0, | |
| "spam": len([t for t in self.ground_truths if t.category == EmailCategory.SPAM]), | |
| "urgent": len([t for t in self.ground_truths if t.category == EmailCategory.URGENT]), | |
| "processed": self.current_email_idx | |
| }, | |
| step_count=self.step_count, | |
| task_name=self.task_name, | |
| info={"done": True, "final_score": self._compute_final_score()} | |
| ) | |
| current_email = self.emails[self.current_email_idx] | |
| inbox_state = { | |
| "pending": len(self.emails) - self.current_email_idx, | |
| "spam": len([t for t in self.ground_truths[self.current_email_idx:] if t.category == EmailCategory.SPAM]), | |
| "urgent": len([t for t in self.ground_truths[self.current_email_idx:] if t.category == EmailCategory.URGENT]), | |
| "processed": self.current_email_idx | |
| } | |
| return Observation( | |
| current_email=current_email, | |
| inbox_state=inbox_state, | |
| step_count=self.step_count, | |
| task_name=self.task_name | |
| ) | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: | |
| """Process one email with the given action""" | |
| if self.current_email_idx >= len(self.emails): | |
| self.done = True | |
| reward = Reward(value=0.0) | |
| obs = self._get_observation() | |
| return obs, reward, True, {} | |
| # Get ground truth for current email | |
| ground_truth = self.ground_truths[self.current_email_idx] | |
| # Compute reward for this step | |
| step_reward, breakdown = compute_step_reward(action, ground_truth) | |
| reward = Reward( | |
| value=step_reward, | |
| breakdown=breakdown | |
| ) | |
| self.actions_taken.append(action) | |
| self.rewards_accumulated += step_reward | |
| self.step_count += 1 | |
| self.current_email_idx += 1 | |
| # Check if done | |
| if self.current_email_idx >= len(self.emails): | |
| self.done = True | |
| # Get next observation | |
| next_obs = self._get_observation() | |
| info = { | |
| "email_id": ground_truth.email_id, | |
| "ground_truth_category": ground_truth.category, | |
| "ground_truth_team": ground_truth.team, | |
| "ground_truth_priority": ground_truth.priority, | |
| "action_classification": action.classification, | |
| "action_team": action.team, | |
| "action_priority": action.priority, | |
| } | |
| if self.done: | |
| info["final_score"] = self._compute_final_score() | |
| info["task_complete"] = True | |
| return next_obs, reward, self.done, info | |
| def _compute_final_score(self) -> float: | |
| """Compute final task score""" | |
| if not self.actions_taken: | |
| return 0.0 | |
| return self.grader.score_actions(self.actions_taken, self.ground_truths) | |
| def state(self) -> State: | |
| """Return current complete state""" | |
| return State( | |
| current_observation=self.current_observation or self._get_observation(), | |
| current_reward=self.rewards_accumulated / max(1, self.step_count), | |
| done=self.done, | |
| history=[ | |
| { | |
| "step": i, | |
| "action": action.model_dump(), | |
| "ground_truth": truth.model_dump(), | |
| "email_id": truth.email_id | |
| } | |
| for i, (action, truth) in enumerate(zip(self.actions_taken, self.ground_truths)) | |
| ], | |
| info={ | |
| "task_name": self.task_name, | |
| "step_count": self.step_count, | |
| "total_emails": len(self.emails), | |
| "final_score": self._compute_final_score() if self.done else None | |
| } | |
| ) | |
| def describe_action_space(self) -> Dict[str, Any]: | |
| """Describe the action space""" | |
| return { | |
| "type": "object", | |
| "properties": { | |
| "classification": { | |
| "type": "string", | |
| "enum": [cat.value for cat in EmailCategory], | |
| "description": "Email classification category" | |
| }, | |
| "team": { | |
| "type": "string", | |
| "enum": [t.value for t in Team], | |
| "description": "Team to route email to" | |
| }, | |
| "priority": { | |
| "type": "integer", | |
| "minimum": 0, | |
| "maximum": 3, | |
| "description": "Priority level (0=low, 3=high)" | |
| } | |
| }, | |
| "required": ["classification", "team", "priority"] | |
| } | |
| def describe_observation_space(self) -> Dict[str, Any]: | |
| """Describe the observation space""" | |
| return { | |
| "type": "object", | |
| "properties": { | |
| "current_email": { | |
| "type": "object", | |
| "properties": { | |
| "email_id": {"type": "string"}, | |
| "subject": {"type": "string"}, | |
| "body": {"type": "string"}, | |
| "sender_domain": {"type": "string"}, | |
| "timestamp": {"type": "string", "format": "date-time"}, | |
| "is_vip_sender": {"type": "boolean"}, | |
| "sla_hours": {"type": ["integer", "null"]} | |
| } | |
| }, | |
| "inbox_state": { | |
| "type": "object", | |
| "properties": { | |
| "pending": {"type": "integer"}, | |
| "spam": {"type": "integer"}, | |
| "urgent": {"type": "integer"}, | |
| "processed": {"type": "integer"} | |
| } | |
| }, | |
| "step_count": {"type": "integer"}, | |
| "task_name": {"type": "string"} | |
| } | |
| } | |