"""Pydantic models for SupportBench OpenEnv environment.""" from __future__ import annotations from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Action types # --------------------------------------------------------------------------- ActionType = Literal[ "classify_ticket", "set_priority", "ask_customer", "propose_resolution", "apply_resolution", "escalate", "resolve", ] TicketCategory = Literal[ "delivery_issue", "refund_request", "damaged_item", "duplicate_charge", "wrong_item", "account_access", ] Priority = Literal["low", "medium", "high", "urgent"] Resolution = Literal[ "refund", "replacement", "troubleshooting", "account_recovery", "verify_identity", "escalate_billing", "escalate_human", "deny_refund", "close_case", ] EscalationTarget = Literal["billing", "fraud", "supervisor", "legal", "technical"] class Action(BaseModel): """Structured action taken by the agent.""" action_type: ActionType = Field(..., description="The type of action to perform") category: Optional[TicketCategory] = Field( None, description="Ticket category (for classify_ticket)" ) priority: Optional[Priority] = Field( None, description="Priority level (for set_priority)" ) message: Optional[str] = Field( None, description="Message to customer or internal note" ) resolution: Optional[Resolution] = Field( None, description="Resolution type (for propose_resolution / apply_resolution)" ) escalate_to: Optional[EscalationTarget] = Field( None, description="Escalation target (for escalate)" ) model_config = {"extra": "ignore"} # --------------------------------------------------------------------------- # Observation # --------------------------------------------------------------------------- class CustomerProfile(BaseModel): customer_id: str name: str email: str account_age_days: int loyalty_tier: str past_disputes: int payment_card_changed_recently: bool = False class OrderInfo(BaseModel): order_id: str product: str amount: float order_date: str status: str model_config = {"extra": "allow"} class Observation(BaseModel): """What the agent sees at each step.""" task_id: str task_name: str customer_message: str customer_profile: CustomerProfile order_info: OrderInfo policy_snippets: List[str] ticket_history: List[Dict[str, Any]] current_status: str available_actions: List[str] steps_taken: int max_steps: int last_action_result: Optional[str] = None last_action_error: Optional[str] = None # --------------------------------------------------------------------------- # Reward # --------------------------------------------------------------------------- class Reward(BaseModel): """Reward signal returned after each step.""" value: float = Field(..., description="Step reward value") reason: str = Field(..., description="Human-readable explanation of reward") cumulative: float = Field(..., description="Cumulative reward this episode") # --------------------------------------------------------------------------- # Hidden episode state (not exposed to agent directly) # --------------------------------------------------------------------------- class EpisodeState(BaseModel): """Internal state tracked by the environment.""" task_id: str task_spec: Dict[str, Any] # Subgoal completion flags classified: bool = False correct_category: bool = False priority_set: bool = False correct_priority: bool = False asked_customer: bool = False identity_verified: bool = False resolution_proposed: bool = False resolution_applied: bool = False escalated: bool = False escalation_correct: bool = False resolved: bool = False # Policy compliance refund_denied_when_required: bool = False replacement_offered: bool = False policy_referenced: bool = False acted_before_verification: bool = False # violation # Trajectory tracking steps: int = 0 max_steps: int = 8 action_history: List[Dict[str, Any]] = Field(default_factory=list) violations: List[str] = Field(default_factory=list) repeated_action_counts: Dict[str, int] = Field(default_factory=dict) # Scoring cumulative_reward: float = 0.0 done: bool = False success: bool = False # Ticket state visible to agent current_status: str = "open" ticket_history: List[Dict[str, Any]] = Field(default_factory=list) last_action_result: Optional[str] = None last_action_error: Optional[str] = None