parthpethia's picture
Add Email Triage OpenEnv environment - production-ready with 3 graded tasks and Flask API
fee8744
from environment.types import (
Email, Observation, Action, Reward, State, GroundTruth,
EmailCategory, Team
)
from environment.data_generator import DataGenerator
from environment.graders import (
SpamDetectionGrader, MultiClassRoutingGrader,
ContextAwareTriageGrader, compute_step_reward
)
from datetime import datetime
from typing import Tuple, Dict, Any, List, Optional
class EmailTriageEnv:
"""Main email triage environment implementing OpenEnv spec"""
def __init__(self, task_name: str = "spam_detection"):
self.task_name = task_name
self.generator = DataGenerator()
self.step_count = 0
self.current_email_idx = 0
self.actions_taken = []
self.rewards_accumulated = 0.0
self.done = False
# Data for current task
self.emails: List[Email] = []
self.ground_truths: List[GroundTruth] = []
self.current_observation: Optional[Observation] = None
# Set up task
self._setup_task(task_name)
def _setup_task(self, task_name: str):
"""Initialize task-specific data"""
if task_name == "spam_detection":
self.emails, self.ground_truths = self.generator.generate_task1_emails()
self.grader = SpamDetectionGrader()
elif task_name == "multi_class_routing":
self.emails, self.ground_truths = self.generator.generate_task2_emails()
self.grader = MultiClassRoutingGrader()
elif task_name == "context_aware_triage":
self.emails, self.ground_truths = self.generator.generate_task3_emails()
self.grader = ContextAwareTriageGrader()
else:
raise ValueError(f"Unknown task: {task_name}")
def reset(self) -> Observation:
"""Reset environment to initial state"""
self.step_count = 0
self.current_email_idx = 0
self.actions_taken = []
self.rewards_accumulated = 0.0
self.done = False
# Get first email
if self.emails:
return self._get_observation()
return Observation(
current_email=Email(
email_id="none",
subject="",
body="",
sender_domain="",
timestamp=datetime.now()
),
inbox_state={"pending": 0, "spam": 0, "urgent": 0, "processed": 0},
step_count=0,
task_name=self.task_name
)
def _get_observation(self) -> Observation:
"""Get observation for current email"""
if self.current_email_idx >= len(self.emails):
# End of task
self.done = True
return Observation(
current_email=Email(
email_id="done",
subject="Task Complete",
body="All emails processed",
sender_domain="",
timestamp=datetime.now()
),
inbox_state={
"pending": 0,
"spam": len([t for t in self.ground_truths if t.category == EmailCategory.SPAM]),
"urgent": len([t for t in self.ground_truths if t.category == EmailCategory.URGENT]),
"processed": self.current_email_idx
},
step_count=self.step_count,
task_name=self.task_name,
info={"done": True, "final_score": self._compute_final_score()}
)
current_email = self.emails[self.current_email_idx]
inbox_state = {
"pending": len(self.emails) - self.current_email_idx,
"spam": len([t for t in self.ground_truths[self.current_email_idx:] if t.category == EmailCategory.SPAM]),
"urgent": len([t for t in self.ground_truths[self.current_email_idx:] if t.category == EmailCategory.URGENT]),
"processed": self.current_email_idx
}
return Observation(
current_email=current_email,
inbox_state=inbox_state,
step_count=self.step_count,
task_name=self.task_name
)
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
"""Process one email with the given action"""
if self.current_email_idx >= len(self.emails):
self.done = True
reward = Reward(value=0.0)
obs = self._get_observation()
return obs, reward, True, {}
# Get ground truth for current email
ground_truth = self.ground_truths[self.current_email_idx]
# Compute reward for this step
step_reward, breakdown = compute_step_reward(action, ground_truth)
reward = Reward(
value=step_reward,
breakdown=breakdown
)
self.actions_taken.append(action)
self.rewards_accumulated += step_reward
self.step_count += 1
self.current_email_idx += 1
# Check if done
if self.current_email_idx >= len(self.emails):
self.done = True
# Get next observation
next_obs = self._get_observation()
info = {
"email_id": ground_truth.email_id,
"ground_truth_category": ground_truth.category,
"ground_truth_team": ground_truth.team,
"ground_truth_priority": ground_truth.priority,
"action_classification": action.classification,
"action_team": action.team,
"action_priority": action.priority,
}
if self.done:
info["final_score"] = self._compute_final_score()
info["task_complete"] = True
return next_obs, reward, self.done, info
def _compute_final_score(self) -> float:
"""Compute final task score"""
if not self.actions_taken:
return 0.0
return self.grader.score_actions(self.actions_taken, self.ground_truths)
def state(self) -> State:
"""Return current complete state"""
return State(
current_observation=self.current_observation or self._get_observation(),
current_reward=self.rewards_accumulated / max(1, self.step_count),
done=self.done,
history=[
{
"step": i,
"action": action.model_dump(),
"ground_truth": truth.model_dump(),
"email_id": truth.email_id
}
for i, (action, truth) in enumerate(zip(self.actions_taken, self.ground_truths))
],
info={
"task_name": self.task_name,
"step_count": self.step_count,
"total_emails": len(self.emails),
"final_score": self._compute_final_score() if self.done else None
}
)
def describe_action_space(self) -> Dict[str, Any]:
"""Describe the action space"""
return {
"type": "object",
"properties": {
"classification": {
"type": "string",
"enum": [cat.value for cat in EmailCategory],
"description": "Email classification category"
},
"team": {
"type": "string",
"enum": [t.value for t in Team],
"description": "Team to route email to"
},
"priority": {
"type": "integer",
"minimum": 0,
"maximum": 3,
"description": "Priority level (0=low, 3=high)"
}
},
"required": ["classification", "team", "priority"]
}
def describe_observation_space(self) -> Dict[str, Any]:
"""Describe the observation space"""
return {
"type": "object",
"properties": {
"current_email": {
"type": "object",
"properties": {
"email_id": {"type": "string"},
"subject": {"type": "string"},
"body": {"type": "string"},
"sender_domain": {"type": "string"},
"timestamp": {"type": "string", "format": "date-time"},
"is_vip_sender": {"type": "boolean"},
"sla_hours": {"type": ["integer", "null"]}
}
},
"inbox_state": {
"type": "object",
"properties": {
"pending": {"type": "integer"},
"spam": {"type": "integer"},
"urgent": {"type": "integer"},
"processed": {"type": "integer"}
}
},
"step_count": {"type": "integer"},
"task_name": {"type": "string"}
}
}