afroimam's picture
Upload folder using huggingface_hub
1395b2e verified
from __future__ import annotations
import copy
from typing import Any
from .graders import grade_task
from .models import Action, Observation, Reward, StepInfo, TicketView
from .tasks import TaskSpec, get_tasks
class SupportTriageEnv:
"""
OpenEnv-compatible environment for customer support ticket triage.
API:
- reset(task_id: str | None = None) -> Observation
- step(action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]]
- state() -> dict[str, Any]
"""
def __init__(self) -> None:
self._tasks: dict[str, TaskSpec] = {t.task_id: t for t in get_tasks()}
self._task_order = [t.task_id for t in get_tasks()]
self._task_index = 0
self._current_task: TaskSpec | None = None
self._state: dict[str, Any] = {}
@property
def task_ids(self) -> list[str]:
return list(self._task_order)
def reset(self, task_id: str | None = None) -> Observation:
if task_id is None:
task_id = self._task_order[self._task_index % len(self._task_order)]
self._task_index += 1
if task_id not in self._tasks:
raise ValueError(f"Unknown task_id '{task_id}'. Available: {sorted(self._tasks.keys())}")
self._current_task = self._tasks[task_id]
self._state = {
"step_count": 0,
"read_ticket_ids": set(),
"selected_ticket_id": None,
"classification": None,
"draft_reply": None,
"resolved": False,
"resolved_ticket_id": None,
"invalid_actions": 0,
"repeat_actions": 0,
"action_history": [],
"last_note": "Environment reset.",
"done": False,
"done_reason": "ongoing",
}
return self._build_observation()
def step(self, action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]]:
if self._current_task is None:
raise RuntimeError("Call reset() before step().")
if self._state["done"]:
raise RuntimeError("Episode already done. Call reset() for a new episode.")
task = self._current_task
st = self._state
st["step_count"] += 1
action_fingerprint = action.model_dump_json()
if st["action_history"] and st["action_history"][-1] == action_fingerprint:
st["repeat_actions"] += 1
st["action_history"].append(action_fingerprint)
valid_ticket_ids = {t["ticket_id"] for t in task.tickets}
step_penalty = 0.0
if action.action_type in {"read_ticket", "classify_ticket", "resolve_ticket"}:
if not action.ticket_id or action.ticket_id not in valid_ticket_ids:
st["invalid_actions"] += 1
st["last_note"] = "Invalid or missing ticket_id."
step_penalty -= 0.03
if st["invalid_actions"] >= 3:
st["done"] = True
st["done_reason"] = "invalid_action"
return self._assemble_step_response(step_penalty)
if action.action_type == "read_ticket":
st["read_ticket_ids"].add(action.ticket_id)
st["selected_ticket_id"] = action.ticket_id
st["last_note"] = f"Read ticket {action.ticket_id}."
elif action.action_type == "classify_ticket":
if action.ticket_id != task.target_ticket_id:
step_penalty -= 0.01
st["classification"] = {
"ticket_id": action.ticket_id,
"priority": action.priority,
"category": action.category,
"needs_escalation": action.needs_escalation,
}
st["last_note"] = f"Saved classification for {action.ticket_id}."
elif action.action_type == "draft_reply":
text = (action.message or "").strip()
if not text:
st["invalid_actions"] += 1
st["last_note"] = "Draft reply is empty."
step_penalty -= 0.02
else:
st["draft_reply"] = text
st["last_note"] = "Draft reply saved."
elif action.action_type == "resolve_ticket":
st["resolved"] = True
st["resolved_ticket_id"] = action.ticket_id
st["done"] = True
st["done_reason"] = "resolved"
st["last_note"] = f"Resolved ticket {action.ticket_id}."
else:
st["invalid_actions"] += 1
st["last_note"] = f"Unknown action {action.action_type}."
step_penalty -= 0.03
if st["step_count"] >= task.max_steps and not st["done"]:
st["done"] = True
st["done_reason"] = "max_steps"
st["last_note"] = "Reached max_steps."
if st["repeat_actions"] > 0:
step_penalty -= min(0.04, 0.01 * st["repeat_actions"])
return self._assemble_step_response(step_penalty)
def state(self) -> dict[str, Any]:
if self._current_task is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
visible = copy.deepcopy(self._state)
visible["read_ticket_ids"] = sorted(list(visible["read_ticket_ids"]))
visible["task_id"] = self._current_task.task_id
return visible
def _build_observation(self) -> Observation:
assert self._current_task is not None
task = self._current_task
st = self._state
content = None
if st.get("selected_ticket_id") in st["read_ticket_ids"]:
ticket = next(t for t in task.tickets if t["ticket_id"] == st["selected_ticket_id"])
content = ticket["content"]
inbox = [
TicketView(
ticket_id=t["ticket_id"],
subject=t["subject"],
customer_tier=t["customer_tier"],
age_minutes=t["age_minutes"],
read=t["ticket_id"] in st["read_ticket_ids"],
)
for t in task.tickets
]
partial = grade_task(task, st)
return Observation(
task_id=task.task_id,
objective=task.objective,
step_count=st["step_count"],
max_steps=task.max_steps,
inbox=inbox,
current_ticket_content=content,
latest_system_note=st.get("last_note", ""),
score_hint={
"read": partial.read_score,
"classify": partial.classify_score,
"reply": partial.reply_score,
"resolve": partial.resolve_score,
},
)
def _assemble_step_response(self, step_penalty: float) -> tuple[Observation, Reward, bool, dict[str, Any]]:
assert self._current_task is not None
task = self._current_task
st = self._state
grade = grade_task(task, st)
progress_signal = 0.75 * grade.total
penalty_total = 0.0
penalties: dict[str, float] = {}
if st["invalid_actions"]:
penalties["invalid_actions"] = round(min(0.2, 0.04 * st["invalid_actions"]), 4)
penalty_total += penalties["invalid_actions"]
if st["repeat_actions"]:
penalties["repetition"] = round(min(0.15, 0.02 * st["repeat_actions"]), 4)
penalty_total += penalties["repetition"]
if step_penalty < 0:
penalties["step_penalty"] = round(abs(step_penalty), 4)
penalty_total += abs(step_penalty)
reward_value = progress_signal - penalty_total
if st["done"]:
reward_value = max(reward_value, grade.total)
reward_value = max(0.0, min(1.0, reward_value))
reward = Reward(
value=round(reward_value, 4),
components={
"progress_signal": round(progress_signal, 4),
"grade_total": grade.total,
"read_score": grade.read_score,
"classify_score": grade.classify_score,
"reply_score": grade.reply_score,
"resolve_score": grade.resolve_score,
"penalty_total": round(penalty_total, 4),
},
reasoning="Shaped reward from grader progress with penalties for invalid or looping actions.",
)
info = StepInfo(
task_id=task.task_id,
done_reason=st["done_reason"],
grader_score=grade.total,
reward_components=reward.components,
penalties=penalties,
state_snapshot=self.state(),
).model_dump()
obs = self._build_observation()
return obs, reward, st["done"], info