"""Core SupportBench environment.""" from __future__ import annotations import copy from typing import Any, Dict, Optional, Tuple from .models import ( Action, CustomerProfile, EpisodeState, Observation, OrderInfo, Reward, ) from .tasks import AVAILABLE_ACTIONS, get_task from .reward import compute_step_reward from .graders import grade class SupportBenchEnv: """ OpenEnv-compliant environment for customer support ticket resolution. Usage: env = SupportBenchEnv() obs = env.reset(task_id="easy_ticket_triage") obs, reward, done, info = env.step(action) state = env.state() env.close() """ def __init__(self) -> None: self._state: Optional[EpisodeState] = None self._task_spec: Optional[Dict[str, Any]] = None # ------------------------------------------------------------------ # reset # ------------------------------------------------------------------ def reset(self, task_id: str = "easy_ticket_triage") -> Observation: spec = get_task(task_id) self._task_spec = spec self._state = EpisodeState( task_id=spec["task_id"], task_spec=spec, max_steps=spec.get("max_steps", 8), current_status=spec.get("current_status", "open"), ticket_history=list(spec.get("ticket_history", [])), ) return self._build_observation() # ------------------------------------------------------------------ # step # ------------------------------------------------------------------ def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: if self._state is None: raise RuntimeError("Call reset() before step().") state = self._state if state.done: obs = self._build_observation() reward = Reward(value=0.0, reason="episode already done", cumulative=state.cumulative_reward) return obs, reward, True, {"score": grade(state), "violations": state.violations} state.steps += 1 error: Optional[str] = None result: str = "" # Validate action error = self._validate_action(action) if not error: result, error = self._apply_action(action, state) # Track repeated actions action_key = f"{action.action_type}:{action.category or action.priority or action.resolution or action.escalate_to or ''}" state.repeated_action_counts[action_key] = state.repeated_action_counts.get(action_key, 0) + 1 # Compute reward reward_delta, reason = compute_step_reward(action, state, result, error) state.cumulative_reward = max(-1.0, min(2.0, state.cumulative_reward + reward_delta)) # Record history history_entry: Dict[str, Any] = { "step": state.steps, "action_type": action.action_type, } if action.category: history_entry["category"] = action.category if action.priority: history_entry["priority"] = action.priority if action.resolution: history_entry["resolution"] = action.resolution if action.escalate_to: history_entry["escalate_to"] = action.escalate_to if action.message: history_entry["message"] = action.message[:100] history_entry["reward"] = round(reward_delta, 3) history_entry["result"] = result state.ticket_history.append(history_entry) state.action_history.append(history_entry) state.last_action_result = result or None state.last_action_error = error or None # Check termination if state.steps >= state.max_steps: state.done = True if action.action_type == "resolve" and not error: state.done = True state.resolved = True final_score = grade(state) if state.done else None if state.done: state.success = (final_score or 0.0) >= 0.6 obs = self._build_observation() reward = Reward(value=reward_delta, reason=reason, cumulative=state.cumulative_reward) info: Dict[str, Any] = { "score": final_score, "violations": list(state.violations), "step_result": result, "step_error": error, } return obs, reward, state.done, info # ------------------------------------------------------------------ # state # ------------------------------------------------------------------ def state(self) -> Dict[str, Any]: if self._state is None: return {} s = self._state return { "task_id": s.task_id, "steps": s.steps, "max_steps": s.max_steps, "done": s.done, "success": s.success, "cumulative_reward": round(s.cumulative_reward, 4), "violations": s.violations, "classified": s.classified, "priority_set": s.priority_set, "asked_customer": s.asked_customer, "identity_verified": s.identity_verified, "escalated": s.escalated, "resolved": s.resolved, "score": grade(s), } # ------------------------------------------------------------------ # close # ------------------------------------------------------------------ def close(self) -> Dict[str, Any]: """Finalize episode and return summary.""" if self._state is None: return {"score": 0.0, "steps": 0, "success": False} s = self._state s.done = True final_score = grade(s) s.success = final_score >= 0.6 return { "score": round(final_score, 4), "steps": s.steps, "success": s.success, "cumulative_reward": round(s.cumulative_reward, 4), "violations": s.violations, } # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _build_observation(self) -> Observation: spec = self._task_spec state = self._state profile_data = spec["customer_profile"] return Observation( task_id=spec["task_id"], task_name=spec["task_name"], customer_message=spec["customer_message"], customer_profile=CustomerProfile(**profile_data), order_info=OrderInfo(**spec["order_info"]), policy_snippets=spec["policy_snippets"], ticket_history=list(state.ticket_history), current_status=state.current_status, available_actions=AVAILABLE_ACTIONS, steps_taken=state.steps, max_steps=state.max_steps, last_action_result=state.last_action_result, last_action_error=state.last_action_error, ) def _validate_action(self, action: Action) -> Optional[str]: """Return error string if action is invalid, else None.""" if action.action_type == "classify_ticket" and not action.category: return "classify_ticket requires 'category'" if action.action_type == "set_priority" and not action.priority: return "set_priority requires 'priority'" if action.action_type == "ask_customer" and not action.message: return "ask_customer requires 'message'" if action.action_type == "propose_resolution" and not action.resolution: return "propose_resolution requires 'resolution'" if action.action_type == "apply_resolution" and not action.resolution: return "apply_resolution requires 'resolution'" if action.action_type == "escalate" and not action.escalate_to: return "escalate requires 'escalate_to'" return None def _apply_action(self, action: Action, state: EpisodeState) -> Tuple[str, Optional[str]]: """Apply action side effects to state. Returns (result_message, error).""" task_expected = state.task_spec.get("expected", {}) if action.action_type == "classify_ticket": state.classified = True state.correct_category = action.category == task_expected.get("category") state.current_status = "classified" return f"Ticket classified as '{action.category}'", None elif action.action_type == "set_priority": state.priority_set = True exp_pri = task_expected.get("priority") state.correct_priority = ( action.priority == exp_pri or (action.priority in ("high", "urgent") and exp_pri in ("high", "urgent")) ) state.current_status = f"priority:{action.priority}" return f"Priority set to '{action.priority}'", None elif action.action_type == "ask_customer": state.asked_customer = True msg_lower = (action.message or "").lower() # Check if identity verification was requested if any(kw in msg_lower for kw in ["name", "card", "email", "verify", "identity", "last 4", "confirm"]): state.identity_verified = True state.current_status = "awaiting_customer_response" return f"Message sent to customer: '{action.message[:80]}...' " if len(action.message or "") > 80 else f"Message sent to customer: '{action.message}'", None elif action.action_type == "propose_resolution": state.resolution_proposed = True must_deny = task_expected.get("must_deny_refund", False) if action.resolution == "replacement": state.replacement_offered = True if must_deny and action.resolution == "deny_refund": state.refund_denied_when_required = True if must_deny and action.resolution == "replacement": state.refund_denied_when_required = True state.replacement_offered = True state.policy_referenced = bool(action.message and len(action.message) > 20) state.current_status = f"resolution_proposed:{action.resolution}" return f"Resolution proposed: '{action.resolution}'", None elif action.action_type == "apply_resolution": state.resolution_applied = True state.current_status = f"resolution_applied:{action.resolution}" return f"Resolution applied: '{action.resolution}'", None elif action.action_type == "escalate": state.escalated = True correct_target = task_expected.get("correct_escalation_target") state.escalation_correct = action.escalate_to == correct_target state.current_status = f"escalated:{action.escalate_to}" return f"Ticket escalated to '{action.escalate_to}'", None elif action.action_type == "resolve": state.resolved = True state.current_status = "resolved" return "Ticket resolved and closed.", None return "", f"Unknown action type '{action.action_type}'"