Dar3devil's picture
Initial customer support OpenEnv upload
2b73c16 verified
Raw
History Blame Contribute Delete
18.4 kB
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from .fixtures import (
BENCHMARK_NAME,
DEFAULT_SUCCESS_THRESHOLD,
KB_ARTICLES,
KnowledgeBaseArticle,
TaskFixture,
get_task_fixture,
list_task_ids,
)
from .models import (
ACTION_TYPE_NAMES,
AccountLookupResult,
ConversationTurn,
KBSearchResult,
ErrorToolResult,
EscalateTicketAction,
EscalationResult,
IssueRefundAction,
LookupAccountAction,
RefundResult,
ReplyResult,
ResolveResult,
SearchKBAction,
SupportTicketAction,
SupportTicketObservation,
SupportTicketStepResult,
ToolResult,
parse_action,
)
from .scoring import build_scorecard, normalize_text
@dataclass
class SessionState:
fixture: TaskFixture
ticket_status: str = "open"
steps_taken: int = 0
conversation_history: list[ConversationTurn] = field(default_factory=list)
action_history: list[dict[str, Any]] = field(default_factory=list)
reply_history: list[dict[str, Any]] = field(default_factory=list)
known_facts: dict[str, Any] = field(default_factory=dict)
kb_articles_seen: set[str] = field(default_factory=set)
search_signatures: set[str] = field(default_factory=set)
lookup_performed: bool = False
lookup_customer_id: str | None = None
refund_record: dict[str, Any] | None = None
refund_attempted: bool = False
resolution_code: str | None = None
escalation: dict[str, Any] | None = None
done: bool = False
terminal_reason: str | None = None
previous_score: float = 0.0
last_tool_result: ToolResult | None = None
last_action_error: str | None = None
class SupportTicketEnvironment:
benchmark_name = BENCHMARK_NAME
max_steps = 8
step_cost = 0.01
invalid_action_penalty = 0.10
repeated_action_penalty = 0.02
success_threshold = DEFAULT_SUCCESS_THRESHOLD
def __init__(self, task_id: str | None = None) -> None:
self._default_task_id = task_id or list_task_ids()[0]
self._session: SessionState | None = None
def reset(self, task_id: str | None = None) -> SupportTicketStepResult:
fixture = get_task_fixture(task_id or self._default_task_id)
self._session = SessionState(
fixture=fixture,
conversation_history=[
ConversationTurn(
role="customer",
message=fixture.ticket.message,
step_index=0,
)
],
)
return self._build_result(reward=0.0)
def step(self, action: SupportTicketAction | dict[str, Any]) -> SupportTicketStepResult:
session = self._require_session()
if session.done:
session.last_action_error = "episode_already_done"
session.last_tool_result = ErrorToolResult(
tool_name="error",
success=False,
error_code="episode_already_done",
message="This ticket is already terminal. Reset the environment before stepping again.",
)
return self._build_result(reward=-self.invalid_action_penalty)
invalid_penalty = 0.0
redundancy_penalty = 0.0
session.last_action_error = None
try:
parsed_action = parse_action(action)
except Exception as exc:
session.steps_taken += 1
session.last_action_error = f"invalid_action: {exc}"
session.last_tool_result = ErrorToolResult(
tool_name="error",
success=False,
error_code="invalid_action",
message=str(exc),
)
invalid_penalty = self.invalid_action_penalty
self._record_action({"action_type": "invalid"}, False)
if session.steps_taken >= self.max_steps:
session.done = True
session.terminal_reason = "max_steps_exceeded"
return self._finalize_step(invalid_penalty=invalid_penalty, redundancy_penalty=0.0)
session.steps_taken += 1
session.last_tool_result, invalid_penalty, redundancy_penalty = self._apply_action(parsed_action)
action_succeeded = bool(getattr(session.last_tool_result, "success", False))
self._record_action(parsed_action.model_dump(mode="json"), action_succeeded)
if not session.done and session.steps_taken >= self.max_steps:
session.done = True
session.terminal_reason = "max_steps_exceeded"
return self._finalize_step(
invalid_penalty=invalid_penalty,
redundancy_penalty=redundancy_penalty,
)
def state(self) -> dict[str, Any]:
session = self._require_session()
scorecard = build_scorecard(session.fixture, session)
return {
"benchmark_name": self.benchmark_name,
"task_id": session.fixture.task_id,
"ticket_status": session.ticket_status,
"steps_taken": session.steps_taken,
"steps_remaining": max(self.max_steps - session.steps_taken, 0),
"conversation_history": [turn.model_dump(mode="json") for turn in session.conversation_history],
"audit_log": list(session.action_history),
"known_facts": dict(session.known_facts),
"current_rubric_score": scorecard.score,
"score_breakdown": scorecard.model_dump(mode="json"),
"terminal_reason": session.terminal_reason,
"done": session.done,
}
def _apply_action(self, action: SupportTicketAction) -> tuple[ToolResult, float, float]:
session = self._require_session()
invalid_penalty = 0.0
redundancy_penalty = 0.0
if isinstance(action, SearchKBAction):
query_signature = normalize_text(action.query)
if query_signature in session.search_signatures:
redundancy_penalty = self.repeated_action_penalty
session.search_signatures.add(query_signature)
articles = self._search_knowledge_base(action.query)
article_ids = [article.article_id for article in articles]
session.kb_articles_seen.update(article_ids)
session.known_facts["kb_articles_seen"] = sorted(session.kb_articles_seen)
session.known_facts["kb_titles_seen"] = [KB_ARTICLES[article_id].title for article_id in sorted(session.kb_articles_seen)]
result = KBSearchResult(
tool_name="search_kb",
success=bool(articles),
query=action.query,
article_ids=article_ids,
snippets=[article.snippet for article in articles],
message="Knowledge base search completed." if articles else "No KB articles matched the query.",
)
return result, invalid_penalty, redundancy_penalty
if isinstance(action, LookupAccountAction):
if action.customer_id != session.fixture.account.customer_id:
session.last_action_error = "unknown_customer_id"
result = ErrorToolResult(
tool_name="error",
success=False,
error_code="unknown_customer_id",
message=f"No account found for customer_id={action.customer_id}.",
)
return result, self.invalid_action_penalty, redundancy_penalty
if session.lookup_performed and session.lookup_customer_id == action.customer_id:
redundancy_penalty = self.repeated_action_penalty
account = session.fixture.account
session.lookup_performed = True
session.lookup_customer_id = action.customer_id
account_summary = {
"customer_id": account.customer_id,
"organization_name": account.organization_name,
"plan": account.plan,
"tenure_years": account.tenure_years,
"arr_usd": account.arr_usd,
"duplicate_charge_amount_cents": account.duplicate_charge_amount_cents,
"duplicate_charge_count": account.duplicate_charge_count,
"duplicate_charge_refund_eligible": account.duplicate_charge_refund_eligible,
"legal_threat": account.legal_threat,
"incident_severity": account.incident_severity,
}
session.known_facts["account"] = account_summary
result = AccountLookupResult(
tool_name="lookup_account",
success=True,
customer_id=action.customer_id,
account_summary=account_summary,
message="Account lookup completed.",
)
return result, invalid_penalty, redundancy_penalty
if action.action_type == "send_reply":
reply = action.message.strip()
session.reply_history.append({"message": reply, "step_index": session.steps_taken})
session.conversation_history.append(
ConversationTurn(role="agent", message=reply, step_index=session.steps_taken)
)
result = ReplyResult(
tool_name="send_reply",
success=True,
message_preview=reply[:120],
message="Reply sent to the customer.",
)
return result, invalid_penalty, redundancy_penalty
if isinstance(action, IssueRefundAction):
session.refund_attempted = True
account = session.fixture.account
if not session.lookup_performed:
session.last_action_error = "lookup_required_before_refund"
result = ErrorToolResult(
tool_name="error",
success=False,
error_code="lookup_required_before_refund",
message="lookup_account must succeed before issue_refund can be used.",
)
return result, self.invalid_action_penalty, redundancy_penalty
if not account.duplicate_charge_refund_eligible or not account.duplicate_charge_amount_cents:
session.last_action_error = "refund_not_applicable"
result = RefundResult(
tool_name="issue_refund",
success=False,
refunded=False,
amount_cents=action.amount_cents,
reason_code=action.reason_code,
message="No duplicate charge is eligible for refund on this account.",
)
return result, self.invalid_action_penalty, redundancy_penalty
if action.amount_cents != account.duplicate_charge_amount_cents or action.reason_code != "duplicate_charge":
session.last_action_error = "incorrect_refund_payload"
result = RefundResult(
tool_name="issue_refund",
success=False,
refunded=False,
amount_cents=action.amount_cents,
reason_code=action.reason_code,
message="Refund payload does not match the verified duplicate charge.",
)
return result, self.invalid_action_penalty, redundancy_penalty
session.refund_record = {
"amount_cents": action.amount_cents,
"reason_code": action.reason_code,
"step_index": session.steps_taken,
}
result = RefundResult(
tool_name="issue_refund",
success=True,
refunded=True,
amount_cents=action.amount_cents,
reason_code=action.reason_code,
message="Refund recorded successfully.",
)
return result, invalid_penalty, redundancy_penalty
if action.action_type == "resolve_ticket":
session.resolution_code = action.resolution_code
session.ticket_status = "resolved"
session.done = True
session.terminal_reason = "resolved"
result = ResolveResult(
tool_name="resolve_ticket",
success=True,
resolution_code=action.resolution_code,
ticket_status="resolved",
message="Ticket marked as resolved.",
)
return result, invalid_penalty, redundancy_penalty
if isinstance(action, EscalateTicketAction):
session.escalation = {
"queue": action.queue,
"priority": action.priority,
"summary": action.summary,
"step_index": session.steps_taken,
}
session.ticket_status = "escalated"
session.done = True
session.terminal_reason = "escalated"
result = EscalationResult(
tool_name="escalate_ticket",
success=True,
queue=action.queue,
priority=action.priority,
summary=action.summary,
ticket_status="escalated",
message="Ticket escalated.",
)
return result, invalid_penalty, redundancy_penalty
session.last_action_error = "unsupported_action"
return (
ErrorToolResult(
tool_name="error",
success=False,
error_code="unsupported_action",
message=f"Unsupported action type: {type(action).__name__}",
),
self.invalid_action_penalty,
redundancy_penalty,
)
def _search_knowledge_base(self, query: str) -> list[KnowledgeBaseArticle]:
query_terms = set(normalize_text(query).split())
ranked: list[tuple[int, str, KnowledgeBaseArticle]] = []
for article in KB_ARTICLES.values():
searchable = normalize_text(" ".join((article.title, article.content, " ".join(article.tags))))
article_terms = set(searchable.split())
score = len(query_terms & article_terms)
if score > 0:
ranked.append((score, article.article_id, article))
ranked.sort(key=lambda item: (-item[0], item[1]))
return [article for _, _, article in ranked[:3]]
def _record_action(self, action_payload: dict[str, Any], action_succeeded: bool) -> None:
session = self._require_session()
session.action_history.append(
{
"step_index": session.steps_taken,
"action": action_payload,
"success": action_succeeded,
"ticket_status": session.ticket_status,
}
)
def _finalize_step(self, invalid_penalty: float, redundancy_penalty: float) -> SupportTicketStepResult:
session = self._require_session()
scorecard = build_scorecard(session.fixture, session)
reward = round(
(scorecard.score - session.previous_score) - self.step_cost - invalid_penalty - redundancy_penalty,
6,
)
session.previous_score = scorecard.score
return SupportTicketStepResult(
observation=self._build_observation(),
reward=reward,
done=session.done,
info={
"task_id": session.fixture.task_id,
"benchmark_name": self.benchmark_name,
"score": scorecard.score,
"score_breakdown": scorecard.model_dump(mode="json"),
"success": scorecard.score >= self.success_threshold,
"success_threshold": self.success_threshold,
"terminal_reason": session.terminal_reason,
"invalid_penalty": invalid_penalty,
"redundancy_penalty": redundancy_penalty,
},
)
def _build_observation(self) -> SupportTicketObservation:
session = self._require_session()
ticket = session.fixture.ticket
return SupportTicketObservation(
task_id=session.fixture.task_id,
ticket_id=ticket.ticket_id,
ticket_status=session.ticket_status,
customer_id=ticket.customer_id,
organization_name=ticket.organization_name,
subject=ticket.subject,
customer_message=ticket.message,
conversation_history=list(session.conversation_history),
last_tool_result=session.last_tool_result,
steps_taken=session.steps_taken,
steps_remaining=max(self.max_steps - session.steps_taken, 0),
available_action_types=list(ACTION_TYPE_NAMES),
last_action_error=session.last_action_error,
known_facts=dict(session.known_facts),
)
def _build_result(self, reward: float) -> SupportTicketStepResult:
session = self._require_session()
scorecard = build_scorecard(session.fixture, session)
session.previous_score = scorecard.score
return SupportTicketStepResult(
observation=self._build_observation(),
reward=reward,
done=session.done,
info={
"task_id": session.fixture.task_id,
"benchmark_name": self.benchmark_name,
"score": scorecard.score,
"score_breakdown": scorecard.model_dump(mode="json"),
"success": scorecard.score >= self.success_threshold,
"success_threshold": self.success_threshold,
"terminal_reason": session.terminal_reason,
"invalid_penalty": 0.0,
"redundancy_penalty": 0.0,
},
)
def _require_session(self) -> SessionState:
if self._session is None:
raise RuntimeError("Environment has not been reset yet.")
return self._session