from __future__ import annotations import copy from typing import Any from .graders import grade_task from .models import Action, Observation, Reward, StepInfo, TicketView from .tasks import TaskSpec, get_tasks class SupportTriageEnv: """ OpenEnv-compatible environment for customer support ticket triage. API: - reset(task_id: str | None = None) -> Observation - step(action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]] - state() -> dict[str, Any] """ def __init__(self) -> None: self._tasks: dict[str, TaskSpec] = {t.task_id: t for t in get_tasks()} self._task_order = [t.task_id for t in get_tasks()] self._task_index = 0 self._current_task: TaskSpec | None = None self._state: dict[str, Any] = {} @property def task_ids(self) -> list[str]: return list(self._task_order) def reset(self, task_id: str | None = None) -> Observation: if task_id is None: task_id = self._task_order[self._task_index % len(self._task_order)] self._task_index += 1 if task_id not in self._tasks: raise ValueError(f"Unknown task_id '{task_id}'. Available: {sorted(self._tasks.keys())}") self._current_task = self._tasks[task_id] self._state = { "step_count": 0, "read_ticket_ids": set(), "selected_ticket_id": None, "classification": None, "draft_reply": None, "resolved": False, "resolved_ticket_id": None, "invalid_actions": 0, "repeat_actions": 0, "action_history": [], "last_note": "Environment reset.", "done": False, "done_reason": "ongoing", } return self._build_observation() def step(self, action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]]: if self._current_task is None: raise RuntimeError("Call reset() before step().") if self._state["done"]: raise RuntimeError("Episode already done. Call reset() for a new episode.") task = self._current_task st = self._state st["step_count"] += 1 action_fingerprint = action.model_dump_json() if st["action_history"] and st["action_history"][-1] == action_fingerprint: st["repeat_actions"] += 1 st["action_history"].append(action_fingerprint) valid_ticket_ids = {t["ticket_id"] for t in task.tickets} step_penalty = 0.0 if action.action_type in {"read_ticket", "classify_ticket", "resolve_ticket"}: if not action.ticket_id or action.ticket_id not in valid_ticket_ids: st["invalid_actions"] += 1 st["last_note"] = "Invalid or missing ticket_id." step_penalty -= 0.03 if st["invalid_actions"] >= 3: st["done"] = True st["done_reason"] = "invalid_action" return self._assemble_step_response(step_penalty) if action.action_type == "read_ticket": st["read_ticket_ids"].add(action.ticket_id) st["selected_ticket_id"] = action.ticket_id st["last_note"] = f"Read ticket {action.ticket_id}." elif action.action_type == "classify_ticket": if action.ticket_id != task.target_ticket_id: step_penalty -= 0.01 st["classification"] = { "ticket_id": action.ticket_id, "priority": action.priority, "category": action.category, "needs_escalation": action.needs_escalation, } st["last_note"] = f"Saved classification for {action.ticket_id}." elif action.action_type == "draft_reply": text = (action.message or "").strip() if not text: st["invalid_actions"] += 1 st["last_note"] = "Draft reply is empty." step_penalty -= 0.02 else: st["draft_reply"] = text st["last_note"] = "Draft reply saved." elif action.action_type == "resolve_ticket": st["resolved"] = True st["resolved_ticket_id"] = action.ticket_id st["done"] = True st["done_reason"] = "resolved" st["last_note"] = f"Resolved ticket {action.ticket_id}." else: st["invalid_actions"] += 1 st["last_note"] = f"Unknown action {action.action_type}." step_penalty -= 0.03 if st["step_count"] >= task.max_steps and not st["done"]: st["done"] = True st["done_reason"] = "max_steps" st["last_note"] = "Reached max_steps." if st["repeat_actions"] > 0: step_penalty -= min(0.04, 0.01 * st["repeat_actions"]) return self._assemble_step_response(step_penalty) def state(self) -> dict[str, Any]: if self._current_task is None: raise RuntimeError("Environment not initialized. Call reset() first.") visible = copy.deepcopy(self._state) visible["read_ticket_ids"] = sorted(list(visible["read_ticket_ids"])) visible["task_id"] = self._current_task.task_id return visible def _build_observation(self) -> Observation: assert self._current_task is not None task = self._current_task st = self._state content = None if st.get("selected_ticket_id") in st["read_ticket_ids"]: ticket = next(t for t in task.tickets if t["ticket_id"] == st["selected_ticket_id"]) content = ticket["content"] inbox = [ TicketView( ticket_id=t["ticket_id"], subject=t["subject"], customer_tier=t["customer_tier"], age_minutes=t["age_minutes"], read=t["ticket_id"] in st["read_ticket_ids"], ) for t in task.tickets ] partial = grade_task(task, st) return Observation( task_id=task.task_id, objective=task.objective, step_count=st["step_count"], max_steps=task.max_steps, inbox=inbox, current_ticket_content=content, latest_system_note=st.get("last_note", ""), score_hint={ "read": partial.read_score, "classify": partial.classify_score, "reply": partial.reply_score, "resolve": partial.resolve_score, }, ) def _assemble_step_response(self, step_penalty: float) -> tuple[Observation, Reward, bool, dict[str, Any]]: assert self._current_task is not None task = self._current_task st = self._state grade = grade_task(task, st) progress_signal = 0.75 * grade.total penalty_total = 0.0 penalties: dict[str, float] = {} if st["invalid_actions"]: penalties["invalid_actions"] = round(min(0.2, 0.04 * st["invalid_actions"]), 4) penalty_total += penalties["invalid_actions"] if st["repeat_actions"]: penalties["repetition"] = round(min(0.15, 0.02 * st["repeat_actions"]), 4) penalty_total += penalties["repetition"] if step_penalty < 0: penalties["step_penalty"] = round(abs(step_penalty), 4) penalty_total += abs(step_penalty) reward_value = progress_signal - penalty_total if st["done"]: reward_value = max(reward_value, grade.total) reward_value = max(0.0, min(1.0, reward_value)) reward = Reward( value=round(reward_value, 4), components={ "progress_signal": round(progress_signal, 4), "grade_total": grade.total, "read_score": grade.read_score, "classify_score": grade.classify_score, "reply_score": grade.reply_score, "resolve_score": grade.resolve_score, "penalty_total": round(penalty_total, 4), }, reasoning="Shaped reward from grader progress with penalties for invalid or looping actions.", ) info = StepInfo( task_id=task.task_id, done_reason=st["done_reason"], grader_score=grade.total, reward_components=reward.components, penalties=penalties, state_snapshot=self.state(), ).model_dump() obs = self._build_observation() return obs, reward, st["done"], info