dbatcode28's picture
fixex
3e50fa2
from __future__ import annotations
from typing import Dict, List, Optional, Tuple
from .graders import grade_task
from .models import Action, Observation, RewardModel, StateModel, StepInfo, TaskSpec, TicketObservation
from .reward import STEP_PENALTY, build_reward
from .state import discovered_for_ticket, initial_tracking, update_mapping
from .tasks import get_all_tasks, get_task
class SupportOpsEnv:
"""OpenEnv-shaped benchmark for support operations workflows."""
def __init__(self, task_id: Optional[str] = None):
self._tasks = {task.task_id: task for task in get_all_tasks()}
self._task_order = sorted(self._tasks)
self._task_id = task_id or self._task_order[0]
self._task: TaskSpec = self._tasks[self._task_id]
self._state: StateModel = initial_tracking(self._task)
def reset(self, task_id: Optional[str] = None) -> Observation:
if task_id is not None:
self._task = get_task(task_id)
self._task_id = task_id
self._state = initial_tracking(self._task)
return self._build_observation()
def state(self) -> StateModel:
return self._state.model_copy(deep=True)
def step(self, action: Action) -> Tuple[Observation, RewardModel, bool, Dict[str, object]]:
if self._state.done:
reward = build_reward({"invalid_after_done": -0.1}, "Episode already finished.")
info = StepInfo(
task_id=self._task.task_id,
step_count=self._state.step_count,
task_score=self._state.latest_score.get("task_score", 0.0),
done_reason="already_done",
event="invalid_after_done",
event_score=reward.components,
)
return self._build_observation(), reward, True, info.model_dump()
self._state.step_count += 1
event_scores: Dict[str, float] = {"step_penalty": STEP_PENALTY}
event_name = action.action_type
done_reason = None
if action.action_type == "inspect_ticket":
event_scores.update(self._handle_inspect(action))
elif action.action_type == "request_context":
event_scores.update(self._handle_request_context(action))
elif action.action_type == "set_priority":
event_scores.update(self._handle_priority(action))
elif action.action_type == "set_route":
event_scores.update(self._handle_route(action))
elif action.action_type == "set_resolution":
event_scores.update(self._handle_resolution(action))
elif action.action_type == "escalate":
event_scores.update(self._handle_escalation(action))
elif action.action_type == "rank_queue":
event_scores.update(self._handle_rank_queue(action))
elif action.action_type == "finalize":
self._state.done = True
done_reason = "agent_finalized"
grade = grade_task(self._task, self._state)
self._state.latest_score = {"task_score": grade.score, **grade.component_scores}
event_scores["terminal_grade"] = grade.score
reward = build_reward(event_scores, "Final task grade applied.")
self._state.cumulative_reward = round(self._state.cumulative_reward + reward.value, 4)
info = StepInfo(
task_id=self._task.task_id,
step_count=self._state.step_count,
task_score=grade.score,
done_reason=done_reason,
grade=grade,
event=event_name,
event_score=reward.components,
)
return self._build_observation(), reward, True, info.model_dump()
else:
event_scores["invalid_action"] = -0.1
event_name = "invalid_action"
grade = grade_task(self._task, self._state)
self._state.latest_score = {"task_score": grade.score, **grade.component_scores}
if self._state.step_count >= self._task.max_steps and not self._state.done:
self._state.done = True
done_reason = "max_steps"
event_scores["timeout_grade"] = grade.score
reward = build_reward(event_scores, f"Processed {event_name}.")
self._state.cumulative_reward = round(self._state.cumulative_reward + reward.value, 4)
info = StepInfo(
task_id=self._task.task_id,
step_count=self._state.step_count,
task_score=grade.score,
done_reason=done_reason,
grade=grade if self._state.done else None,
event=event_name,
event_score=reward.components,
)
return self._build_observation(), reward, self._state.done, info.model_dump()
def _build_observation(self) -> Observation:
tickets: List[TicketObservation] = []
for ticket in self._task.tickets:
keys = self._state.discovered_keys.get(ticket.ticket_id, [])
discovered_context = {key: ticket.hidden_context[key] for key in keys}
tickets.append(
TicketObservation(
ticket_id=ticket.ticket_id,
summary=ticket.summary,
visible_context=ticket.visible_context,
discovered_context=discovered_context,
selected_priority=self._state.priorities.get(ticket.ticket_id),
selected_route=self._state.routes.get(ticket.ticket_id),
selected_resolution=self._state.resolutions.get(ticket.ticket_id),
escalation_team=self._state.escalations.get(ticket.ticket_id),
)
)
return Observation(
task_id=self._task.task_id,
difficulty=self._task.difficulty,
title=self._task.title,
instruction=self._task.instruction,
queue_mode=self._task.queue_mode,
tickets=tickets,
remaining_steps=max(self._task.max_steps - self._state.step_count, 0),
available_actions=[
"inspect_ticket",
"request_context",
"set_priority",
"set_route",
"set_resolution",
"escalate",
"rank_queue",
"finalize",
],
current_queue_order=self._state.queue_order,
score_hint=self._state.latest_score,
)
def _find_ticket(self, ticket_id: str):
for ticket in self._task.tickets:
if ticket.ticket_id == ticket_id:
return ticket
return None
def _handle_inspect(self, action: Action) -> Dict[str, float]:
ticket = self._find_ticket(action.target)
if ticket is None:
return {"invalid_ticket": -0.1}
key = f"inspected::{ticket.ticket_id}"
notes = self._state.latest_score.setdefault("inspections", 0.0)
if notes and key in self._state.latest_score:
return {"redundant_inspect": -0.03}
self._state.latest_score[key] = 1.0
return {"inspect": 0.03}
def _handle_request_context(self, action: Action) -> Dict[str, float]:
ticket = self._find_ticket(action.target)
if ticket is None or not action.value:
return {"invalid_context_request": -0.1}
if action.value not in ticket.hidden_context:
return {"unknown_context_key": -0.08}
discovered = discovered_for_ticket(self._state.discovered_keys, ticket.ticket_id)
if action.value in discovered:
return {"redundant_context_request": -0.05}
discovered.append(action.value)
if action.value in ticket.required_context:
return {"required_context_found": 0.12}
return {"optional_context_found": 0.04}
def _handle_priority(self, action: Action) -> Dict[str, float]:
ticket = self._find_ticket(action.target)
if ticket is None or not action.value:
return {"invalid_priority": -0.1}
current = self._state.priorities.get(ticket.ticket_id)
update_mapping(self._state.priorities, ticket.ticket_id, action.value)
if action.value == current:
return {"redundant_priority": -0.03}
return {"priority_match": 0.08 if action.value == ticket.gold_priority else -0.04}
def _handle_route(self, action: Action) -> Dict[str, float]:
ticket = self._find_ticket(action.target)
if ticket is None or not action.value:
return {"invalid_route": -0.1}
current = self._state.routes.get(ticket.ticket_id)
update_mapping(self._state.routes, ticket.ticket_id, action.value)
if action.value == current:
return {"redundant_route": -0.03}
return {"route_match": 0.1 if action.value == ticket.gold_route else -0.06}
def _handle_resolution(self, action: Action) -> Dict[str, float]:
ticket = self._find_ticket(action.target)
if ticket is None or not action.value:
return {"invalid_resolution": -0.1}
current = self._state.resolutions.get(ticket.ticket_id)
update_mapping(self._state.resolutions, ticket.ticket_id, action.value)
if action.value == current:
return {"redundant_resolution": -0.03}
return {"resolution_match": 0.12 if action.value == ticket.gold_resolution else -0.08}
def _handle_escalation(self, action: Action) -> Dict[str, float]:
ticket = self._find_ticket(action.target)
if ticket is None:
return {"invalid_escalation": -0.1}
team = action.value
current = self._state.escalations.get(ticket.ticket_id)
update_mapping(self._state.escalations, ticket.ticket_id, team)
if team == current:
return {"redundant_escalation": -0.03}
if team == ticket.gold_escalation_team:
return {"correct_escalation": 0.1}
if ticket.gold_escalation_team is None and team is None:
return {"correct_no_escalation": 0.03}
return {"incorrect_escalation": -0.1}
def _handle_rank_queue(self, action: Action) -> Dict[str, float]:
if not self._task.queue_mode or not action.value:
return {"invalid_queue_ranking": -0.1}
ranked = [item.strip() for item in action.value.split(",") if item.strip()]
valid_ticket_ids = {ticket.ticket_id for ticket in self._task.tickets}
if set(ranked) != valid_ticket_ids:
return {"malformed_queue_ranking": -0.08}
self._state.queue_order = ranked
correct_positions = sum(
1 for observed, expected in zip(ranked, self._task.gold_queue_order) if observed == expected
)
return {"queue_progress": round((correct_positions / len(ranked)) * 0.12, 4)}