from __future__ import annotations from dataclasses import dataclass, field from typing import Any from .fixtures import ( BENCHMARK_NAME, DEFAULT_SUCCESS_THRESHOLD, KB_ARTICLES, KnowledgeBaseArticle, TaskFixture, get_task_fixture, list_task_ids, ) from .models import ( ACTION_TYPE_NAMES, AccountLookupResult, ConversationTurn, KBSearchResult, ErrorToolResult, EscalateTicketAction, EscalationResult, IssueRefundAction, LookupAccountAction, RefundResult, ReplyResult, ResolveResult, SearchKBAction, SupportTicketAction, SupportTicketObservation, SupportTicketStepResult, ToolResult, parse_action, ) from .scoring import build_scorecard, normalize_text @dataclass class SessionState: fixture: TaskFixture ticket_status: str = "open" steps_taken: int = 0 conversation_history: list[ConversationTurn] = field(default_factory=list) action_history: list[dict[str, Any]] = field(default_factory=list) reply_history: list[dict[str, Any]] = field(default_factory=list) known_facts: dict[str, Any] = field(default_factory=dict) kb_articles_seen: set[str] = field(default_factory=set) search_signatures: set[str] = field(default_factory=set) lookup_performed: bool = False lookup_customer_id: str | None = None refund_record: dict[str, Any] | None = None refund_attempted: bool = False resolution_code: str | None = None escalation: dict[str, Any] | None = None done: bool = False terminal_reason: str | None = None previous_score: float = 0.0 last_tool_result: ToolResult | None = None last_action_error: str | None = None class SupportTicketEnvironment: benchmark_name = BENCHMARK_NAME max_steps = 8 step_cost = 0.01 invalid_action_penalty = 0.10 repeated_action_penalty = 0.02 success_threshold = DEFAULT_SUCCESS_THRESHOLD def __init__(self, task_id: str | None = None) -> None: self._default_task_id = task_id or list_task_ids()[0] self._session: SessionState | None = None def reset(self, task_id: str | None = None) -> SupportTicketStepResult: fixture = get_task_fixture(task_id or self._default_task_id) self._session = SessionState( fixture=fixture, conversation_history=[ ConversationTurn( role="customer", message=fixture.ticket.message, step_index=0, ) ], ) return self._build_result(reward=0.0) def step(self, action: SupportTicketAction | dict[str, Any]) -> SupportTicketStepResult: session = self._require_session() if session.done: session.last_action_error = "episode_already_done" session.last_tool_result = ErrorToolResult( tool_name="error", success=False, error_code="episode_already_done", message="This ticket is already terminal. Reset the environment before stepping again.", ) return self._build_result(reward=-self.invalid_action_penalty) invalid_penalty = 0.0 redundancy_penalty = 0.0 session.last_action_error = None try: parsed_action = parse_action(action) except Exception as exc: session.steps_taken += 1 session.last_action_error = f"invalid_action: {exc}" session.last_tool_result = ErrorToolResult( tool_name="error", success=False, error_code="invalid_action", message=str(exc), ) invalid_penalty = self.invalid_action_penalty self._record_action({"action_type": "invalid"}, False) if session.steps_taken >= self.max_steps: session.done = True session.terminal_reason = "max_steps_exceeded" return self._finalize_step(invalid_penalty=invalid_penalty, redundancy_penalty=0.0) session.steps_taken += 1 session.last_tool_result, invalid_penalty, redundancy_penalty = self._apply_action(parsed_action) action_succeeded = bool(getattr(session.last_tool_result, "success", False)) self._record_action(parsed_action.model_dump(mode="json"), action_succeeded) if not session.done and session.steps_taken >= self.max_steps: session.done = True session.terminal_reason = "max_steps_exceeded" return self._finalize_step( invalid_penalty=invalid_penalty, redundancy_penalty=redundancy_penalty, ) def state(self) -> dict[str, Any]: session = self._require_session() scorecard = build_scorecard(session.fixture, session) return { "benchmark_name": self.benchmark_name, "task_id": session.fixture.task_id, "ticket_status": session.ticket_status, "steps_taken": session.steps_taken, "steps_remaining": max(self.max_steps - session.steps_taken, 0), "conversation_history": [turn.model_dump(mode="json") for turn in session.conversation_history], "audit_log": list(session.action_history), "known_facts": dict(session.known_facts), "current_rubric_score": scorecard.score, "score_breakdown": scorecard.model_dump(mode="json"), "terminal_reason": session.terminal_reason, "done": session.done, } def _apply_action(self, action: SupportTicketAction) -> tuple[ToolResult, float, float]: session = self._require_session() invalid_penalty = 0.0 redundancy_penalty = 0.0 if isinstance(action, SearchKBAction): query_signature = normalize_text(action.query) if query_signature in session.search_signatures: redundancy_penalty = self.repeated_action_penalty session.search_signatures.add(query_signature) articles = self._search_knowledge_base(action.query) article_ids = [article.article_id for article in articles] session.kb_articles_seen.update(article_ids) session.known_facts["kb_articles_seen"] = sorted(session.kb_articles_seen) session.known_facts["kb_titles_seen"] = [KB_ARTICLES[article_id].title for article_id in sorted(session.kb_articles_seen)] result = KBSearchResult( tool_name="search_kb", success=bool(articles), query=action.query, article_ids=article_ids, snippets=[article.snippet for article in articles], message="Knowledge base search completed." if articles else "No KB articles matched the query.", ) return result, invalid_penalty, redundancy_penalty if isinstance(action, LookupAccountAction): if action.customer_id != session.fixture.account.customer_id: session.last_action_error = "unknown_customer_id" result = ErrorToolResult( tool_name="error", success=False, error_code="unknown_customer_id", message=f"No account found for customer_id={action.customer_id}.", ) return result, self.invalid_action_penalty, redundancy_penalty if session.lookup_performed and session.lookup_customer_id == action.customer_id: redundancy_penalty = self.repeated_action_penalty account = session.fixture.account session.lookup_performed = True session.lookup_customer_id = action.customer_id account_summary = { "customer_id": account.customer_id, "organization_name": account.organization_name, "plan": account.plan, "tenure_years": account.tenure_years, "arr_usd": account.arr_usd, "duplicate_charge_amount_cents": account.duplicate_charge_amount_cents, "duplicate_charge_count": account.duplicate_charge_count, "duplicate_charge_refund_eligible": account.duplicate_charge_refund_eligible, "legal_threat": account.legal_threat, "incident_severity": account.incident_severity, } session.known_facts["account"] = account_summary result = AccountLookupResult( tool_name="lookup_account", success=True, customer_id=action.customer_id, account_summary=account_summary, message="Account lookup completed.", ) return result, invalid_penalty, redundancy_penalty if action.action_type == "send_reply": reply = action.message.strip() session.reply_history.append({"message": reply, "step_index": session.steps_taken}) session.conversation_history.append( ConversationTurn(role="agent", message=reply, step_index=session.steps_taken) ) result = ReplyResult( tool_name="send_reply", success=True, message_preview=reply[:120], message="Reply sent to the customer.", ) return result, invalid_penalty, redundancy_penalty if isinstance(action, IssueRefundAction): session.refund_attempted = True account = session.fixture.account if not session.lookup_performed: session.last_action_error = "lookup_required_before_refund" result = ErrorToolResult( tool_name="error", success=False, error_code="lookup_required_before_refund", message="lookup_account must succeed before issue_refund can be used.", ) return result, self.invalid_action_penalty, redundancy_penalty if not account.duplicate_charge_refund_eligible or not account.duplicate_charge_amount_cents: session.last_action_error = "refund_not_applicable" result = RefundResult( tool_name="issue_refund", success=False, refunded=False, amount_cents=action.amount_cents, reason_code=action.reason_code, message="No duplicate charge is eligible for refund on this account.", ) return result, self.invalid_action_penalty, redundancy_penalty if action.amount_cents != account.duplicate_charge_amount_cents or action.reason_code != "duplicate_charge": session.last_action_error = "incorrect_refund_payload" result = RefundResult( tool_name="issue_refund", success=False, refunded=False, amount_cents=action.amount_cents, reason_code=action.reason_code, message="Refund payload does not match the verified duplicate charge.", ) return result, self.invalid_action_penalty, redundancy_penalty session.refund_record = { "amount_cents": action.amount_cents, "reason_code": action.reason_code, "step_index": session.steps_taken, } result = RefundResult( tool_name="issue_refund", success=True, refunded=True, amount_cents=action.amount_cents, reason_code=action.reason_code, message="Refund recorded successfully.", ) return result, invalid_penalty, redundancy_penalty if action.action_type == "resolve_ticket": session.resolution_code = action.resolution_code session.ticket_status = "resolved" session.done = True session.terminal_reason = "resolved" result = ResolveResult( tool_name="resolve_ticket", success=True, resolution_code=action.resolution_code, ticket_status="resolved", message="Ticket marked as resolved.", ) return result, invalid_penalty, redundancy_penalty if isinstance(action, EscalateTicketAction): session.escalation = { "queue": action.queue, "priority": action.priority, "summary": action.summary, "step_index": session.steps_taken, } session.ticket_status = "escalated" session.done = True session.terminal_reason = "escalated" result = EscalationResult( tool_name="escalate_ticket", success=True, queue=action.queue, priority=action.priority, summary=action.summary, ticket_status="escalated", message="Ticket escalated.", ) return result, invalid_penalty, redundancy_penalty session.last_action_error = "unsupported_action" return ( ErrorToolResult( tool_name="error", success=False, error_code="unsupported_action", message=f"Unsupported action type: {type(action).__name__}", ), self.invalid_action_penalty, redundancy_penalty, ) def _search_knowledge_base(self, query: str) -> list[KnowledgeBaseArticle]: query_terms = set(normalize_text(query).split()) ranked: list[tuple[int, str, KnowledgeBaseArticle]] = [] for article in KB_ARTICLES.values(): searchable = normalize_text(" ".join((article.title, article.content, " ".join(article.tags)))) article_terms = set(searchable.split()) score = len(query_terms & article_terms) if score > 0: ranked.append((score, article.article_id, article)) ranked.sort(key=lambda item: (-item[0], item[1])) return [article for _, _, article in ranked[:3]] def _record_action(self, action_payload: dict[str, Any], action_succeeded: bool) -> None: session = self._require_session() session.action_history.append( { "step_index": session.steps_taken, "action": action_payload, "success": action_succeeded, "ticket_status": session.ticket_status, } ) def _finalize_step(self, invalid_penalty: float, redundancy_penalty: float) -> SupportTicketStepResult: session = self._require_session() scorecard = build_scorecard(session.fixture, session) reward = round( (scorecard.score - session.previous_score) - self.step_cost - invalid_penalty - redundancy_penalty, 6, ) session.previous_score = scorecard.score return SupportTicketStepResult( observation=self._build_observation(), reward=reward, done=session.done, info={ "task_id": session.fixture.task_id, "benchmark_name": self.benchmark_name, "score": scorecard.score, "score_breakdown": scorecard.model_dump(mode="json"), "success": scorecard.score >= self.success_threshold, "success_threshold": self.success_threshold, "terminal_reason": session.terminal_reason, "invalid_penalty": invalid_penalty, "redundancy_penalty": redundancy_penalty, }, ) def _build_observation(self) -> SupportTicketObservation: session = self._require_session() ticket = session.fixture.ticket return SupportTicketObservation( task_id=session.fixture.task_id, ticket_id=ticket.ticket_id, ticket_status=session.ticket_status, customer_id=ticket.customer_id, organization_name=ticket.organization_name, subject=ticket.subject, customer_message=ticket.message, conversation_history=list(session.conversation_history), last_tool_result=session.last_tool_result, steps_taken=session.steps_taken, steps_remaining=max(self.max_steps - session.steps_taken, 0), available_action_types=list(ACTION_TYPE_NAMES), last_action_error=session.last_action_error, known_facts=dict(session.known_facts), ) def _build_result(self, reward: float) -> SupportTicketStepResult: session = self._require_session() scorecard = build_scorecard(session.fixture, session) session.previous_score = scorecard.score return SupportTicketStepResult( observation=self._build_observation(), reward=reward, done=session.done, info={ "task_id": session.fixture.task_id, "benchmark_name": self.benchmark_name, "score": scorecard.score, "score_breakdown": scorecard.model_dump(mode="json"), "success": scorecard.score >= self.success_threshold, "success_threshold": self.success_threshold, "terminal_reason": session.terminal_reason, "invalid_penalty": 0.0, "redundancy_penalty": 0.0, }, ) def _require_session(self) -> SessionState: if self._session is None: raise RuntimeError("Environment has not been reset yet.") return self._session