Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import copy | |
| from typing import Any | |
| from .graders import grade_task | |
| from .models import Action, Observation, Reward, StepInfo, TicketView | |
| from .tasks import TaskSpec, get_tasks | |
| class SupportTriageEnv: | |
| """ | |
| OpenEnv-compatible environment for customer support ticket triage. | |
| API: | |
| - reset(task_id: str | None = None) -> Observation | |
| - step(action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]] | |
| - state() -> dict[str, Any] | |
| """ | |
| def __init__(self) -> None: | |
| self._tasks: dict[str, TaskSpec] = {t.task_id: t for t in get_tasks()} | |
| self._task_order = [t.task_id for t in get_tasks()] | |
| self._task_index = 0 | |
| self._current_task: TaskSpec | None = None | |
| self._state: dict[str, Any] = {} | |
| def task_ids(self) -> list[str]: | |
| return list(self._task_order) | |
| def reset(self, task_id: str | None = None) -> Observation: | |
| if task_id is None: | |
| task_id = self._task_order[self._task_index % len(self._task_order)] | |
| self._task_index += 1 | |
| if task_id not in self._tasks: | |
| raise ValueError(f"Unknown task_id '{task_id}'. Available: {sorted(self._tasks.keys())}") | |
| self._current_task = self._tasks[task_id] | |
| self._state = { | |
| "step_count": 0, | |
| "read_ticket_ids": set(), | |
| "selected_ticket_id": None, | |
| "classification": None, | |
| "draft_reply": None, | |
| "resolved": False, | |
| "resolved_ticket_id": None, | |
| "invalid_actions": 0, | |
| "repeat_actions": 0, | |
| "action_history": [], | |
| "last_note": "Environment reset.", | |
| "done": False, | |
| "done_reason": "ongoing", | |
| } | |
| return self._build_observation() | |
| def step(self, action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]]: | |
| if self._current_task is None: | |
| raise RuntimeError("Call reset() before step().") | |
| if self._state["done"]: | |
| raise RuntimeError("Episode already done. Call reset() for a new episode.") | |
| task = self._current_task | |
| st = self._state | |
| st["step_count"] += 1 | |
| action_fingerprint = action.model_dump_json() | |
| if st["action_history"] and st["action_history"][-1] == action_fingerprint: | |
| st["repeat_actions"] += 1 | |
| st["action_history"].append(action_fingerprint) | |
| valid_ticket_ids = {t["ticket_id"] for t in task.tickets} | |
| step_penalty = 0.0 | |
| if action.action_type in {"read_ticket", "classify_ticket", "resolve_ticket"}: | |
| if not action.ticket_id or action.ticket_id not in valid_ticket_ids: | |
| st["invalid_actions"] += 1 | |
| st["last_note"] = "Invalid or missing ticket_id." | |
| step_penalty -= 0.03 | |
| if st["invalid_actions"] >= 3: | |
| st["done"] = True | |
| st["done_reason"] = "invalid_action" | |
| return self._assemble_step_response(step_penalty) | |
| if action.action_type == "read_ticket": | |
| st["read_ticket_ids"].add(action.ticket_id) | |
| st["selected_ticket_id"] = action.ticket_id | |
| st["last_note"] = f"Read ticket {action.ticket_id}." | |
| elif action.action_type == "classify_ticket": | |
| if action.ticket_id != task.target_ticket_id: | |
| step_penalty -= 0.01 | |
| st["classification"] = { | |
| "ticket_id": action.ticket_id, | |
| "priority": action.priority, | |
| "category": action.category, | |
| "needs_escalation": action.needs_escalation, | |
| } | |
| st["last_note"] = f"Saved classification for {action.ticket_id}." | |
| elif action.action_type == "draft_reply": | |
| text = (action.message or "").strip() | |
| if not text: | |
| st["invalid_actions"] += 1 | |
| st["last_note"] = "Draft reply is empty." | |
| step_penalty -= 0.02 | |
| else: | |
| st["draft_reply"] = text | |
| st["last_note"] = "Draft reply saved." | |
| elif action.action_type == "resolve_ticket": | |
| st["resolved"] = True | |
| st["resolved_ticket_id"] = action.ticket_id | |
| st["done"] = True | |
| st["done_reason"] = "resolved" | |
| st["last_note"] = f"Resolved ticket {action.ticket_id}." | |
| else: | |
| st["invalid_actions"] += 1 | |
| st["last_note"] = f"Unknown action {action.action_type}." | |
| step_penalty -= 0.03 | |
| if st["step_count"] >= task.max_steps and not st["done"]: | |
| st["done"] = True | |
| st["done_reason"] = "max_steps" | |
| st["last_note"] = "Reached max_steps." | |
| if st["repeat_actions"] > 0: | |
| step_penalty -= min(0.04, 0.01 * st["repeat_actions"]) | |
| return self._assemble_step_response(step_penalty) | |
| def state(self) -> dict[str, Any]: | |
| if self._current_task is None: | |
| raise RuntimeError("Environment not initialized. Call reset() first.") | |
| visible = copy.deepcopy(self._state) | |
| visible["read_ticket_ids"] = sorted(list(visible["read_ticket_ids"])) | |
| visible["task_id"] = self._current_task.task_id | |
| return visible | |
| def _build_observation(self) -> Observation: | |
| assert self._current_task is not None | |
| task = self._current_task | |
| st = self._state | |
| content = None | |
| if st.get("selected_ticket_id") in st["read_ticket_ids"]: | |
| ticket = next(t for t in task.tickets if t["ticket_id"] == st["selected_ticket_id"]) | |
| content = ticket["content"] | |
| inbox = [ | |
| TicketView( | |
| ticket_id=t["ticket_id"], | |
| subject=t["subject"], | |
| customer_tier=t["customer_tier"], | |
| age_minutes=t["age_minutes"], | |
| read=t["ticket_id"] in st["read_ticket_ids"], | |
| ) | |
| for t in task.tickets | |
| ] | |
| partial = grade_task(task, st) | |
| return Observation( | |
| task_id=task.task_id, | |
| objective=task.objective, | |
| step_count=st["step_count"], | |
| max_steps=task.max_steps, | |
| inbox=inbox, | |
| current_ticket_content=content, | |
| latest_system_note=st.get("last_note", ""), | |
| score_hint={ | |
| "read": partial.read_score, | |
| "classify": partial.classify_score, | |
| "reply": partial.reply_score, | |
| "resolve": partial.resolve_score, | |
| }, | |
| ) | |
| def _assemble_step_response(self, step_penalty: float) -> tuple[Observation, Reward, bool, dict[str, Any]]: | |
| assert self._current_task is not None | |
| task = self._current_task | |
| st = self._state | |
| grade = grade_task(task, st) | |
| progress_signal = 0.75 * grade.total | |
| penalty_total = 0.0 | |
| penalties: dict[str, float] = {} | |
| if st["invalid_actions"]: | |
| penalties["invalid_actions"] = round(min(0.2, 0.04 * st["invalid_actions"]), 4) | |
| penalty_total += penalties["invalid_actions"] | |
| if st["repeat_actions"]: | |
| penalties["repetition"] = round(min(0.15, 0.02 * st["repeat_actions"]), 4) | |
| penalty_total += penalties["repetition"] | |
| if step_penalty < 0: | |
| penalties["step_penalty"] = round(abs(step_penalty), 4) | |
| penalty_total += abs(step_penalty) | |
| reward_value = progress_signal - penalty_total | |
| if st["done"]: | |
| reward_value = max(reward_value, grade.total) | |
| reward_value = max(0.0, min(1.0, reward_value)) | |
| reward = Reward( | |
| value=round(reward_value, 4), | |
| components={ | |
| "progress_signal": round(progress_signal, 4), | |
| "grade_total": grade.total, | |
| "read_score": grade.read_score, | |
| "classify_score": grade.classify_score, | |
| "reply_score": grade.reply_score, | |
| "resolve_score": grade.resolve_score, | |
| "penalty_total": round(penalty_total, 4), | |
| }, | |
| reasoning="Shaped reward from grader progress with penalties for invalid or looping actions.", | |
| ) | |
| info = StepInfo( | |
| task_id=task.task_id, | |
| done_reason=st["done_reason"], | |
| grader_score=grade.total, | |
| reward_components=reward.components, | |
| penalties=penalties, | |
| state_snapshot=self.state(), | |
| ).model_dump() | |
| obs = self._build_observation() | |
| return obs, reward, st["done"], info | |