Rishi Prasad
Clean submission upload
bc8b288
"""Pydantic models for SupportBench OpenEnv environment."""
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Action types
# ---------------------------------------------------------------------------
ActionType = Literal[
"classify_ticket",
"set_priority",
"ask_customer",
"propose_resolution",
"apply_resolution",
"escalate",
"resolve",
]
TicketCategory = Literal[
"delivery_issue",
"refund_request",
"damaged_item",
"duplicate_charge",
"wrong_item",
"account_access",
]
Priority = Literal["low", "medium", "high", "urgent"]
Resolution = Literal[
"refund",
"replacement",
"troubleshooting",
"account_recovery",
"verify_identity",
"escalate_billing",
"escalate_human",
"deny_refund",
"close_case",
]
EscalationTarget = Literal["billing", "fraud", "supervisor", "legal", "technical"]
class Action(BaseModel):
"""Structured action taken by the agent."""
action_type: ActionType = Field(..., description="The type of action to perform")
category: Optional[TicketCategory] = Field(
None, description="Ticket category (for classify_ticket)"
)
priority: Optional[Priority] = Field(
None, description="Priority level (for set_priority)"
)
message: Optional[str] = Field(
None, description="Message to customer or internal note"
)
resolution: Optional[Resolution] = Field(
None, description="Resolution type (for propose_resolution / apply_resolution)"
)
escalate_to: Optional[EscalationTarget] = Field(
None, description="Escalation target (for escalate)"
)
model_config = {"extra": "ignore"}
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class CustomerProfile(BaseModel):
customer_id: str
name: str
email: str
account_age_days: int
loyalty_tier: str
past_disputes: int
payment_card_changed_recently: bool = False
class OrderInfo(BaseModel):
order_id: str
product: str
amount: float
order_date: str
status: str
model_config = {"extra": "allow"}
class Observation(BaseModel):
"""What the agent sees at each step."""
task_id: str
task_name: str
customer_message: str
customer_profile: CustomerProfile
order_info: OrderInfo
policy_snippets: List[str]
ticket_history: List[Dict[str, Any]]
current_status: str
available_actions: List[str]
steps_taken: int
max_steps: int
last_action_result: Optional[str] = None
last_action_error: Optional[str] = None
# ---------------------------------------------------------------------------
# Reward
# ---------------------------------------------------------------------------
class Reward(BaseModel):
"""Reward signal returned after each step."""
value: float = Field(..., description="Step reward value")
reason: str = Field(..., description="Human-readable explanation of reward")
cumulative: float = Field(..., description="Cumulative reward this episode")
# ---------------------------------------------------------------------------
# Hidden episode state (not exposed to agent directly)
# ---------------------------------------------------------------------------
class EpisodeState(BaseModel):
"""Internal state tracked by the environment."""
task_id: str
task_spec: Dict[str, Any]
# Subgoal completion flags
classified: bool = False
correct_category: bool = False
priority_set: bool = False
correct_priority: bool = False
asked_customer: bool = False
identity_verified: bool = False
resolution_proposed: bool = False
resolution_applied: bool = False
escalated: bool = False
escalation_correct: bool = False
resolved: bool = False
# Policy compliance
refund_denied_when_required: bool = False
replacement_offered: bool = False
policy_referenced: bool = False
acted_before_verification: bool = False # violation
# Trajectory tracking
steps: int = 0
max_steps: int = 8
action_history: List[Dict[str, Any]] = Field(default_factory=list)
violations: List[str] = Field(default_factory=list)
repeated_action_counts: Dict[str, int] = Field(default_factory=dict)
# Scoring
cumulative_reward: float = 0.0
done: bool = False
success: bool = False
# Ticket state visible to agent
current_status: str = "open"
ticket_history: List[Dict[str, Any]] = Field(default_factory=list)
last_action_result: Optional[str] = None
last_action_error: Optional[str] = None