Rishi Prasad
Clean submission upload
bc8b288
"""Core SupportBench environment."""
from __future__ import annotations
import copy
from typing import Any, Dict, Optional, Tuple
from .models import (
Action,
CustomerProfile,
EpisodeState,
Observation,
OrderInfo,
Reward,
)
from .tasks import AVAILABLE_ACTIONS, get_task
from .reward import compute_step_reward
from .graders import grade
class SupportBenchEnv:
"""
OpenEnv-compliant environment for customer support ticket resolution.
Usage:
env = SupportBenchEnv()
obs = env.reset(task_id="easy_ticket_triage")
obs, reward, done, info = env.step(action)
state = env.state()
env.close()
"""
def __init__(self) -> None:
self._state: Optional[EpisodeState] = None
self._task_spec: Optional[Dict[str, Any]] = None
# ------------------------------------------------------------------
# reset
# ------------------------------------------------------------------
def reset(self, task_id: str = "easy_ticket_triage") -> Observation:
spec = get_task(task_id)
self._task_spec = spec
self._state = EpisodeState(
task_id=spec["task_id"],
task_spec=spec,
max_steps=spec.get("max_steps", 8),
current_status=spec.get("current_status", "open"),
ticket_history=list(spec.get("ticket_history", [])),
)
return self._build_observation()
# ------------------------------------------------------------------
# step
# ------------------------------------------------------------------
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
if self._state is None:
raise RuntimeError("Call reset() before step().")
state = self._state
if state.done:
obs = self._build_observation()
reward = Reward(value=0.0, reason="episode already done", cumulative=state.cumulative_reward)
return obs, reward, True, {"score": grade(state), "violations": state.violations}
state.steps += 1
error: Optional[str] = None
result: str = ""
# Validate action
error = self._validate_action(action)
if not error:
result, error = self._apply_action(action, state)
# Track repeated actions
action_key = f"{action.action_type}:{action.category or action.priority or action.resolution or action.escalate_to or ''}"
state.repeated_action_counts[action_key] = state.repeated_action_counts.get(action_key, 0) + 1
# Compute reward
reward_delta, reason = compute_step_reward(action, state, result, error)
state.cumulative_reward = max(-1.0, min(2.0, state.cumulative_reward + reward_delta))
# Record history
history_entry: Dict[str, Any] = {
"step": state.steps,
"action_type": action.action_type,
}
if action.category:
history_entry["category"] = action.category
if action.priority:
history_entry["priority"] = action.priority
if action.resolution:
history_entry["resolution"] = action.resolution
if action.escalate_to:
history_entry["escalate_to"] = action.escalate_to
if action.message:
history_entry["message"] = action.message[:100]
history_entry["reward"] = round(reward_delta, 3)
history_entry["result"] = result
state.ticket_history.append(history_entry)
state.action_history.append(history_entry)
state.last_action_result = result or None
state.last_action_error = error or None
# Check termination
if state.steps >= state.max_steps:
state.done = True
if action.action_type == "resolve" and not error:
state.done = True
state.resolved = True
final_score = grade(state) if state.done else None
if state.done:
state.success = (final_score or 0.0) >= 0.6
obs = self._build_observation()
reward = Reward(value=reward_delta, reason=reason, cumulative=state.cumulative_reward)
info: Dict[str, Any] = {
"score": final_score,
"violations": list(state.violations),
"step_result": result,
"step_error": error,
}
return obs, reward, state.done, info
# ------------------------------------------------------------------
# state
# ------------------------------------------------------------------
def state(self) -> Dict[str, Any]:
if self._state is None:
return {}
s = self._state
return {
"task_id": s.task_id,
"steps": s.steps,
"max_steps": s.max_steps,
"done": s.done,
"success": s.success,
"cumulative_reward": round(s.cumulative_reward, 4),
"violations": s.violations,
"classified": s.classified,
"priority_set": s.priority_set,
"asked_customer": s.asked_customer,
"identity_verified": s.identity_verified,
"escalated": s.escalated,
"resolved": s.resolved,
"score": grade(s),
}
# ------------------------------------------------------------------
# close
# ------------------------------------------------------------------
def close(self) -> Dict[str, Any]:
"""Finalize episode and return summary."""
if self._state is None:
return {"score": 0.0, "steps": 0, "success": False}
s = self._state
s.done = True
final_score = grade(s)
s.success = final_score >= 0.6
return {
"score": round(final_score, 4),
"steps": s.steps,
"success": s.success,
"cumulative_reward": round(s.cumulative_reward, 4),
"violations": s.violations,
}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_observation(self) -> Observation:
spec = self._task_spec
state = self._state
profile_data = spec["customer_profile"]
return Observation(
task_id=spec["task_id"],
task_name=spec["task_name"],
customer_message=spec["customer_message"],
customer_profile=CustomerProfile(**profile_data),
order_info=OrderInfo(**spec["order_info"]),
policy_snippets=spec["policy_snippets"],
ticket_history=list(state.ticket_history),
current_status=state.current_status,
available_actions=AVAILABLE_ACTIONS,
steps_taken=state.steps,
max_steps=state.max_steps,
last_action_result=state.last_action_result,
last_action_error=state.last_action_error,
)
def _validate_action(self, action: Action) -> Optional[str]:
"""Return error string if action is invalid, else None."""
if action.action_type == "classify_ticket" and not action.category:
return "classify_ticket requires 'category'"
if action.action_type == "set_priority" and not action.priority:
return "set_priority requires 'priority'"
if action.action_type == "ask_customer" and not action.message:
return "ask_customer requires 'message'"
if action.action_type == "propose_resolution" and not action.resolution:
return "propose_resolution requires 'resolution'"
if action.action_type == "apply_resolution" and not action.resolution:
return "apply_resolution requires 'resolution'"
if action.action_type == "escalate" and not action.escalate_to:
return "escalate requires 'escalate_to'"
return None
def _apply_action(self, action: Action, state: EpisodeState) -> Tuple[str, Optional[str]]:
"""Apply action side effects to state. Returns (result_message, error)."""
task_expected = state.task_spec.get("expected", {})
if action.action_type == "classify_ticket":
state.classified = True
state.correct_category = action.category == task_expected.get("category")
state.current_status = "classified"
return f"Ticket classified as '{action.category}'", None
elif action.action_type == "set_priority":
state.priority_set = True
exp_pri = task_expected.get("priority")
state.correct_priority = (
action.priority == exp_pri
or (action.priority in ("high", "urgent") and exp_pri in ("high", "urgent"))
)
state.current_status = f"priority:{action.priority}"
return f"Priority set to '{action.priority}'", None
elif action.action_type == "ask_customer":
state.asked_customer = True
msg_lower = (action.message or "").lower()
# Check if identity verification was requested
if any(kw in msg_lower for kw in ["name", "card", "email", "verify", "identity", "last 4", "confirm"]):
state.identity_verified = True
state.current_status = "awaiting_customer_response"
return f"Message sent to customer: '{action.message[:80]}...' " if len(action.message or "") > 80 else f"Message sent to customer: '{action.message}'", None
elif action.action_type == "propose_resolution":
state.resolution_proposed = True
must_deny = task_expected.get("must_deny_refund", False)
if action.resolution == "replacement":
state.replacement_offered = True
if must_deny and action.resolution == "deny_refund":
state.refund_denied_when_required = True
if must_deny and action.resolution == "replacement":
state.refund_denied_when_required = True
state.replacement_offered = True
state.policy_referenced = bool(action.message and len(action.message) > 20)
state.current_status = f"resolution_proposed:{action.resolution}"
return f"Resolution proposed: '{action.resolution}'", None
elif action.action_type == "apply_resolution":
state.resolution_applied = True
state.current_status = f"resolution_applied:{action.resolution}"
return f"Resolution applied: '{action.resolution}'", None
elif action.action_type == "escalate":
state.escalated = True
correct_target = task_expected.get("correct_escalation_target")
state.escalation_correct = action.escalate_to == correct_target
state.current_status = f"escalated:{action.escalate_to}"
return f"Ticket escalated to '{action.escalate_to}'", None
elif action.action_type == "resolve":
state.resolved = True
state.current_status = "resolved"
return "Ticket resolved and closed.", None
return "", f"Unknown action type '{action.action_type}'"