Spaces:
Sleeping
Sleeping
| """Pydantic models for SupportBench OpenEnv environment.""" | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Literal, Optional | |
| from pydantic import BaseModel, Field | |
| # --------------------------------------------------------------------------- | |
| # Action types | |
| # --------------------------------------------------------------------------- | |
| ActionType = Literal[ | |
| "classify_ticket", | |
| "set_priority", | |
| "ask_customer", | |
| "propose_resolution", | |
| "apply_resolution", | |
| "escalate", | |
| "resolve", | |
| ] | |
| TicketCategory = Literal[ | |
| "delivery_issue", | |
| "refund_request", | |
| "damaged_item", | |
| "duplicate_charge", | |
| "wrong_item", | |
| "account_access", | |
| ] | |
| Priority = Literal["low", "medium", "high", "urgent"] | |
| Resolution = Literal[ | |
| "refund", | |
| "replacement", | |
| "troubleshooting", | |
| "account_recovery", | |
| "verify_identity", | |
| "escalate_billing", | |
| "escalate_human", | |
| "deny_refund", | |
| "close_case", | |
| ] | |
| EscalationTarget = Literal["billing", "fraud", "supervisor", "legal", "technical"] | |
| class Action(BaseModel): | |
| """Structured action taken by the agent.""" | |
| action_type: ActionType = Field(..., description="The type of action to perform") | |
| category: Optional[TicketCategory] = Field( | |
| None, description="Ticket category (for classify_ticket)" | |
| ) | |
| priority: Optional[Priority] = Field( | |
| None, description="Priority level (for set_priority)" | |
| ) | |
| message: Optional[str] = Field( | |
| None, description="Message to customer or internal note" | |
| ) | |
| resolution: Optional[Resolution] = Field( | |
| None, description="Resolution type (for propose_resolution / apply_resolution)" | |
| ) | |
| escalate_to: Optional[EscalationTarget] = Field( | |
| None, description="Escalation target (for escalate)" | |
| ) | |
| model_config = {"extra": "ignore"} | |
| # --------------------------------------------------------------------------- | |
| # Observation | |
| # --------------------------------------------------------------------------- | |
| class CustomerProfile(BaseModel): | |
| customer_id: str | |
| name: str | |
| email: str | |
| account_age_days: int | |
| loyalty_tier: str | |
| past_disputes: int | |
| payment_card_changed_recently: bool = False | |
| class OrderInfo(BaseModel): | |
| order_id: str | |
| product: str | |
| amount: float | |
| order_date: str | |
| status: str | |
| model_config = {"extra": "allow"} | |
| class Observation(BaseModel): | |
| """What the agent sees at each step.""" | |
| task_id: str | |
| task_name: str | |
| customer_message: str | |
| customer_profile: CustomerProfile | |
| order_info: OrderInfo | |
| policy_snippets: List[str] | |
| ticket_history: List[Dict[str, Any]] | |
| current_status: str | |
| available_actions: List[str] | |
| steps_taken: int | |
| max_steps: int | |
| last_action_result: Optional[str] = None | |
| last_action_error: Optional[str] = None | |
| # --------------------------------------------------------------------------- | |
| # Reward | |
| # --------------------------------------------------------------------------- | |
| class Reward(BaseModel): | |
| """Reward signal returned after each step.""" | |
| value: float = Field(..., description="Step reward value") | |
| reason: str = Field(..., description="Human-readable explanation of reward") | |
| cumulative: float = Field(..., description="Cumulative reward this episode") | |
| # --------------------------------------------------------------------------- | |
| # Hidden episode state (not exposed to agent directly) | |
| # --------------------------------------------------------------------------- | |
| class EpisodeState(BaseModel): | |
| """Internal state tracked by the environment.""" | |
| task_id: str | |
| task_spec: Dict[str, Any] | |
| # Subgoal completion flags | |
| classified: bool = False | |
| correct_category: bool = False | |
| priority_set: bool = False | |
| correct_priority: bool = False | |
| asked_customer: bool = False | |
| identity_verified: bool = False | |
| resolution_proposed: bool = False | |
| resolution_applied: bool = False | |
| escalated: bool = False | |
| escalation_correct: bool = False | |
| resolved: bool = False | |
| # Policy compliance | |
| refund_denied_when_required: bool = False | |
| replacement_offered: bool = False | |
| policy_referenced: bool = False | |
| acted_before_verification: bool = False # violation | |
| # Trajectory tracking | |
| steps: int = 0 | |
| max_steps: int = 8 | |
| action_history: List[Dict[str, Any]] = Field(default_factory=list) | |
| violations: List[str] = Field(default_factory=list) | |
| repeated_action_counts: Dict[str, int] = Field(default_factory=dict) | |
| # Scoring | |
| cumulative_reward: float = 0.0 | |
| done: bool = False | |
| success: bool = False | |
| # Ticket state visible to agent | |
| current_status: str = "open" | |
| ticket_history: List[Dict[str, Any]] = Field(default_factory=list) | |
| last_action_result: Optional[str] = None | |
| last_action_error: Optional[str] = None |