Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| from .fixtures import ( | |
| BENCHMARK_NAME, | |
| DEFAULT_SUCCESS_THRESHOLD, | |
| KB_ARTICLES, | |
| KnowledgeBaseArticle, | |
| TaskFixture, | |
| get_task_fixture, | |
| list_task_ids, | |
| ) | |
| from .models import ( | |
| ACTION_TYPE_NAMES, | |
| AccountLookupResult, | |
| ConversationTurn, | |
| KBSearchResult, | |
| ErrorToolResult, | |
| EscalateTicketAction, | |
| EscalationResult, | |
| IssueRefundAction, | |
| LookupAccountAction, | |
| RefundResult, | |
| ReplyResult, | |
| ResolveResult, | |
| SearchKBAction, | |
| SupportTicketAction, | |
| SupportTicketObservation, | |
| SupportTicketStepResult, | |
| ToolResult, | |
| parse_action, | |
| ) | |
| from .scoring import build_scorecard, normalize_text | |
| class SessionState: | |
| fixture: TaskFixture | |
| ticket_status: str = "open" | |
| steps_taken: int = 0 | |
| conversation_history: list[ConversationTurn] = field(default_factory=list) | |
| action_history: list[dict[str, Any]] = field(default_factory=list) | |
| reply_history: list[dict[str, Any]] = field(default_factory=list) | |
| known_facts: dict[str, Any] = field(default_factory=dict) | |
| kb_articles_seen: set[str] = field(default_factory=set) | |
| search_signatures: set[str] = field(default_factory=set) | |
| lookup_performed: bool = False | |
| lookup_customer_id: str | None = None | |
| refund_record: dict[str, Any] | None = None | |
| refund_attempted: bool = False | |
| resolution_code: str | None = None | |
| escalation: dict[str, Any] | None = None | |
| done: bool = False | |
| terminal_reason: str | None = None | |
| previous_score: float = 0.0 | |
| last_tool_result: ToolResult | None = None | |
| last_action_error: str | None = None | |
| class SupportTicketEnvironment: | |
| benchmark_name = BENCHMARK_NAME | |
| max_steps = 8 | |
| step_cost = 0.01 | |
| invalid_action_penalty = 0.10 | |
| repeated_action_penalty = 0.02 | |
| success_threshold = DEFAULT_SUCCESS_THRESHOLD | |
| def __init__(self, task_id: str | None = None) -> None: | |
| self._default_task_id = task_id or list_task_ids()[0] | |
| self._session: SessionState | None = None | |
| def reset(self, task_id: str | None = None) -> SupportTicketStepResult: | |
| fixture = get_task_fixture(task_id or self._default_task_id) | |
| self._session = SessionState( | |
| fixture=fixture, | |
| conversation_history=[ | |
| ConversationTurn( | |
| role="customer", | |
| message=fixture.ticket.message, | |
| step_index=0, | |
| ) | |
| ], | |
| ) | |
| return self._build_result(reward=0.0) | |
| def step(self, action: SupportTicketAction | dict[str, Any]) -> SupportTicketStepResult: | |
| session = self._require_session() | |
| if session.done: | |
| session.last_action_error = "episode_already_done" | |
| session.last_tool_result = ErrorToolResult( | |
| tool_name="error", | |
| success=False, | |
| error_code="episode_already_done", | |
| message="This ticket is already terminal. Reset the environment before stepping again.", | |
| ) | |
| return self._build_result(reward=-self.invalid_action_penalty) | |
| invalid_penalty = 0.0 | |
| redundancy_penalty = 0.0 | |
| session.last_action_error = None | |
| try: | |
| parsed_action = parse_action(action) | |
| except Exception as exc: | |
| session.steps_taken += 1 | |
| session.last_action_error = f"invalid_action: {exc}" | |
| session.last_tool_result = ErrorToolResult( | |
| tool_name="error", | |
| success=False, | |
| error_code="invalid_action", | |
| message=str(exc), | |
| ) | |
| invalid_penalty = self.invalid_action_penalty | |
| self._record_action({"action_type": "invalid"}, False) | |
| if session.steps_taken >= self.max_steps: | |
| session.done = True | |
| session.terminal_reason = "max_steps_exceeded" | |
| return self._finalize_step(invalid_penalty=invalid_penalty, redundancy_penalty=0.0) | |
| session.steps_taken += 1 | |
| session.last_tool_result, invalid_penalty, redundancy_penalty = self._apply_action(parsed_action) | |
| action_succeeded = bool(getattr(session.last_tool_result, "success", False)) | |
| self._record_action(parsed_action.model_dump(mode="json"), action_succeeded) | |
| if not session.done and session.steps_taken >= self.max_steps: | |
| session.done = True | |
| session.terminal_reason = "max_steps_exceeded" | |
| return self._finalize_step( | |
| invalid_penalty=invalid_penalty, | |
| redundancy_penalty=redundancy_penalty, | |
| ) | |
| def state(self) -> dict[str, Any]: | |
| session = self._require_session() | |
| scorecard = build_scorecard(session.fixture, session) | |
| return { | |
| "benchmark_name": self.benchmark_name, | |
| "task_id": session.fixture.task_id, | |
| "ticket_status": session.ticket_status, | |
| "steps_taken": session.steps_taken, | |
| "steps_remaining": max(self.max_steps - session.steps_taken, 0), | |
| "conversation_history": [turn.model_dump(mode="json") for turn in session.conversation_history], | |
| "audit_log": list(session.action_history), | |
| "known_facts": dict(session.known_facts), | |
| "current_rubric_score": scorecard.score, | |
| "score_breakdown": scorecard.model_dump(mode="json"), | |
| "terminal_reason": session.terminal_reason, | |
| "done": session.done, | |
| } | |
| def _apply_action(self, action: SupportTicketAction) -> tuple[ToolResult, float, float]: | |
| session = self._require_session() | |
| invalid_penalty = 0.0 | |
| redundancy_penalty = 0.0 | |
| if isinstance(action, SearchKBAction): | |
| query_signature = normalize_text(action.query) | |
| if query_signature in session.search_signatures: | |
| redundancy_penalty = self.repeated_action_penalty | |
| session.search_signatures.add(query_signature) | |
| articles = self._search_knowledge_base(action.query) | |
| article_ids = [article.article_id for article in articles] | |
| session.kb_articles_seen.update(article_ids) | |
| session.known_facts["kb_articles_seen"] = sorted(session.kb_articles_seen) | |
| session.known_facts["kb_titles_seen"] = [KB_ARTICLES[article_id].title for article_id in sorted(session.kb_articles_seen)] | |
| result = KBSearchResult( | |
| tool_name="search_kb", | |
| success=bool(articles), | |
| query=action.query, | |
| article_ids=article_ids, | |
| snippets=[article.snippet for article in articles], | |
| message="Knowledge base search completed." if articles else "No KB articles matched the query.", | |
| ) | |
| return result, invalid_penalty, redundancy_penalty | |
| if isinstance(action, LookupAccountAction): | |
| if action.customer_id != session.fixture.account.customer_id: | |
| session.last_action_error = "unknown_customer_id" | |
| result = ErrorToolResult( | |
| tool_name="error", | |
| success=False, | |
| error_code="unknown_customer_id", | |
| message=f"No account found for customer_id={action.customer_id}.", | |
| ) | |
| return result, self.invalid_action_penalty, redundancy_penalty | |
| if session.lookup_performed and session.lookup_customer_id == action.customer_id: | |
| redundancy_penalty = self.repeated_action_penalty | |
| account = session.fixture.account | |
| session.lookup_performed = True | |
| session.lookup_customer_id = action.customer_id | |
| account_summary = { | |
| "customer_id": account.customer_id, | |
| "organization_name": account.organization_name, | |
| "plan": account.plan, | |
| "tenure_years": account.tenure_years, | |
| "arr_usd": account.arr_usd, | |
| "duplicate_charge_amount_cents": account.duplicate_charge_amount_cents, | |
| "duplicate_charge_count": account.duplicate_charge_count, | |
| "duplicate_charge_refund_eligible": account.duplicate_charge_refund_eligible, | |
| "legal_threat": account.legal_threat, | |
| "incident_severity": account.incident_severity, | |
| } | |
| session.known_facts["account"] = account_summary | |
| result = AccountLookupResult( | |
| tool_name="lookup_account", | |
| success=True, | |
| customer_id=action.customer_id, | |
| account_summary=account_summary, | |
| message="Account lookup completed.", | |
| ) | |
| return result, invalid_penalty, redundancy_penalty | |
| if action.action_type == "send_reply": | |
| reply = action.message.strip() | |
| session.reply_history.append({"message": reply, "step_index": session.steps_taken}) | |
| session.conversation_history.append( | |
| ConversationTurn(role="agent", message=reply, step_index=session.steps_taken) | |
| ) | |
| result = ReplyResult( | |
| tool_name="send_reply", | |
| success=True, | |
| message_preview=reply[:120], | |
| message="Reply sent to the customer.", | |
| ) | |
| return result, invalid_penalty, redundancy_penalty | |
| if isinstance(action, IssueRefundAction): | |
| session.refund_attempted = True | |
| account = session.fixture.account | |
| if not session.lookup_performed: | |
| session.last_action_error = "lookup_required_before_refund" | |
| result = ErrorToolResult( | |
| tool_name="error", | |
| success=False, | |
| error_code="lookup_required_before_refund", | |
| message="lookup_account must succeed before issue_refund can be used.", | |
| ) | |
| return result, self.invalid_action_penalty, redundancy_penalty | |
| if not account.duplicate_charge_refund_eligible or not account.duplicate_charge_amount_cents: | |
| session.last_action_error = "refund_not_applicable" | |
| result = RefundResult( | |
| tool_name="issue_refund", | |
| success=False, | |
| refunded=False, | |
| amount_cents=action.amount_cents, | |
| reason_code=action.reason_code, | |
| message="No duplicate charge is eligible for refund on this account.", | |
| ) | |
| return result, self.invalid_action_penalty, redundancy_penalty | |
| if action.amount_cents != account.duplicate_charge_amount_cents or action.reason_code != "duplicate_charge": | |
| session.last_action_error = "incorrect_refund_payload" | |
| result = RefundResult( | |
| tool_name="issue_refund", | |
| success=False, | |
| refunded=False, | |
| amount_cents=action.amount_cents, | |
| reason_code=action.reason_code, | |
| message="Refund payload does not match the verified duplicate charge.", | |
| ) | |
| return result, self.invalid_action_penalty, redundancy_penalty | |
| session.refund_record = { | |
| "amount_cents": action.amount_cents, | |
| "reason_code": action.reason_code, | |
| "step_index": session.steps_taken, | |
| } | |
| result = RefundResult( | |
| tool_name="issue_refund", | |
| success=True, | |
| refunded=True, | |
| amount_cents=action.amount_cents, | |
| reason_code=action.reason_code, | |
| message="Refund recorded successfully.", | |
| ) | |
| return result, invalid_penalty, redundancy_penalty | |
| if action.action_type == "resolve_ticket": | |
| session.resolution_code = action.resolution_code | |
| session.ticket_status = "resolved" | |
| session.done = True | |
| session.terminal_reason = "resolved" | |
| result = ResolveResult( | |
| tool_name="resolve_ticket", | |
| success=True, | |
| resolution_code=action.resolution_code, | |
| ticket_status="resolved", | |
| message="Ticket marked as resolved.", | |
| ) | |
| return result, invalid_penalty, redundancy_penalty | |
| if isinstance(action, EscalateTicketAction): | |
| session.escalation = { | |
| "queue": action.queue, | |
| "priority": action.priority, | |
| "summary": action.summary, | |
| "step_index": session.steps_taken, | |
| } | |
| session.ticket_status = "escalated" | |
| session.done = True | |
| session.terminal_reason = "escalated" | |
| result = EscalationResult( | |
| tool_name="escalate_ticket", | |
| success=True, | |
| queue=action.queue, | |
| priority=action.priority, | |
| summary=action.summary, | |
| ticket_status="escalated", | |
| message="Ticket escalated.", | |
| ) | |
| return result, invalid_penalty, redundancy_penalty | |
| session.last_action_error = "unsupported_action" | |
| return ( | |
| ErrorToolResult( | |
| tool_name="error", | |
| success=False, | |
| error_code="unsupported_action", | |
| message=f"Unsupported action type: {type(action).__name__}", | |
| ), | |
| self.invalid_action_penalty, | |
| redundancy_penalty, | |
| ) | |
| def _search_knowledge_base(self, query: str) -> list[KnowledgeBaseArticle]: | |
| query_terms = set(normalize_text(query).split()) | |
| ranked: list[tuple[int, str, KnowledgeBaseArticle]] = [] | |
| for article in KB_ARTICLES.values(): | |
| searchable = normalize_text(" ".join((article.title, article.content, " ".join(article.tags)))) | |
| article_terms = set(searchable.split()) | |
| score = len(query_terms & article_terms) | |
| if score > 0: | |
| ranked.append((score, article.article_id, article)) | |
| ranked.sort(key=lambda item: (-item[0], item[1])) | |
| return [article for _, _, article in ranked[:3]] | |
| def _record_action(self, action_payload: dict[str, Any], action_succeeded: bool) -> None: | |
| session = self._require_session() | |
| session.action_history.append( | |
| { | |
| "step_index": session.steps_taken, | |
| "action": action_payload, | |
| "success": action_succeeded, | |
| "ticket_status": session.ticket_status, | |
| } | |
| ) | |
| def _finalize_step(self, invalid_penalty: float, redundancy_penalty: float) -> SupportTicketStepResult: | |
| session = self._require_session() | |
| scorecard = build_scorecard(session.fixture, session) | |
| reward = round( | |
| (scorecard.score - session.previous_score) - self.step_cost - invalid_penalty - redundancy_penalty, | |
| 6, | |
| ) | |
| session.previous_score = scorecard.score | |
| return SupportTicketStepResult( | |
| observation=self._build_observation(), | |
| reward=reward, | |
| done=session.done, | |
| info={ | |
| "task_id": session.fixture.task_id, | |
| "benchmark_name": self.benchmark_name, | |
| "score": scorecard.score, | |
| "score_breakdown": scorecard.model_dump(mode="json"), | |
| "success": scorecard.score >= self.success_threshold, | |
| "success_threshold": self.success_threshold, | |
| "terminal_reason": session.terminal_reason, | |
| "invalid_penalty": invalid_penalty, | |
| "redundancy_penalty": redundancy_penalty, | |
| }, | |
| ) | |
| def _build_observation(self) -> SupportTicketObservation: | |
| session = self._require_session() | |
| ticket = session.fixture.ticket | |
| return SupportTicketObservation( | |
| task_id=session.fixture.task_id, | |
| ticket_id=ticket.ticket_id, | |
| ticket_status=session.ticket_status, | |
| customer_id=ticket.customer_id, | |
| organization_name=ticket.organization_name, | |
| subject=ticket.subject, | |
| customer_message=ticket.message, | |
| conversation_history=list(session.conversation_history), | |
| last_tool_result=session.last_tool_result, | |
| steps_taken=session.steps_taken, | |
| steps_remaining=max(self.max_steps - session.steps_taken, 0), | |
| available_action_types=list(ACTION_TYPE_NAMES), | |
| last_action_error=session.last_action_error, | |
| known_facts=dict(session.known_facts), | |
| ) | |
| def _build_result(self, reward: float) -> SupportTicketStepResult: | |
| session = self._require_session() | |
| scorecard = build_scorecard(session.fixture, session) | |
| session.previous_score = scorecard.score | |
| return SupportTicketStepResult( | |
| observation=self._build_observation(), | |
| reward=reward, | |
| done=session.done, | |
| info={ | |
| "task_id": session.fixture.task_id, | |
| "benchmark_name": self.benchmark_name, | |
| "score": scorecard.score, | |
| "score_breakdown": scorecard.model_dump(mode="json"), | |
| "success": scorecard.score >= self.success_threshold, | |
| "success_threshold": self.success_threshold, | |
| "terminal_reason": session.terminal_reason, | |
| "invalid_penalty": 0.0, | |
| "redundancy_penalty": 0.0, | |
| }, | |
| ) | |
| def _require_session(self) -> SessionState: | |
| if self._session is None: | |
| raise RuntimeError("Environment has not been reset yet.") | |
| return self._session | |