| from __future__ import annotations |
|
|
| import copy |
| from typing import Any |
|
|
| from .graders import grade_task |
| from .models import Action, Observation, Reward, StepInfo, TicketView |
| from .tasks import TaskSpec, get_tasks |
|
|
|
|
| class SupportTriageEnv: |
| """ |
| OpenEnv-compatible environment for customer support ticket triage. |
| |
| API: |
| - reset(task_id: str | None = None) -> Observation |
| - step(action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]] |
| - state() -> dict[str, Any] |
| """ |
|
|
| def __init__(self) -> None: |
| self._tasks: dict[str, TaskSpec] = {t.task_id: t for t in get_tasks()} |
| self._task_order = [t.task_id for t in get_tasks()] |
| self._task_index = 0 |
| self._current_task: TaskSpec | None = None |
| self._state: dict[str, Any] = {} |
|
|
| @property |
| def task_ids(self) -> list[str]: |
| return list(self._task_order) |
|
|
| def reset(self, task_id: str | None = None) -> Observation: |
| if task_id is None: |
| task_id = self._task_order[self._task_index % len(self._task_order)] |
| self._task_index += 1 |
| if task_id not in self._tasks: |
| raise ValueError(f"Unknown task_id '{task_id}'. Available: {sorted(self._tasks.keys())}") |
|
|
| self._current_task = self._tasks[task_id] |
| self._state = { |
| "step_count": 0, |
| "read_ticket_ids": set(), |
| "selected_ticket_id": None, |
| "classification": None, |
| "draft_reply": None, |
| "resolved": False, |
| "resolved_ticket_id": None, |
| "invalid_actions": 0, |
| "repeat_actions": 0, |
| "action_history": [], |
| "last_note": "Environment reset.", |
| "done": False, |
| "done_reason": "ongoing", |
| } |
| return self._build_observation() |
|
|
| def step(self, action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]]: |
| if self._current_task is None: |
| raise RuntimeError("Call reset() before step().") |
| if self._state["done"]: |
| raise RuntimeError("Episode already done. Call reset() for a new episode.") |
|
|
| task = self._current_task |
| st = self._state |
| st["step_count"] += 1 |
|
|
| action_fingerprint = action.model_dump_json() |
| if st["action_history"] and st["action_history"][-1] == action_fingerprint: |
| st["repeat_actions"] += 1 |
| st["action_history"].append(action_fingerprint) |
|
|
| valid_ticket_ids = {t["ticket_id"] for t in task.tickets} |
| step_penalty = 0.0 |
|
|
| if action.action_type in {"read_ticket", "classify_ticket", "resolve_ticket"}: |
| if not action.ticket_id or action.ticket_id not in valid_ticket_ids: |
| st["invalid_actions"] += 1 |
| st["last_note"] = "Invalid or missing ticket_id." |
| step_penalty -= 0.03 |
| if st["invalid_actions"] >= 3: |
| st["done"] = True |
| st["done_reason"] = "invalid_action" |
| return self._assemble_step_response(step_penalty) |
|
|
| if action.action_type == "read_ticket": |
| st["read_ticket_ids"].add(action.ticket_id) |
| st["selected_ticket_id"] = action.ticket_id |
| st["last_note"] = f"Read ticket {action.ticket_id}." |
|
|
| elif action.action_type == "classify_ticket": |
| if action.ticket_id != task.target_ticket_id: |
| step_penalty -= 0.01 |
| st["classification"] = { |
| "ticket_id": action.ticket_id, |
| "priority": action.priority, |
| "category": action.category, |
| "needs_escalation": action.needs_escalation, |
| } |
| st["last_note"] = f"Saved classification for {action.ticket_id}." |
|
|
| elif action.action_type == "draft_reply": |
| text = (action.message or "").strip() |
| if not text: |
| st["invalid_actions"] += 1 |
| st["last_note"] = "Draft reply is empty." |
| step_penalty -= 0.02 |
| else: |
| st["draft_reply"] = text |
| st["last_note"] = "Draft reply saved." |
|
|
| elif action.action_type == "resolve_ticket": |
| st["resolved"] = True |
| st["resolved_ticket_id"] = action.ticket_id |
| st["done"] = True |
| st["done_reason"] = "resolved" |
| st["last_note"] = f"Resolved ticket {action.ticket_id}." |
|
|
| else: |
| st["invalid_actions"] += 1 |
| st["last_note"] = f"Unknown action {action.action_type}." |
| step_penalty -= 0.03 |
|
|
| if st["step_count"] >= task.max_steps and not st["done"]: |
| st["done"] = True |
| st["done_reason"] = "max_steps" |
| st["last_note"] = "Reached max_steps." |
|
|
| if st["repeat_actions"] > 0: |
| step_penalty -= min(0.04, 0.01 * st["repeat_actions"]) |
|
|
| return self._assemble_step_response(step_penalty) |
|
|
| def state(self) -> dict[str, Any]: |
| if self._current_task is None: |
| raise RuntimeError("Environment not initialized. Call reset() first.") |
| visible = copy.deepcopy(self._state) |
| visible["read_ticket_ids"] = sorted(list(visible["read_ticket_ids"])) |
| visible["task_id"] = self._current_task.task_id |
| return visible |
|
|
| def _build_observation(self) -> Observation: |
| assert self._current_task is not None |
| task = self._current_task |
| st = self._state |
|
|
| content = None |
| if st.get("selected_ticket_id") in st["read_ticket_ids"]: |
| ticket = next(t for t in task.tickets if t["ticket_id"] == st["selected_ticket_id"]) |
| content = ticket["content"] |
|
|
| inbox = [ |
| TicketView( |
| ticket_id=t["ticket_id"], |
| subject=t["subject"], |
| customer_tier=t["customer_tier"], |
| age_minutes=t["age_minutes"], |
| read=t["ticket_id"] in st["read_ticket_ids"], |
| ) |
| for t in task.tickets |
| ] |
|
|
| partial = grade_task(task, st) |
| return Observation( |
| task_id=task.task_id, |
| objective=task.objective, |
| step_count=st["step_count"], |
| max_steps=task.max_steps, |
| inbox=inbox, |
| current_ticket_content=content, |
| latest_system_note=st.get("last_note", ""), |
| score_hint={ |
| "read": partial.read_score, |
| "classify": partial.classify_score, |
| "reply": partial.reply_score, |
| "resolve": partial.resolve_score, |
| }, |
| ) |
|
|
| def _assemble_step_response(self, step_penalty: float) -> tuple[Observation, Reward, bool, dict[str, Any]]: |
| assert self._current_task is not None |
| task = self._current_task |
| st = self._state |
|
|
| grade = grade_task(task, st) |
| progress_signal = 0.75 * grade.total |
| penalty_total = 0.0 |
|
|
| penalties: dict[str, float] = {} |
| if st["invalid_actions"]: |
| penalties["invalid_actions"] = round(min(0.2, 0.04 * st["invalid_actions"]), 4) |
| penalty_total += penalties["invalid_actions"] |
| if st["repeat_actions"]: |
| penalties["repetition"] = round(min(0.15, 0.02 * st["repeat_actions"]), 4) |
| penalty_total += penalties["repetition"] |
| if step_penalty < 0: |
| penalties["step_penalty"] = round(abs(step_penalty), 4) |
| penalty_total += abs(step_penalty) |
|
|
| reward_value = progress_signal - penalty_total |
| if st["done"]: |
| reward_value = max(reward_value, grade.total) |
|
|
| reward_value = max(0.0, min(1.0, reward_value)) |
|
|
| reward = Reward( |
| value=round(reward_value, 4), |
| components={ |
| "progress_signal": round(progress_signal, 4), |
| "grade_total": grade.total, |
| "read_score": grade.read_score, |
| "classify_score": grade.classify_score, |
| "reply_score": grade.reply_score, |
| "resolve_score": grade.resolve_score, |
| "penalty_total": round(penalty_total, 4), |
| }, |
| reasoning="Shaped reward from grader progress with penalties for invalid or looping actions.", |
| ) |
|
|
| info = StepInfo( |
| task_id=task.task_id, |
| done_reason=st["done_reason"], |
| grader_score=grade.total, |
| reward_components=reward.components, |
| penalties=penalties, |
| state_snapshot=self.state(), |
| ).model_dump() |
|
|
| obs = self._build_observation() |
| return obs, reward, st["done"], info |
|
|