Spaces:
Sleeping
Sleeping
| """Core SupportBench environment.""" | |
| from __future__ import annotations | |
| import copy | |
| from typing import Any, Dict, Optional, Tuple | |
| from .models import ( | |
| Action, | |
| CustomerProfile, | |
| EpisodeState, | |
| Observation, | |
| OrderInfo, | |
| Reward, | |
| ) | |
| from .tasks import AVAILABLE_ACTIONS, get_task | |
| from .reward import compute_step_reward | |
| from .graders import grade | |
| class SupportBenchEnv: | |
| """ | |
| OpenEnv-compliant environment for customer support ticket resolution. | |
| Usage: | |
| env = SupportBenchEnv() | |
| obs = env.reset(task_id="easy_ticket_triage") | |
| obs, reward, done, info = env.step(action) | |
| state = env.state() | |
| env.close() | |
| """ | |
| def __init__(self) -> None: | |
| self._state: Optional[EpisodeState] = None | |
| self._task_spec: Optional[Dict[str, Any]] = None | |
| # ------------------------------------------------------------------ | |
| # reset | |
| # ------------------------------------------------------------------ | |
| def reset(self, task_id: str = "easy_ticket_triage") -> Observation: | |
| spec = get_task(task_id) | |
| self._task_spec = spec | |
| self._state = EpisodeState( | |
| task_id=spec["task_id"], | |
| task_spec=spec, | |
| max_steps=spec.get("max_steps", 8), | |
| current_status=spec.get("current_status", "open"), | |
| ticket_history=list(spec.get("ticket_history", [])), | |
| ) | |
| return self._build_observation() | |
| # ------------------------------------------------------------------ | |
| # step | |
| # ------------------------------------------------------------------ | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: | |
| if self._state is None: | |
| raise RuntimeError("Call reset() before step().") | |
| state = self._state | |
| if state.done: | |
| obs = self._build_observation() | |
| reward = Reward(value=0.0, reason="episode already done", cumulative=state.cumulative_reward) | |
| return obs, reward, True, {"score": grade(state), "violations": state.violations} | |
| state.steps += 1 | |
| error: Optional[str] = None | |
| result: str = "" | |
| # Validate action | |
| error = self._validate_action(action) | |
| if not error: | |
| result, error = self._apply_action(action, state) | |
| # Track repeated actions | |
| action_key = f"{action.action_type}:{action.category or action.priority or action.resolution or action.escalate_to or ''}" | |
| state.repeated_action_counts[action_key] = state.repeated_action_counts.get(action_key, 0) + 1 | |
| # Compute reward | |
| reward_delta, reason = compute_step_reward(action, state, result, error) | |
| state.cumulative_reward = max(-1.0, min(2.0, state.cumulative_reward + reward_delta)) | |
| # Record history | |
| history_entry: Dict[str, Any] = { | |
| "step": state.steps, | |
| "action_type": action.action_type, | |
| } | |
| if action.category: | |
| history_entry["category"] = action.category | |
| if action.priority: | |
| history_entry["priority"] = action.priority | |
| if action.resolution: | |
| history_entry["resolution"] = action.resolution | |
| if action.escalate_to: | |
| history_entry["escalate_to"] = action.escalate_to | |
| if action.message: | |
| history_entry["message"] = action.message[:100] | |
| history_entry["reward"] = round(reward_delta, 3) | |
| history_entry["result"] = result | |
| state.ticket_history.append(history_entry) | |
| state.action_history.append(history_entry) | |
| state.last_action_result = result or None | |
| state.last_action_error = error or None | |
| # Check termination | |
| if state.steps >= state.max_steps: | |
| state.done = True | |
| if action.action_type == "resolve" and not error: | |
| state.done = True | |
| state.resolved = True | |
| final_score = grade(state) if state.done else None | |
| if state.done: | |
| state.success = (final_score or 0.0) >= 0.6 | |
| obs = self._build_observation() | |
| reward = Reward(value=reward_delta, reason=reason, cumulative=state.cumulative_reward) | |
| info: Dict[str, Any] = { | |
| "score": final_score, | |
| "violations": list(state.violations), | |
| "step_result": result, | |
| "step_error": error, | |
| } | |
| return obs, reward, state.done, info | |
| # ------------------------------------------------------------------ | |
| # state | |
| # ------------------------------------------------------------------ | |
| def state(self) -> Dict[str, Any]: | |
| if self._state is None: | |
| return {} | |
| s = self._state | |
| return { | |
| "task_id": s.task_id, | |
| "steps": s.steps, | |
| "max_steps": s.max_steps, | |
| "done": s.done, | |
| "success": s.success, | |
| "cumulative_reward": round(s.cumulative_reward, 4), | |
| "violations": s.violations, | |
| "classified": s.classified, | |
| "priority_set": s.priority_set, | |
| "asked_customer": s.asked_customer, | |
| "identity_verified": s.identity_verified, | |
| "escalated": s.escalated, | |
| "resolved": s.resolved, | |
| "score": grade(s), | |
| } | |
| # ------------------------------------------------------------------ | |
| # close | |
| # ------------------------------------------------------------------ | |
| def close(self) -> Dict[str, Any]: | |
| """Finalize episode and return summary.""" | |
| if self._state is None: | |
| return {"score": 0.0, "steps": 0, "success": False} | |
| s = self._state | |
| s.done = True | |
| final_score = grade(s) | |
| s.success = final_score >= 0.6 | |
| return { | |
| "score": round(final_score, 4), | |
| "steps": s.steps, | |
| "success": s.success, | |
| "cumulative_reward": round(s.cumulative_reward, 4), | |
| "violations": s.violations, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _build_observation(self) -> Observation: | |
| spec = self._task_spec | |
| state = self._state | |
| profile_data = spec["customer_profile"] | |
| return Observation( | |
| task_id=spec["task_id"], | |
| task_name=spec["task_name"], | |
| customer_message=spec["customer_message"], | |
| customer_profile=CustomerProfile(**profile_data), | |
| order_info=OrderInfo(**spec["order_info"]), | |
| policy_snippets=spec["policy_snippets"], | |
| ticket_history=list(state.ticket_history), | |
| current_status=state.current_status, | |
| available_actions=AVAILABLE_ACTIONS, | |
| steps_taken=state.steps, | |
| max_steps=state.max_steps, | |
| last_action_result=state.last_action_result, | |
| last_action_error=state.last_action_error, | |
| ) | |
| def _validate_action(self, action: Action) -> Optional[str]: | |
| """Return error string if action is invalid, else None.""" | |
| if action.action_type == "classify_ticket" and not action.category: | |
| return "classify_ticket requires 'category'" | |
| if action.action_type == "set_priority" and not action.priority: | |
| return "set_priority requires 'priority'" | |
| if action.action_type == "ask_customer" and not action.message: | |
| return "ask_customer requires 'message'" | |
| if action.action_type == "propose_resolution" and not action.resolution: | |
| return "propose_resolution requires 'resolution'" | |
| if action.action_type == "apply_resolution" and not action.resolution: | |
| return "apply_resolution requires 'resolution'" | |
| if action.action_type == "escalate" and not action.escalate_to: | |
| return "escalate requires 'escalate_to'" | |
| return None | |
| def _apply_action(self, action: Action, state: EpisodeState) -> Tuple[str, Optional[str]]: | |
| """Apply action side effects to state. Returns (result_message, error).""" | |
| task_expected = state.task_spec.get("expected", {}) | |
| if action.action_type == "classify_ticket": | |
| state.classified = True | |
| state.correct_category = action.category == task_expected.get("category") | |
| state.current_status = "classified" | |
| return f"Ticket classified as '{action.category}'", None | |
| elif action.action_type == "set_priority": | |
| state.priority_set = True | |
| exp_pri = task_expected.get("priority") | |
| state.correct_priority = ( | |
| action.priority == exp_pri | |
| or (action.priority in ("high", "urgent") and exp_pri in ("high", "urgent")) | |
| ) | |
| state.current_status = f"priority:{action.priority}" | |
| return f"Priority set to '{action.priority}'", None | |
| elif action.action_type == "ask_customer": | |
| state.asked_customer = True | |
| msg_lower = (action.message or "").lower() | |
| # Check if identity verification was requested | |
| if any(kw in msg_lower for kw in ["name", "card", "email", "verify", "identity", "last 4", "confirm"]): | |
| state.identity_verified = True | |
| state.current_status = "awaiting_customer_response" | |
| return f"Message sent to customer: '{action.message[:80]}...' " if len(action.message or "") > 80 else f"Message sent to customer: '{action.message}'", None | |
| elif action.action_type == "propose_resolution": | |
| state.resolution_proposed = True | |
| must_deny = task_expected.get("must_deny_refund", False) | |
| if action.resolution == "replacement": | |
| state.replacement_offered = True | |
| if must_deny and action.resolution == "deny_refund": | |
| state.refund_denied_when_required = True | |
| if must_deny and action.resolution == "replacement": | |
| state.refund_denied_when_required = True | |
| state.replacement_offered = True | |
| state.policy_referenced = bool(action.message and len(action.message) > 20) | |
| state.current_status = f"resolution_proposed:{action.resolution}" | |
| return f"Resolution proposed: '{action.resolution}'", None | |
| elif action.action_type == "apply_resolution": | |
| state.resolution_applied = True | |
| state.current_status = f"resolution_applied:{action.resolution}" | |
| return f"Resolution applied: '{action.resolution}'", None | |
| elif action.action_type == "escalate": | |
| state.escalated = True | |
| correct_target = task_expected.get("correct_escalation_target") | |
| state.escalation_correct = action.escalate_to == correct_target | |
| state.current_status = f"escalated:{action.escalate_to}" | |
| return f"Ticket escalated to '{action.escalate_to}'", None | |
| elif action.action_type == "resolve": | |
| state.resolved = True | |
| state.current_status = "resolved" | |
| return "Ticket resolved and closed.", None | |
| return "", f"Unknown action type '{action.action_type}'" |