from __future__ import annotations from typing import Dict, List, Optional, Tuple from .graders import grade_task from .models import Action, Observation, RewardModel, StateModel, StepInfo, TaskSpec, TicketObservation from .reward import STEP_PENALTY, build_reward from .state import discovered_for_ticket, initial_tracking, update_mapping from .tasks import get_all_tasks, get_task class SupportOpsEnv: """OpenEnv-shaped benchmark for support operations workflows.""" def __init__(self, task_id: Optional[str] = None): self._tasks = {task.task_id: task for task in get_all_tasks()} self._task_order = sorted(self._tasks) self._task_id = task_id or self._task_order[0] self._task: TaskSpec = self._tasks[self._task_id] self._state: StateModel = initial_tracking(self._task) def reset(self, task_id: Optional[str] = None) -> Observation: if task_id is not None: self._task = get_task(task_id) self._task_id = task_id self._state = initial_tracking(self._task) return self._build_observation() def state(self) -> StateModel: return self._state.model_copy(deep=True) def step(self, action: Action) -> Tuple[Observation, RewardModel, bool, Dict[str, object]]: if self._state.done: reward = build_reward({"invalid_after_done": -0.1}, "Episode already finished.") info = StepInfo( task_id=self._task.task_id, step_count=self._state.step_count, task_score=self._state.latest_score.get("task_score", 0.0), done_reason="already_done", event="invalid_after_done", event_score=reward.components, ) return self._build_observation(), reward, True, info.model_dump() self._state.step_count += 1 event_scores: Dict[str, float] = {"step_penalty": STEP_PENALTY} event_name = action.action_type done_reason = None if action.action_type == "inspect_ticket": event_scores.update(self._handle_inspect(action)) elif action.action_type == "request_context": event_scores.update(self._handle_request_context(action)) elif action.action_type == "set_priority": event_scores.update(self._handle_priority(action)) elif action.action_type == "set_route": event_scores.update(self._handle_route(action)) elif action.action_type == "set_resolution": event_scores.update(self._handle_resolution(action)) elif action.action_type == "escalate": event_scores.update(self._handle_escalation(action)) elif action.action_type == "rank_queue": event_scores.update(self._handle_rank_queue(action)) elif action.action_type == "finalize": self._state.done = True done_reason = "agent_finalized" grade = grade_task(self._task, self._state) self._state.latest_score = {"task_score": grade.score, **grade.component_scores} event_scores["terminal_grade"] = grade.score reward = build_reward(event_scores, "Final task grade applied.") self._state.cumulative_reward = round(self._state.cumulative_reward + reward.value, 4) info = StepInfo( task_id=self._task.task_id, step_count=self._state.step_count, task_score=grade.score, done_reason=done_reason, grade=grade, event=event_name, event_score=reward.components, ) return self._build_observation(), reward, True, info.model_dump() else: event_scores["invalid_action"] = -0.1 event_name = "invalid_action" grade = grade_task(self._task, self._state) self._state.latest_score = {"task_score": grade.score, **grade.component_scores} if self._state.step_count >= self._task.max_steps and not self._state.done: self._state.done = True done_reason = "max_steps" event_scores["timeout_grade"] = grade.score reward = build_reward(event_scores, f"Processed {event_name}.") self._state.cumulative_reward = round(self._state.cumulative_reward + reward.value, 4) info = StepInfo( task_id=self._task.task_id, step_count=self._state.step_count, task_score=grade.score, done_reason=done_reason, grade=grade if self._state.done else None, event=event_name, event_score=reward.components, ) return self._build_observation(), reward, self._state.done, info.model_dump() def _build_observation(self) -> Observation: tickets: List[TicketObservation] = [] for ticket in self._task.tickets: keys = self._state.discovered_keys.get(ticket.ticket_id, []) discovered_context = {key: ticket.hidden_context[key] for key in keys} tickets.append( TicketObservation( ticket_id=ticket.ticket_id, summary=ticket.summary, visible_context=ticket.visible_context, discovered_context=discovered_context, selected_priority=self._state.priorities.get(ticket.ticket_id), selected_route=self._state.routes.get(ticket.ticket_id), selected_resolution=self._state.resolutions.get(ticket.ticket_id), escalation_team=self._state.escalations.get(ticket.ticket_id), ) ) return Observation( task_id=self._task.task_id, difficulty=self._task.difficulty, title=self._task.title, instruction=self._task.instruction, queue_mode=self._task.queue_mode, tickets=tickets, remaining_steps=max(self._task.max_steps - self._state.step_count, 0), available_actions=[ "inspect_ticket", "request_context", "set_priority", "set_route", "set_resolution", "escalate", "rank_queue", "finalize", ], current_queue_order=self._state.queue_order, score_hint=self._state.latest_score, ) def _find_ticket(self, ticket_id: str): for ticket in self._task.tickets: if ticket.ticket_id == ticket_id: return ticket return None def _handle_inspect(self, action: Action) -> Dict[str, float]: ticket = self._find_ticket(action.target) if ticket is None: return {"invalid_ticket": -0.1} key = f"inspected::{ticket.ticket_id}" notes = self._state.latest_score.setdefault("inspections", 0.0) if notes and key in self._state.latest_score: return {"redundant_inspect": -0.03} self._state.latest_score[key] = 1.0 return {"inspect": 0.03} def _handle_request_context(self, action: Action) -> Dict[str, float]: ticket = self._find_ticket(action.target) if ticket is None or not action.value: return {"invalid_context_request": -0.1} if action.value not in ticket.hidden_context: return {"unknown_context_key": -0.08} discovered = discovered_for_ticket(self._state.discovered_keys, ticket.ticket_id) if action.value in discovered: return {"redundant_context_request": -0.05} discovered.append(action.value) if action.value in ticket.required_context: return {"required_context_found": 0.12} return {"optional_context_found": 0.04} def _handle_priority(self, action: Action) -> Dict[str, float]: ticket = self._find_ticket(action.target) if ticket is None or not action.value: return {"invalid_priority": -0.1} current = self._state.priorities.get(ticket.ticket_id) update_mapping(self._state.priorities, ticket.ticket_id, action.value) if action.value == current: return {"redundant_priority": -0.03} return {"priority_match": 0.08 if action.value == ticket.gold_priority else -0.04} def _handle_route(self, action: Action) -> Dict[str, float]: ticket = self._find_ticket(action.target) if ticket is None or not action.value: return {"invalid_route": -0.1} current = self._state.routes.get(ticket.ticket_id) update_mapping(self._state.routes, ticket.ticket_id, action.value) if action.value == current: return {"redundant_route": -0.03} return {"route_match": 0.1 if action.value == ticket.gold_route else -0.06} def _handle_resolution(self, action: Action) -> Dict[str, float]: ticket = self._find_ticket(action.target) if ticket is None or not action.value: return {"invalid_resolution": -0.1} current = self._state.resolutions.get(ticket.ticket_id) update_mapping(self._state.resolutions, ticket.ticket_id, action.value) if action.value == current: return {"redundant_resolution": -0.03} return {"resolution_match": 0.12 if action.value == ticket.gold_resolution else -0.08} def _handle_escalation(self, action: Action) -> Dict[str, float]: ticket = self._find_ticket(action.target) if ticket is None: return {"invalid_escalation": -0.1} team = action.value current = self._state.escalations.get(ticket.ticket_id) update_mapping(self._state.escalations, ticket.ticket_id, team) if team == current: return {"redundant_escalation": -0.03} if team == ticket.gold_escalation_team: return {"correct_escalation": 0.1} if ticket.gold_escalation_team is None and team is None: return {"correct_no_escalation": 0.03} return {"incorrect_escalation": -0.1} def _handle_rank_queue(self, action: Action) -> Dict[str, float]: if not self._task.queue_mode or not action.value: return {"invalid_queue_ranking": -0.1} ranked = [item.strip() for item in action.value.split(",") if item.strip()] valid_ticket_ids = {ticket.ticket_id for ticket in self._task.tickets} if set(ranked) != valid_ticket_ids: return {"malformed_queue_ranking": -0.08} self._state.queue_order = ranked correct_positions = sum( 1 for observed, expected in zip(ranked, self._task.gold_queue_order) if observed == expected ) return {"queue_progress": round((correct_positions / len(ranked)) * 0.12, 4)}