Spaces:
Sleeping
Sleeping
| """ | |
| EmailTriageEnv v2 β OpenEnv-compliant email triage with sequential state. | |
| What makes this a true sequential decision problem (not just N independent | |
| classifications): | |
| 1. ESCALATION BUDGET β The agent has a fixed number of flag_review=True | |
| uses per episode. Wasting budget on low-priority emails means critical | |
| ones cannot be escalated later. Budget is visible in every observation. | |
| 2. TEAM QUEUE CAPACITY β Each routing destination has a finite capacity. | |
| Routing too many emails to legal/tier2/management saturates the queue; | |
| subsequent emails routed there incur an overflow penalty and the agent | |
| must adapt (e.g. route to management instead of legal when legal is full). | |
| 3. SLA DECAY TIMERS β Every email has a deadline relative to when it arrived. | |
| If the agent processes low-priority emails first and leaves urgent ones | |
| untouched, SLA breach events fire automatically at the start of each step, | |
| penalising the agent. Processing order therefore matters. | |
| 4. CRITICAL CASCADE β If 2+ urgent emails breach SLA in the same episode, | |
| a one-time cascade penalty triggers and the observation marks | |
| cascade_active=True, signalling compounding organisational damage. | |
| These mechanics mean: | |
| - Agent must read future inbox to plan processing ORDER (planning horizon). | |
| - Agent must ration escalations across the full episode (budget constraint). | |
| - Routing choices affect available routes for later emails (resource constraint). | |
| - Greedy per-email optimisation is strictly suboptimal. | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from models import ( | |
| Action, EmailMessage, EmailHeader, EnvironmentState, | |
| Observation, Reward, RewardBreakdown, | |
| SessionConstraints, SlaStatus, | |
| TEAM_CAPACITY, SLA_STEPS, | |
| ) | |
| from dataset import EASY_EMAILS, MEDIUM_EMAILS, HARD_EMAILS | |
| from grader import score_action, grade_episode | |
| TASK_DATASETS: Dict[str, List[Dict[str, Any]]] = { | |
| "easy": EASY_EMAILS, | |
| "medium": MEDIUM_EMAILS, | |
| "hard": HARD_EMAILS, | |
| } | |
| # Per-task escalation budgets (hard is tightest relative to true need) | |
| TASK_ESCALATION_BUDGET: Dict[str, int] = { | |
| "easy": 3, # 2 emails truly need escalation β budget=3, comfortable | |
| "medium": 4, # 4 emails truly need escalation β budget=4, exact | |
| "hard": 5, # 7 emails truly need escalation β budget=5, must choose | |
| } | |
| TASK_DESCRIPTIONS: Dict[str, str] = { | |
| "easy": ( | |
| "Triage 5 emails. Escalation budget=3 (2 truly required). " | |
| "Team queues are generous. SLA deadlines are forgiving. " | |
| "Expected score for a competent agent: 0.75β0.90." | |
| ), | |
| "medium": ( | |
| "Triage 8 emails. Escalation budget=4 (exactly matching true need). " | |
| "legal and support_tier2 queues can saturate if misused. " | |
| "Processing order affects SLA breaches. " | |
| "Expected score: 0.55β0.75." | |
| ), | |
| "hard": ( | |
| "Triage 10 emails. Escalation budget=5 but 7 emails truly require " | |
| "escalation β agent must choose which 5 matter most. " | |
| "legal queue capacity=2, management=2; overflow forces creative routing. " | |
| "Multiple urgent SLA timers run simultaneously. " | |
| "Expected score: 0.35β0.60." | |
| ), | |
| } | |
| def _build_email_message(raw: Dict[str, Any]) -> EmailMessage: | |
| header = EmailHeader(**raw["email"]["header"]) | |
| return EmailMessage( | |
| header=header, | |
| body=raw["email"]["body"], | |
| metadata=raw["email"].get("metadata", {}), | |
| ) | |
| class EmailTriageEnv: | |
| """ | |
| OpenEnv-compliant Email Triage environment with sequential state. | |
| """ | |
| ENV_ID = "email-triage-v1" | |
| VERSION = "2.0.0" | |
| MAX_STEPS = 60 | |
| def __init__(self, task_id: str = "easy", seed: Optional[int] = None) -> None: | |
| if task_id not in TASK_DATASETS: | |
| raise ValueError(f"task_id must be one of {list(TASK_DATASETS.keys())}") | |
| self.task_id = task_id | |
| self.seed = seed | |
| self._dataset = TASK_DATASETS[task_id] | |
| # Runtime state | |
| self._emails: List[EmailMessage] = [] | |
| self._processed_ids: List[str] = [] | |
| self._actions_log: List[Dict[str, Any]] = [] | |
| self._step_num: int = 0 | |
| self._done: bool = False | |
| self._cumulative_reward: float = 0.0 | |
| self._constraints: SessionConstraints = SessionConstraints() | |
| # ββ OpenEnv Interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self) -> Observation: | |
| """Reset all state and return initial observation.""" | |
| self._emails = [_build_email_message(e) for e in self._dataset] | |
| self._processed_ids = [] | |
| self._actions_log = [] | |
| self._step_num = 0 | |
| self._done = False | |
| self._cumulative_reward = 0.0 | |
| budget = TASK_ESCALATION_BUDGET[self.task_id] | |
| self._constraints = SessionConstraints(escalation_budget=budget) | |
| # Register every email in the SLA tracker immediately | |
| for i, raw in enumerate(self._dataset): | |
| gt_priority = raw["ground_truth"]["priority"] | |
| deadline = i + SLA_STEPS.get(gt_priority, 99) | |
| self._constraints.sla_tracker.append(SlaStatus( | |
| email_id = raw["email"]["header"]["email_id"], | |
| true_priority = gt_priority, | |
| arrived_at_step = i, # emails are revealed sequentially | |
| deadline_step = deadline, | |
| )) | |
| return self._make_observation() | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: | |
| """ | |
| Process one agent action. | |
| Sequential effects applied each step (before scoring the action): | |
| β’ SLA breach check: any email whose deadline < current step fires a penalty. | |
| β’ Cascade check: β₯2 urgent SLA breaches triggers cascade_active. | |
| Then the action is scored, and sequential resource effects are applied: | |
| β’ Escalation budget decremented if flag_review=True. | |
| β’ Team queue capacity decremented for route_to destination. | |
| """ | |
| if self._done: | |
| raise RuntimeError("Episode done. Call reset() to start a new one.") | |
| if self._step_num >= self.MAX_STEPS: | |
| self._done = True | |
| return ( | |
| self._make_observation(), | |
| Reward(total=0.0, done=True, info={"reason": "max_steps_exceeded"}), | |
| True, {"reason": "max_steps_exceeded"}, | |
| ) | |
| breakdown = RewardBreakdown() | |
| # ββ A. SLA decay: fire breach penalties for unprocessed overdue emails ββ | |
| sla_penalty = self._tick_sla(breakdown) | |
| # ββ B. Validate email_id ββββββββββββββββββββββββββββββββββββββββββββββ | |
| inbox_ids = {e.header.email_id for e in self._emails} | |
| if action.email_id not in inbox_ids: | |
| breakdown.base_penalty = -0.1 | |
| total = max(-1.0, sla_penalty - 0.1) | |
| r = Reward(total=total, breakdown=breakdown, done=False, | |
| info={"error": f"email_id '{action.email_id}' not in inbox"}) | |
| self._step_num += 1 | |
| self._cumulative_reward += r.total | |
| return self._make_observation(), r, False, r.info | |
| # ββ C. Score the classification decision (label correctness) ββββββββββ | |
| label_reward, detail = score_action(action) | |
| breakdown.priority_score = label_reward.breakdown.priority_score | |
| breakdown.category_score = label_reward.breakdown.category_score | |
| breakdown.routing_score = label_reward.breakdown.routing_score | |
| breakdown.summary_score = label_reward.breakdown.summary_score | |
| breakdown.escalation_score = label_reward.breakdown.escalation_score | |
| breakdown.base_penalty = label_reward.breakdown.base_penalty | |
| # ββ D. Sequential resource effects ββββββββββββββββββββββββββββββββββββ | |
| # D1. Escalation budget | |
| budget_penalty = 0.0 | |
| if action.flag_review: | |
| if self._constraints.escalations_used >= self._constraints.escalation_budget: | |
| # Went over budget β escalation is silently dropped and penalised | |
| budget_penalty = -0.20 | |
| breakdown.budget_penalty = budget_penalty | |
| # Force flag_review=False for grading purposes (budget exhausted) | |
| detail["budget_overflow"] = True | |
| else: | |
| self._constraints.escalations_used += 1 | |
| # D2. Team queue capacity | |
| queue_penalty = 0.0 | |
| route_key = action.route_to.value | |
| if route_key not in ("trash", "archive"): | |
| accepted = self._constraints.team_queues.consume(route_key) | |
| if not accepted: | |
| queue_penalty = -0.10 | |
| breakdown.queue_penalty = queue_penalty | |
| self._constraints.queue_overflows += 1 | |
| detail["queue_overflow"] = route_key | |
| # ββ E. Mark email as processed; remove from inbox βββββββββββββββββββββ | |
| self._actions_log.append(action.model_dump()) | |
| self._processed_ids.append(action.email_id) | |
| self._emails = [e for e in self._emails if e.header.email_id != action.email_id] | |
| # Mark SLA entry as handled (no further breach risk) | |
| for sla in self._constraints.sla_tracker: | |
| if sla.email_id == action.email_id: | |
| sla.breached = True # "handled" β no future breach fires | |
| break | |
| # ββ F. Cascade check ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cascade_penalty = 0.0 | |
| urgent_breaches = sum( | |
| 1 for s in self._constraints.sla_tracker | |
| if s.breached and s.true_priority == "urgent" | |
| and s.deadline_step <= self._step_num # breached late, not just handled | |
| ) | |
| if urgent_breaches >= 2 and not self._constraints.cascade_triggered: | |
| self._constraints.cascade_triggered = True | |
| cascade_penalty = -0.25 | |
| breakdown.cascade_penalty = cascade_penalty | |
| # ββ G. Compute total reward βββββββββββββββββββββββββββββββββββββββββββ | |
| label_total = label_reward.total # already clamped [0,1] | |
| sequential_penalties = sla_penalty + budget_penalty + queue_penalty + cascade_penalty | |
| total = max(-1.0, min(1.0, label_total + sequential_penalties)) | |
| breakdown.sla_penalty = sla_penalty | |
| self._step_num += 1 | |
| self._cumulative_reward += total | |
| self._done = len(self._emails) == 0 | |
| detail.update({ | |
| "sla_penalty": sla_penalty, | |
| "queue_penalty": queue_penalty, | |
| "budget_penalty": budget_penalty, | |
| "cascade_penalty": cascade_penalty, | |
| "escalations_remaining": ( | |
| self._constraints.escalation_budget - self._constraints.escalations_used | |
| ), | |
| }) | |
| reward = Reward(total=round(total, 4), breakdown=breakdown, | |
| done=self._done, info=detail) | |
| if self._done: | |
| reward.info["episode_summary"] = grade_episode(self._actions_log) | |
| reward.info["final_constraints"] = self._constraints_dict() | |
| return self._make_observation(), reward, self._done, reward.info | |
| def state(self) -> EnvironmentState: | |
| """Full internal state snapshot (for logging / debugging).""" | |
| scores = {} | |
| if self._actions_log: | |
| s = grade_episode(self._actions_log) | |
| scores = {"running_score": s["overall_score"], "emails_triaged": s["num_emails"]} | |
| return EnvironmentState( | |
| task_id = self.task_id, | |
| step = self._step_num, | |
| done = self._done, | |
| observation = self._make_observation(), | |
| cumulative_reward = round(self._cumulative_reward, 4), | |
| actions_taken = copy.deepcopy(self._actions_log), | |
| grader_scores = scores, | |
| constraints = self._constraints_dict(), | |
| ) | |
| # ββ Private helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _tick_sla(self, breakdown: RewardBreakdown) -> float: | |
| """ | |
| Check for SLA breaches among unprocessed emails. | |
| A breach fires the first time step_num > deadline_step. | |
| Returns total SLA penalty for this step. | |
| """ | |
| penalty = 0.0 | |
| processed_set = set(self._processed_ids) | |
| for sla in self._constraints.sla_tracker: | |
| if sla.email_id in processed_set: | |
| continue # already handled | |
| if sla.breached: | |
| continue # already penalised | |
| if self._step_num >= sla.deadline_step: | |
| sla.breached = True # mark as breached (not just handled) | |
| self._constraints.sla_breaches += 1 | |
| penalty -= 0.15 | |
| return penalty | |
| def _make_observation(self) -> Observation: | |
| c = self._constraints | |
| # Build SLA warnings for emails still in inbox | |
| processed_set = set(self._processed_ids) | |
| warnings = [] | |
| for sla in c.sla_tracker: | |
| if sla.email_id in processed_set: | |
| continue | |
| steps_left = sla.deadline_step - self._step_num | |
| if steps_left <= 2 and not sla.breached: | |
| warnings.append({ | |
| "email_id": sla.email_id, | |
| "priority": sla.true_priority, | |
| "steps_left": max(0, steps_left), | |
| "overdue": steps_left < 0, | |
| }) | |
| from models import Priority, Category, RouteTo | |
| queue_dict = { | |
| k: c.team_queues.remaining(k) | |
| for k in TEAM_CAPACITY | |
| } | |
| return Observation( | |
| inbox = list(self._emails), | |
| processed = list(self._processed_ids), | |
| current_email = self._emails[0] if self._emails else None, | |
| step_number = self._step_num, | |
| total_emails = len(self._dataset), | |
| remaining = len(self._emails), | |
| escalation_budget_remaining = ( | |
| c.escalation_budget - c.escalations_used | |
| ), | |
| team_queue_remaining = queue_dict, | |
| active_sla_warnings = warnings, | |
| sla_breaches_so_far = c.sla_breaches, | |
| cascade_active = c.cascade_triggered, | |
| session_info = { | |
| "task_id": self.task_id, | |
| "task_description": TASK_DESCRIPTIONS[self.task_id], | |
| "cumulative_reward": round(self._cumulative_reward, 4), | |
| "action_space": { | |
| "priority": [p.value for p in Priority], | |
| "category": [c2.value for c2 in Category], | |
| "route_to": [r.value for r in RouteTo], | |
| "summary": "string β€280 chars", | |
| "flag_review": "bool β uses escalation budget", | |
| "reasoning": "string, not scored", | |
| }, | |
| "constraints_info": { | |
| "escalation_budget": c.escalation_budget, | |
| "escalations_used": c.escalations_used, | |
| "sla_breaches": c.sla_breaches, | |
| "queue_overflows": c.queue_overflows, | |
| "cascade_triggered": c.cascade_triggered, | |
| }, | |
| }, | |
| ) | |
| def _constraints_dict(self) -> Dict[str, Any]: | |
| c = self._constraints | |
| return { | |
| "escalation_budget": c.escalation_budget, | |
| "escalations_used": c.escalations_used, | |
| "escalations_remaining": c.escalation_budget - c.escalations_used, | |
| "sla_breaches": c.sla_breaches, | |
| "queue_overflows": c.queue_overflows, | |
| "cascade_triggered": c.cascade_triggered, | |
| "team_queues": {k: c.team_queues.remaining(k) for k in TEAM_CAPACITY}, | |
| } | |
| def is_done(self) -> bool: | |
| return self._done | |
| def __repr__(self) -> str: | |
| c = self._constraints | |
| return ( | |
| f"EmailTriageEnv(task={self.task_id}, step={self._step_num}, " | |
| f"budget_left={c.escalation_budget - c.escalations_used}, " | |
| f"sla_breaches={c.sla_breaches}, done={self._done})" | |
| ) | |