triagesieve_env / baseline /scripted_expert.py
Angshuman28's picture
Upload folder using huggingface_hub
b89c8aa verified
"""Scripted expert oracle policy (mandatory baseline, §22.1).
Reads hidden ground truth to follow the gold SOP graph exactly:
keyword-match classification, derived impact/urgency, correct queue + template selection,
avoids unnecessary escalation. Used for regression-test ground truth and demo traces.
Public API:
ScriptedExpert(env) — wraps a TriageSieveEnvironment
ScriptedExpert.run_episode() — runs one full episode, returns structured trace dict
"""
from __future__ import annotations
from typing import Any
from ..models import (
ActionType,
CloseReason,
Priority,
QueueId,
TriageSieveAction,
TaskDifficulty,
)
from ..server.scorer import (
EpisodeScoringContext,
ScoreBreakdown,
compute_episode_score,
)
from ..server.triagesieve_env_environment import TriageSieveEnvironment
__all__ = ["ScriptedExpert"]
# Priority sort key: higher priority → lower sort value (processed first).
_PRIORITY_ORDER: dict[Priority, int] = {
Priority.CRITICAL: 0,
Priority.HIGH: 1,
Priority.MEDIUM: 2,
Priority.LOW: 3,
}
# Actions that count as "substantive" for priority-order scoring (§19).
_SUBSTANTIVE_ACTIONS: frozenset[ActionType] = frozenset({
ActionType.CLASSIFY_TICKET,
ActionType.ROUTE_TICKET,
ActionType.CLOSE_TICKET,
})
class ScriptedExpert:
"""Oracle policy that reads hidden truth to produce optimal action sequences.
This is NOT a fair agent — it accesses internal ground truth via
``env._ticket_index[ticket_id].hidden_truth``. Its purpose is to:
1. Prove environment solvability.
2. Produce reference traces for regression testing.
3. Establish a score ceiling for comparison with learned policies.
Args:
env: A fresh (or reusable) TriageSieveEnvironment instance.
"""
def __init__(self, env: TriageSieveEnvironment) -> None:
self.env = env
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def run_episode(
self,
seed: int,
difficulty: TaskDifficulty | None = None,
) -> dict[str, Any]:
"""Run a full episode with oracle actions, return a structured trace.
Args:
seed: Deterministic seed for episode generation.
difficulty: Task difficulty tier. If None, seed-derived.
Returns:
Trace dict with keys: episode_id, seed, task_difficulty, done,
action_sequence, final_score, score_breakdown.
"""
kwargs: dict[str, Any] = {"mode": "eval_strict"}
if difficulty is not None:
kwargs["difficulty"] = difficulty.value
obs = self.env.reset(seed=seed, **kwargs)
state = self.env.state
# Plan ticket processing order: highest priority first (§19).
ordered_ticket_ids = self._plan_ticket_order()
action_sequence: list[dict[str, Any]] = []
step_num = 0
# Tracking for scorer context
templates_used: dict[str, list[str]] = {}
route_count: dict[str, int] = {}
first_substantive_step: dict[str, int] = {}
for ticket_id in ordered_ticket_ids:
actions = self._plan_ticket_actions(ticket_id)
for action in actions:
if obs.done or obs.action_budget_remaining <= 0:
break
step_num += 1
obs = self.env.step(action)
action_sequence.append({
"step": step_num,
"action": self._serialize_action(action),
"result": obs.last_action_result,
"step_reward": obs.reward,
})
# Track templates used
tid = action.ticket_id
if tid is not None and action.template_id is not None:
templates_used.setdefault(tid, []).append(action.template_id)
# Track route count
if tid is not None and action.action_type == ActionType.ROUTE_TICKET:
route_count[tid] = route_count.get(tid, 0) + 1
# Track first substantive step
if (
tid is not None
and action.action_type in _SUBSTANTIVE_ACTIONS
and tid not in first_substantive_step
):
first_substantive_step[tid] = step_num
if obs.done:
break
# FINISH_EPISODE if not already done
if not obs.done and obs.action_budget_remaining > 0:
finish = TriageSieveAction(
action_type=ActionType.FINISH_EPISODE,
metadata={},
)
step_num += 1
obs = self.env.step(finish)
action_sequence.append({
"step": step_num,
"action": self._serialize_action(finish),
"result": obs.last_action_result,
"step_reward": obs.reward,
})
# Compute proper terminal score via scorer
invalid_count = sum(1 for entry in action_sequence if entry["result"] != "ok")
score_breakdown = self._compute_score(
templates_used, route_count, first_substantive_step, invalid_count
)
return {
"episode_id": state.episode_id,
"seed": seed,
"task_difficulty": state.task_difficulty.value,
"done": obs.done,
"action_sequence": action_sequence,
"final_score": score_breakdown.final_score,
"score_breakdown": {
"terminal_business_score": score_breakdown.terminal_business_score,
"ujcs_openenv": score_breakdown.ujcs_openenv,
"episode_penalties": score_breakdown.episode_penalties.total_penalty,
"priority_order_score": score_breakdown.priority_order_score,
"invalid_action_count": score_breakdown.invalid_action_count,
"reassignment_count": score_breakdown.reassignment_count,
},
}
# ------------------------------------------------------------------
# Scoring
# ------------------------------------------------------------------
def _compute_score(
self,
templates_used: dict[str, list[str]],
route_count: dict[str, int],
first_substantive_step: dict[str, int],
invalid_action_count: int,
) -> ScoreBreakdown:
"""Build EpisodeScoringContext from environment state and compute score.
Args:
templates_used: Map ticket_id → list of template_ids used.
route_count: Map ticket_id → number of route actions.
first_substantive_step: Map ticket_id → step number of first substantive action.
invalid_action_count: Number of actions that returned non-"ok" results.
Returns:
ScoreBreakdown from scorer.
"""
env = self.env
ctx = EpisodeScoringContext(
tickets=list(env._ticket_index.values()),
ticket_states=dict(env._ticket_states),
ticket_classifications=dict(env._ticket_classifications),
ticket_impact_urgency=dict(env._ticket_impact_urgency),
ticket_routed_to=dict(env._ticket_routed_to),
ticket_escalated_to=dict(env._ticket_escalated_to),
ticket_close_reasons=dict(env._ticket_close_reasons),
ticket_info_requested=dict(env._ticket_info_requested),
ticket_info_received=dict(env._ticket_info_received),
ticket_merged_to=dict(env._ticket_merged_to),
ticket_templates_used=templates_used,
sop_trackers=dict(env._sop_trackers),
invalid_action_count=invalid_action_count,
ticket_route_count=route_count,
ticket_first_substantive_step=first_substantive_step,
)
return compute_episode_score(ctx)
# ------------------------------------------------------------------
# Ticket ordering
# ------------------------------------------------------------------
def _plan_ticket_order(self) -> list[str]:
"""Sort tickets by gold priority descending (critical first).
Reads hidden truth priority to maximize §19 priority-order score.
Returns:
Ordered list of ticket_ids.
"""
tickets = list(self.env._ticket_index.values())
tickets.sort(key=lambda t: _PRIORITY_ORDER[t.hidden_truth.priority])
return [t.ticket_id for t in tickets]
# ------------------------------------------------------------------
# Per-ticket action planning
# ------------------------------------------------------------------
def _plan_ticket_actions(self, ticket_id: str) -> list[TriageSieveAction]:
"""Plan the full oracle action sequence for a single ticket.
Reads hidden truth and branches on:
- Non-actionable → open + close(non_actionable)
- Duplicate → open + merge
- Feature request → open + classify + close(feature_request)
- Normal flow → open, classify, set_impact_urgency, [request_info],
route or escalate, close
Args:
ticket_id: Ticket to plan actions for.
Returns:
Ordered list of TriageSieveAction objects.
"""
ht = self.env._ticket_index[ticket_id].hidden_truth
actions: list[TriageSieveAction] = []
# 1. Always open first
actions.append(TriageSieveAction(
action_type=ActionType.OPEN_TICKET,
ticket_id=ticket_id,
metadata={},
))
# 2. Branch: non-actionable
# Classify first so the SOP tracker advances through the "identify_*" checkpoint
# (spam, benign, automation_false_positive, data_error archetypes all require it).
if ht.non_actionable_subtype is not None:
actions.append(TriageSieveAction(
action_type=ActionType.CLASSIFY_TICKET,
ticket_id=ticket_id,
issue_family=ht.issue_family,
issue_subtype=ht.issue_subtype,
metadata={},
))
actions.append(TriageSieveAction(
action_type=ActionType.CLOSE_TICKET,
ticket_id=ticket_id,
close_reason=CloseReason.NON_ACTIONABLE,
metadata={},
))
return actions
# 3. Branch: duplicate
if ht.is_duplicate and ht.duplicate_of is not None:
actions.append(TriageSieveAction(
action_type=ActionType.MERGE_DUPLICATE,
ticket_id=ticket_id,
target_ticket_id=ht.duplicate_of,
metadata={},
))
return actions
# 4. Branch: feature request routed to sales_or_feature_requests
# SOP requires: classify → route(sales_or_feature_requests) → close(feature_request)
if ht.required_queue == QueueId.SALES_OR_FEATURE_REQUESTS:
actions.append(TriageSieveAction(
action_type=ActionType.CLASSIFY_TICKET,
ticket_id=ticket_id,
issue_family=ht.issue_family,
issue_subtype=ht.issue_subtype,
metadata={},
))
actions.append(TriageSieveAction(
action_type=ActionType.ROUTE_TICKET,
ticket_id=ticket_id,
queue_id=ht.required_queue,
metadata={},
))
actions.append(TriageSieveAction(
action_type=ActionType.CLOSE_TICKET,
ticket_id=ticket_id,
close_reason=CloseReason.FEATURE_REQUEST,
metadata={},
))
return actions
# 5. Normal flow: classify
actions.append(TriageSieveAction(
action_type=ActionType.CLASSIFY_TICKET,
ticket_id=ticket_id,
issue_family=ht.issue_family,
issue_subtype=ht.issue_subtype,
metadata={},
))
# 6. Set impact/urgency
actions.append(TriageSieveAction(
action_type=ActionType.SET_IMPACT_URGENCY,
ticket_id=ticket_id,
impact=ht.impact,
urgency=ht.urgency,
metadata={},
))
# 7. Request information if missing fields
if ht.required_missing_fields:
template_id = ht.correct_template_ids[0] if ht.correct_template_ids else None
actions.append(TriageSieveAction(
action_type=ActionType.REQUEST_INFORMATION,
ticket_id=ticket_id,
template_id=template_id,
requested_fields=list(ht.required_missing_fields),
metadata={},
))
# 8. Route or escalate
if ht.escalation_required and ht.escalation_target is not None:
# Route first, then escalate (route → escalated is valid per §12)
actions.append(TriageSieveAction(
action_type=ActionType.ROUTE_TICKET,
ticket_id=ticket_id,
queue_id=ht.required_queue,
metadata={},
))
actions.append(TriageSieveAction(
action_type=ActionType.ESCALATE_TICKET,
ticket_id=ticket_id,
queue_id=ht.escalation_target,
reason_code="expert_escalation",
metadata={},
))
else:
actions.append(TriageSieveAction(
action_type=ActionType.ROUTE_TICKET,
ticket_id=ticket_id,
queue_id=ht.required_queue,
metadata={},
))
# 9. Close with correct template
close_template_id = (
ht.correct_template_ids[-1]
if ht.correct_template_ids
else None
)
actions.append(TriageSieveAction(
action_type=ActionType.CLOSE_TICKET,
ticket_id=ticket_id,
close_reason=CloseReason.RESOLVED,
template_id=close_template_id,
metadata={},
))
return actions
# ------------------------------------------------------------------
# Serialization helper
# ------------------------------------------------------------------
@staticmethod
def _serialize_action(action: TriageSieveAction) -> dict[str, Any]:
"""Serialize an action to a plain dict for traces.
Includes only non-None fields for readability.
"""
data: dict[str, Any] = {"action_type": action.action_type.value}
for field_name in (
"ticket_id",
"issue_family",
"issue_subtype",
"impact",
"urgency",
"queue_id",
"reason_code",
"template_id",
"requested_fields",
"target_ticket_id",
"close_reason",
):
value = getattr(action, field_name, None)
if value is not None:
data[field_name] = value.value if hasattr(value, "value") else value
return data