Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Customer Support Ticket Management Environment Implementation. | |
| A real-world environment simulating customer support ticket handling. | |
| The agent must categorize tickets, assign priorities, route to appropriate teams, | |
| and draft professional responses. | |
| Three task difficulties: | |
| - EASY: Basic ticket classification | |
| - MEDIUM: Priority assignment + team routing | |
| - HARD: Complete ticket resolution with quality response drafting | |
| """ | |
| from uuid import uuid4 | |
| from typing import Optional | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import CustomerSupportAction, CustomerSupportObservation | |
| from ..tasks import generate_ticket, get_grader | |
| except ImportError: | |
| from models import CustomerSupportAction, CustomerSupportObservation | |
| from tasks import generate_ticket, get_grader | |
| class CustomerSupportEnvironment(Environment): | |
| """ | |
| Customer Support Ticket Management Environment. | |
| This environment simulates a real-world customer support system where an AI agent | |
| must handle incoming tickets by categorizing, prioritizing, routing, and responding. | |
| Action Space: | |
| - category: billing, technical, account, shipping, general | |
| - priority: low, medium, high, critical | |
| - assigned_team: tier1, tier2, billing, technical, management | |
| - response_draft: Text response to customer (min 10 chars) | |
| - internal_notes: Optional notes for the team | |
| - escalate: Boolean flag for escalation | |
| Observation Space: | |
| - Ticket metadata (ID, timestamp, customer ID, channel) | |
| - Customer message (the support request) | |
| - Customer history (account age, previous tickets, satisfaction, premium status, LTV) | |
| - Additional context (previous interactions, attachments) | |
| Reward Function: | |
| - Category correctness: 0.25 | |
| - Priority correctness: 0.20 | |
| - Team routing correctness: 0.25 | |
| - Response quality: 0.20 | |
| - Efficiency bonuses: up to 0.15 | |
| - Penalties for errors: up to -0.15 | |
| Tasks: | |
| - easy: Ticket classification only (threshold: 0.8) | |
| - medium: Category + priority + routing (threshold: 0.75) | |
| - hard: Full resolution with quality response (threshold: 0.70) | |
| Example: | |
| >>> env = CustomerSupportEnvironment(task_id="easy") | |
| >>> obs = env.reset() | |
| >>> action = CustomerSupportAction( | |
| ... category="billing", | |
| ... priority="high", | |
| ... assigned_team="billing", | |
| ... response_draft="I'll help you resolve this billing issue immediately." | |
| ... ) | |
| >>> obs = env.step(action) | |
| >>> print(obs.reward) # Score based on correctness | |
| """ | |
| # Enable concurrent WebSocket sessions. | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self, task_id: str = "easy", seed: Optional[int] = None): | |
| """ | |
| Initialize the customer support environment. | |
| Args: | |
| task_id: Task difficulty level ("easy", "medium", "hard") | |
| seed: Random seed for reproducibility | |
| """ | |
| self.task_id = task_id | |
| self.seed = seed | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self.current_observation: Optional[CustomerSupportObservation] = None | |
| self.ground_truth: Optional[dict] = None | |
| self.cumulative_reward: float = 0.0 | |
| self.grader = get_grader(task_id) | |
| # Task configurations | |
| self.task_configs = { | |
| "easy": { | |
| "name": "Ticket Classification", | |
| "description": "Categorize support tickets into the correct category", | |
| "max_steps": 1, | |
| "success_threshold": 0.8, | |
| }, | |
| "medium": { | |
| "name": "Priority Assignment & Routing", | |
| "description": "Categorize, prioritize, and route tickets correctly", | |
| "max_steps": 1, | |
| "success_threshold": 0.75, | |
| }, | |
| "hard": { | |
| "name": "Complete Ticket Resolution", | |
| "description": "Fully resolve tickets with professional responses", | |
| "max_steps": 1, | |
| "success_threshold": 0.70, | |
| }, | |
| } | |
| def reset(self) -> CustomerSupportObservation: | |
| """ | |
| Reset the environment to start a new episode. | |
| Returns: | |
| CustomerSupportObservation with a new support ticket | |
| """ | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self.cumulative_reward = 0.0 | |
| # Generate a new ticket | |
| self.current_observation, self.ground_truth = generate_ticket( | |
| seed=self.seed, task_id=self.task_id | |
| ) | |
| return self.current_observation | |
| def step( | |
| self, action: CustomerSupportAction | |
| ) -> CustomerSupportObservation: # type: ignore[override] | |
| """ | |
| Execute a step in the environment by processing the agent's action. | |
| Args: | |
| action: CustomerSupportAction containing the agent's decisions | |
| Returns: | |
| CustomerSupportObservation with reward and done flag | |
| """ | |
| self._state.step_count += 1 | |
| # Grade the action using the task-specific grader | |
| score = self.grader(action, self.ground_truth, self.current_observation) | |
| # Compute detailed reward | |
| reward = self._compute_reward(action, self.ground_truth) | |
| # Update cumulative reward | |
| self.cumulative_reward += reward | |
| # Check if episode is done (single-step tasks for now) | |
| max_steps = self.task_configs[self.task_id]["max_steps"] | |
| done = self._state.step_count >= max_steps | |
| # Create metadata for debugging/analysis | |
| metadata = { | |
| "task_id": self.task_id, | |
| "task_name": self.task_configs[self.task_id]["name"], | |
| "episode_id": self._state.episode_id, | |
| "step_count": self._state.step_count, | |
| "grader_score": score, | |
| "cumulative_reward": self.cumulative_reward, | |
| "ground_truth": { | |
| "category": self.ground_truth["category"], | |
| "priority": self.ground_truth["priority"], | |
| "team": self.ground_truth["team"], | |
| }, | |
| "agent_action": { | |
| "category": action.category, | |
| "priority": action.priority, | |
| "team": action.assigned_team, | |
| "escalate": action.escalate, | |
| }, | |
| } | |
| # Generate next observation if not done | |
| if not done: | |
| self.current_observation, self.ground_truth = generate_ticket( | |
| seed=self.seed + self._state.step_count if self.seed else None, | |
| task_id=self.task_id, | |
| ) | |
| else: | |
| # Keep current observation for final state | |
| pass | |
| # Update observation with reward and done flag | |
| self.current_observation.reward = reward | |
| self.current_observation.done = done | |
| self.current_observation.metadata = metadata | |
| return self.current_observation | |
| def state(self) -> State: | |
| """ | |
| Get the current environment state. | |
| Returns: | |
| Current State with episode_id and step_count | |
| """ | |
| return self._state | |
| def _compute_reward( | |
| self, action: CustomerSupportAction, ground_truth: dict | |
| ) -> float: | |
| """ | |
| Compute detailed reward signal with partial progress tracking. | |
| The reward function provides: | |
| - Individual scores for each component (category, priority, team, response) | |
| - Bonuses for premium customer handling | |
| - Penalties for poor decisions | |
| Args: | |
| action: The action taken by the agent | |
| ground_truth: Ground truth labels for the ticket | |
| Returns: | |
| float: Total reward value | |
| """ | |
| # Component scores | |
| category_correct = 0.25 if action.category == ground_truth["category"] else 0.0 | |
| priority_correct = 0.20 if action.priority == ground_truth["priority"] else 0.0 | |
| team_correct = 0.25 if action.assigned_team == ground_truth["team"] else 0.0 | |
| # Response quality evaluation | |
| response_quality = ( | |
| self._evaluate_response_quality( | |
| action.response_draft, ground_truth["keywords"] | |
| ) | |
| * 0.20 | |
| ) | |
| # Efficiency bonus for correct responses | |
| efficiency_bonus = 0.0 | |
| if category_correct > 0 and priority_correct > 0 and team_correct > 0: | |
| efficiency_bonus = 0.10 | |
| # Premium customer handling bonus | |
| if ground_truth["is_premium"]: | |
| response_lower = action.response_draft.lower() | |
| if action.priority in ["high", "critical"] and "value" in response_lower: | |
| efficiency_bonus += 0.05 | |
| # Penalties | |
| penalty = 0.0 | |
| # Penalty for extremely short responses | |
| if len(action.response_draft) < 20: | |
| penalty -= 0.15 | |
| # Penalty for mismatched priority-team assignment | |
| if action.priority == "critical" and action.assigned_team == "tier1": | |
| penalty -= 0.10 | |
| # Penalty for not escalating critical issues | |
| if ground_truth["priority"] == "critical" and not action.escalate: | |
| penalty -= 0.05 | |
| # Calculate total reward | |
| total = ( | |
| category_correct | |
| + priority_correct | |
| + team_correct | |
| + response_quality | |
| + efficiency_bonus | |
| + penalty | |
| ) | |
| # Ensure total is in valid range | |
| total = max(min(total, 1.0), -0.5) | |
| return total | |
| def _evaluate_response_quality(self, response: str, keywords: list) -> float: | |
| """ | |
| Evaluate the quality of the response draft. | |
| Checks for: | |
| - Appropriate length | |
| - Keyword relevance | |
| - Professional tone | |
| Args: | |
| response: The drafted response | |
| keywords: Relevant keywords for the ticket | |
| Returns: | |
| float: Quality score between 0.0 and 1.0 | |
| """ | |
| if len(response) < 20: | |
| return 0.0 | |
| score = 0.0 | |
| response_lower = response.lower() | |
| # Check for keyword relevance | |
| keyword_matches = sum(1 for kw in keywords if kw.lower() in response_lower) | |
| keyword_score = min(keyword_matches / max(len(keywords), 1), 1.0) | |
| score += keyword_score * 0.4 | |
| # Check for professional language | |
| professional_terms = [ | |
| "help", | |
| "assist", | |
| "sorry", | |
| "apologize", | |
| "thank", | |
| "appreciate", | |
| "resolve", | |
| ] | |
| professional_count = sum(1 for term in professional_terms if term in response_lower) | |
| score += min(professional_count / 3, 1.0) * 0.3 | |
| # Check response length is reasonable | |
| word_count = len(response.split()) | |
| if 10 <= word_count <= 200: | |
| score += 0.2 | |
| elif word_count > 200: | |
| score += 0.1 | |
| # Bonus for premium customer language | |
| if self.ground_truth["is_premium"] and ( | |
| "value" in response_lower or "priority" in response_lower | |
| ): | |
| score += 0.1 | |
| return min(score, 1.0) | |