Spaces:
Running
Running
| """Scripted expert oracle policy (mandatory baseline, §22.1). | |
| Reads hidden ground truth to follow the gold SOP graph exactly: | |
| keyword-match classification, derived impact/urgency, correct queue + template selection, | |
| avoids unnecessary escalation. Used for regression-test ground truth and demo traces. | |
| Public API: | |
| ScriptedExpert(env) — wraps a TriageSieveEnvironment | |
| ScriptedExpert.run_episode() — runs one full episode, returns structured trace dict | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| from ..models import ( | |
| ActionType, | |
| CloseReason, | |
| Priority, | |
| QueueId, | |
| TriageSieveAction, | |
| TaskDifficulty, | |
| ) | |
| from ..server.scorer import ( | |
| EpisodeScoringContext, | |
| ScoreBreakdown, | |
| compute_episode_score, | |
| ) | |
| from ..server.triagesieve_env_environment import TriageSieveEnvironment | |
| __all__ = ["ScriptedExpert"] | |
| # Priority sort key: higher priority → lower sort value (processed first). | |
| _PRIORITY_ORDER: dict[Priority, int] = { | |
| Priority.CRITICAL: 0, | |
| Priority.HIGH: 1, | |
| Priority.MEDIUM: 2, | |
| Priority.LOW: 3, | |
| } | |
| # Actions that count as "substantive" for priority-order scoring (§19). | |
| _SUBSTANTIVE_ACTIONS: frozenset[ActionType] = frozenset({ | |
| ActionType.CLASSIFY_TICKET, | |
| ActionType.ROUTE_TICKET, | |
| ActionType.CLOSE_TICKET, | |
| }) | |
| class ScriptedExpert: | |
| """Oracle policy that reads hidden truth to produce optimal action sequences. | |
| This is NOT a fair agent — it accesses internal ground truth via | |
| ``env._ticket_index[ticket_id].hidden_truth``. Its purpose is to: | |
| 1. Prove environment solvability. | |
| 2. Produce reference traces for regression testing. | |
| 3. Establish a score ceiling for comparison with learned policies. | |
| Args: | |
| env: A fresh (or reusable) TriageSieveEnvironment instance. | |
| """ | |
| def __init__(self, env: TriageSieveEnvironment) -> None: | |
| self.env = env | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def run_episode( | |
| self, | |
| seed: int, | |
| difficulty: TaskDifficulty | None = None, | |
| ) -> dict[str, Any]: | |
| """Run a full episode with oracle actions, return a structured trace. | |
| Args: | |
| seed: Deterministic seed for episode generation. | |
| difficulty: Task difficulty tier. If None, seed-derived. | |
| Returns: | |
| Trace dict with keys: episode_id, seed, task_difficulty, done, | |
| action_sequence, final_score, score_breakdown. | |
| """ | |
| kwargs: dict[str, Any] = {"mode": "eval_strict"} | |
| if difficulty is not None: | |
| kwargs["difficulty"] = difficulty.value | |
| obs = self.env.reset(seed=seed, **kwargs) | |
| state = self.env.state | |
| # Plan ticket processing order: highest priority first (§19). | |
| ordered_ticket_ids = self._plan_ticket_order() | |
| action_sequence: list[dict[str, Any]] = [] | |
| step_num = 0 | |
| # Tracking for scorer context | |
| templates_used: dict[str, list[str]] = {} | |
| route_count: dict[str, int] = {} | |
| first_substantive_step: dict[str, int] = {} | |
| for ticket_id in ordered_ticket_ids: | |
| actions = self._plan_ticket_actions(ticket_id) | |
| for action in actions: | |
| if obs.done or obs.action_budget_remaining <= 0: | |
| break | |
| step_num += 1 | |
| obs = self.env.step(action) | |
| action_sequence.append({ | |
| "step": step_num, | |
| "action": self._serialize_action(action), | |
| "result": obs.last_action_result, | |
| "step_reward": obs.reward, | |
| }) | |
| # Track templates used | |
| tid = action.ticket_id | |
| if tid is not None and action.template_id is not None: | |
| templates_used.setdefault(tid, []).append(action.template_id) | |
| # Track route count | |
| if tid is not None and action.action_type == ActionType.ROUTE_TICKET: | |
| route_count[tid] = route_count.get(tid, 0) + 1 | |
| # Track first substantive step | |
| if ( | |
| tid is not None | |
| and action.action_type in _SUBSTANTIVE_ACTIONS | |
| and tid not in first_substantive_step | |
| ): | |
| first_substantive_step[tid] = step_num | |
| if obs.done: | |
| break | |
| # FINISH_EPISODE if not already done | |
| if not obs.done and obs.action_budget_remaining > 0: | |
| finish = TriageSieveAction( | |
| action_type=ActionType.FINISH_EPISODE, | |
| metadata={}, | |
| ) | |
| step_num += 1 | |
| obs = self.env.step(finish) | |
| action_sequence.append({ | |
| "step": step_num, | |
| "action": self._serialize_action(finish), | |
| "result": obs.last_action_result, | |
| "step_reward": obs.reward, | |
| }) | |
| # Compute proper terminal score via scorer | |
| invalid_count = sum(1 for entry in action_sequence if entry["result"] != "ok") | |
| score_breakdown = self._compute_score( | |
| templates_used, route_count, first_substantive_step, invalid_count | |
| ) | |
| return { | |
| "episode_id": state.episode_id, | |
| "seed": seed, | |
| "task_difficulty": state.task_difficulty.value, | |
| "done": obs.done, | |
| "action_sequence": action_sequence, | |
| "final_score": score_breakdown.final_score, | |
| "score_breakdown": { | |
| "terminal_business_score": score_breakdown.terminal_business_score, | |
| "ujcs_openenv": score_breakdown.ujcs_openenv, | |
| "episode_penalties": score_breakdown.episode_penalties.total_penalty, | |
| "priority_order_score": score_breakdown.priority_order_score, | |
| "invalid_action_count": score_breakdown.invalid_action_count, | |
| "reassignment_count": score_breakdown.reassignment_count, | |
| }, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Scoring | |
| # ------------------------------------------------------------------ | |
| def _compute_score( | |
| self, | |
| templates_used: dict[str, list[str]], | |
| route_count: dict[str, int], | |
| first_substantive_step: dict[str, int], | |
| invalid_action_count: int, | |
| ) -> ScoreBreakdown: | |
| """Build EpisodeScoringContext from environment state and compute score. | |
| Args: | |
| templates_used: Map ticket_id → list of template_ids used. | |
| route_count: Map ticket_id → number of route actions. | |
| first_substantive_step: Map ticket_id → step number of first substantive action. | |
| invalid_action_count: Number of actions that returned non-"ok" results. | |
| Returns: | |
| ScoreBreakdown from scorer. | |
| """ | |
| env = self.env | |
| ctx = EpisodeScoringContext( | |
| tickets=list(env._ticket_index.values()), | |
| ticket_states=dict(env._ticket_states), | |
| ticket_classifications=dict(env._ticket_classifications), | |
| ticket_impact_urgency=dict(env._ticket_impact_urgency), | |
| ticket_routed_to=dict(env._ticket_routed_to), | |
| ticket_escalated_to=dict(env._ticket_escalated_to), | |
| ticket_close_reasons=dict(env._ticket_close_reasons), | |
| ticket_info_requested=dict(env._ticket_info_requested), | |
| ticket_info_received=dict(env._ticket_info_received), | |
| ticket_merged_to=dict(env._ticket_merged_to), | |
| ticket_templates_used=templates_used, | |
| sop_trackers=dict(env._sop_trackers), | |
| invalid_action_count=invalid_action_count, | |
| ticket_route_count=route_count, | |
| ticket_first_substantive_step=first_substantive_step, | |
| ) | |
| return compute_episode_score(ctx) | |
| # ------------------------------------------------------------------ | |
| # Ticket ordering | |
| # ------------------------------------------------------------------ | |
| def _plan_ticket_order(self) -> list[str]: | |
| """Sort tickets by gold priority descending (critical first). | |
| Reads hidden truth priority to maximize §19 priority-order score. | |
| Returns: | |
| Ordered list of ticket_ids. | |
| """ | |
| tickets = list(self.env._ticket_index.values()) | |
| tickets.sort(key=lambda t: _PRIORITY_ORDER[t.hidden_truth.priority]) | |
| return [t.ticket_id for t in tickets] | |
| # ------------------------------------------------------------------ | |
| # Per-ticket action planning | |
| # ------------------------------------------------------------------ | |
| def _plan_ticket_actions(self, ticket_id: str) -> list[TriageSieveAction]: | |
| """Plan the full oracle action sequence for a single ticket. | |
| Reads hidden truth and branches on: | |
| - Non-actionable → open + close(non_actionable) | |
| - Duplicate → open + merge | |
| - Feature request → open + classify + close(feature_request) | |
| - Normal flow → open, classify, set_impact_urgency, [request_info], | |
| route or escalate, close | |
| Args: | |
| ticket_id: Ticket to plan actions for. | |
| Returns: | |
| Ordered list of TriageSieveAction objects. | |
| """ | |
| ht = self.env._ticket_index[ticket_id].hidden_truth | |
| actions: list[TriageSieveAction] = [] | |
| # 1. Always open first | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.OPEN_TICKET, | |
| ticket_id=ticket_id, | |
| metadata={}, | |
| )) | |
| # 2. Branch: non-actionable | |
| # Classify first so the SOP tracker advances through the "identify_*" checkpoint | |
| # (spam, benign, automation_false_positive, data_error archetypes all require it). | |
| if ht.non_actionable_subtype is not None: | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.CLASSIFY_TICKET, | |
| ticket_id=ticket_id, | |
| issue_family=ht.issue_family, | |
| issue_subtype=ht.issue_subtype, | |
| metadata={}, | |
| )) | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.CLOSE_TICKET, | |
| ticket_id=ticket_id, | |
| close_reason=CloseReason.NON_ACTIONABLE, | |
| metadata={}, | |
| )) | |
| return actions | |
| # 3. Branch: duplicate | |
| if ht.is_duplicate and ht.duplicate_of is not None: | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.MERGE_DUPLICATE, | |
| ticket_id=ticket_id, | |
| target_ticket_id=ht.duplicate_of, | |
| metadata={}, | |
| )) | |
| return actions | |
| # 4. Branch: feature request routed to sales_or_feature_requests | |
| # SOP requires: classify → route(sales_or_feature_requests) → close(feature_request) | |
| if ht.required_queue == QueueId.SALES_OR_FEATURE_REQUESTS: | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.CLASSIFY_TICKET, | |
| ticket_id=ticket_id, | |
| issue_family=ht.issue_family, | |
| issue_subtype=ht.issue_subtype, | |
| metadata={}, | |
| )) | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.ROUTE_TICKET, | |
| ticket_id=ticket_id, | |
| queue_id=ht.required_queue, | |
| metadata={}, | |
| )) | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.CLOSE_TICKET, | |
| ticket_id=ticket_id, | |
| close_reason=CloseReason.FEATURE_REQUEST, | |
| metadata={}, | |
| )) | |
| return actions | |
| # 5. Normal flow: classify | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.CLASSIFY_TICKET, | |
| ticket_id=ticket_id, | |
| issue_family=ht.issue_family, | |
| issue_subtype=ht.issue_subtype, | |
| metadata={}, | |
| )) | |
| # 6. Set impact/urgency | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.SET_IMPACT_URGENCY, | |
| ticket_id=ticket_id, | |
| impact=ht.impact, | |
| urgency=ht.urgency, | |
| metadata={}, | |
| )) | |
| # 7. Request information if missing fields | |
| if ht.required_missing_fields: | |
| template_id = ht.correct_template_ids[0] if ht.correct_template_ids else None | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.REQUEST_INFORMATION, | |
| ticket_id=ticket_id, | |
| template_id=template_id, | |
| requested_fields=list(ht.required_missing_fields), | |
| metadata={}, | |
| )) | |
| # 8. Route or escalate | |
| if ht.escalation_required and ht.escalation_target is not None: | |
| # Route first, then escalate (route → escalated is valid per §12) | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.ROUTE_TICKET, | |
| ticket_id=ticket_id, | |
| queue_id=ht.required_queue, | |
| metadata={}, | |
| )) | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.ESCALATE_TICKET, | |
| ticket_id=ticket_id, | |
| queue_id=ht.escalation_target, | |
| reason_code="expert_escalation", | |
| metadata={}, | |
| )) | |
| else: | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.ROUTE_TICKET, | |
| ticket_id=ticket_id, | |
| queue_id=ht.required_queue, | |
| metadata={}, | |
| )) | |
| # 9. Close with correct template | |
| close_template_id = ( | |
| ht.correct_template_ids[-1] | |
| if ht.correct_template_ids | |
| else None | |
| ) | |
| actions.append(TriageSieveAction( | |
| action_type=ActionType.CLOSE_TICKET, | |
| ticket_id=ticket_id, | |
| close_reason=CloseReason.RESOLVED, | |
| template_id=close_template_id, | |
| metadata={}, | |
| )) | |
| return actions | |
| # ------------------------------------------------------------------ | |
| # Serialization helper | |
| # ------------------------------------------------------------------ | |
| def _serialize_action(action: TriageSieveAction) -> dict[str, Any]: | |
| """Serialize an action to a plain dict for traces. | |
| Includes only non-None fields for readability. | |
| """ | |
| data: dict[str, Any] = {"action_type": action.action_type.value} | |
| for field_name in ( | |
| "ticket_id", | |
| "issue_family", | |
| "issue_subtype", | |
| "impact", | |
| "urgency", | |
| "queue_id", | |
| "reason_code", | |
| "template_id", | |
| "requested_fields", | |
| "target_ticket_id", | |
| "close_reason", | |
| ): | |
| value = getattr(action, field_name, None) | |
| if value is not None: | |
| data[field_name] = value.value if hasattr(value, "value") else value | |
| return data | |