customer_support_env / server /customer_support_env_environment.py
ravindrakapse's picture
Upload folder using huggingface_hub
d287a79 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Customer Support Ticket Management Environment Implementation.
A real-world environment simulating customer support ticket handling.
The agent must categorize tickets, assign priorities, route to appropriate teams,
and draft professional responses.
Three task difficulties:
- EASY: Basic ticket classification
- MEDIUM: Priority assignment + team routing
- HARD: Complete ticket resolution with quality response drafting
"""
from uuid import uuid4
from typing import Optional
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import CustomerSupportAction, CustomerSupportObservation
from ..tasks import generate_ticket, get_grader
except ImportError:
from models import CustomerSupportAction, CustomerSupportObservation
from tasks import generate_ticket, get_grader
class CustomerSupportEnvironment(Environment):
"""
Customer Support Ticket Management Environment.
This environment simulates a real-world customer support system where an AI agent
must handle incoming tickets by categorizing, prioritizing, routing, and responding.
Action Space:
- category: billing, technical, account, shipping, general
- priority: low, medium, high, critical
- assigned_team: tier1, tier2, billing, technical, management
- response_draft: Text response to customer (min 10 chars)
- internal_notes: Optional notes for the team
- escalate: Boolean flag for escalation
Observation Space:
- Ticket metadata (ID, timestamp, customer ID, channel)
- Customer message (the support request)
- Customer history (account age, previous tickets, satisfaction, premium status, LTV)
- Additional context (previous interactions, attachments)
Reward Function:
- Category correctness: 0.25
- Priority correctness: 0.20
- Team routing correctness: 0.25
- Response quality: 0.20
- Efficiency bonuses: up to 0.15
- Penalties for errors: up to -0.15
Tasks:
- easy: Ticket classification only (threshold: 0.8)
- medium: Category + priority + routing (threshold: 0.75)
- hard: Full resolution with quality response (threshold: 0.70)
Example:
>>> env = CustomerSupportEnvironment(task_id="easy")
>>> obs = env.reset()
>>> action = CustomerSupportAction(
... category="billing",
... priority="high",
... assigned_team="billing",
... response_draft="I'll help you resolve this billing issue immediately."
... )
>>> obs = env.step(action)
>>> print(obs.reward) # Score based on correctness
"""
# Enable concurrent WebSocket sessions.
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self, task_id: str = "easy", seed: Optional[int] = None):
"""
Initialize the customer support environment.
Args:
task_id: Task difficulty level ("easy", "medium", "hard")
seed: Random seed for reproducibility
"""
self.task_id = task_id
self.seed = seed
self._state = State(episode_id=str(uuid4()), step_count=0)
self.current_observation: Optional[CustomerSupportObservation] = None
self.ground_truth: Optional[dict] = None
self.cumulative_reward: float = 0.0
self.grader = get_grader(task_id)
# Task configurations
self.task_configs = {
"easy": {
"name": "Ticket Classification",
"description": "Categorize support tickets into the correct category",
"max_steps": 1,
"success_threshold": 0.8,
},
"medium": {
"name": "Priority Assignment & Routing",
"description": "Categorize, prioritize, and route tickets correctly",
"max_steps": 1,
"success_threshold": 0.75,
},
"hard": {
"name": "Complete Ticket Resolution",
"description": "Fully resolve tickets with professional responses",
"max_steps": 1,
"success_threshold": 0.70,
},
}
def reset(self) -> CustomerSupportObservation:
"""
Reset the environment to start a new episode.
Returns:
CustomerSupportObservation with a new support ticket
"""
self._state = State(episode_id=str(uuid4()), step_count=0)
self.cumulative_reward = 0.0
# Generate a new ticket
self.current_observation, self.ground_truth = generate_ticket(
seed=self.seed, task_id=self.task_id
)
return self.current_observation
def step(
self, action: CustomerSupportAction
) -> CustomerSupportObservation: # type: ignore[override]
"""
Execute a step in the environment by processing the agent's action.
Args:
action: CustomerSupportAction containing the agent's decisions
Returns:
CustomerSupportObservation with reward and done flag
"""
self._state.step_count += 1
# Grade the action using the task-specific grader
score = self.grader(action, self.ground_truth, self.current_observation)
# Compute detailed reward
reward = self._compute_reward(action, self.ground_truth)
# Update cumulative reward
self.cumulative_reward += reward
# Check if episode is done (single-step tasks for now)
max_steps = self.task_configs[self.task_id]["max_steps"]
done = self._state.step_count >= max_steps
# Create metadata for debugging/analysis
metadata = {
"task_id": self.task_id,
"task_name": self.task_configs[self.task_id]["name"],
"episode_id": self._state.episode_id,
"step_count": self._state.step_count,
"grader_score": score,
"cumulative_reward": self.cumulative_reward,
"ground_truth": {
"category": self.ground_truth["category"],
"priority": self.ground_truth["priority"],
"team": self.ground_truth["team"],
},
"agent_action": {
"category": action.category,
"priority": action.priority,
"team": action.assigned_team,
"escalate": action.escalate,
},
}
# Generate next observation if not done
if not done:
self.current_observation, self.ground_truth = generate_ticket(
seed=self.seed + self._state.step_count if self.seed else None,
task_id=self.task_id,
)
else:
# Keep current observation for final state
pass
# Update observation with reward and done flag
self.current_observation.reward = reward
self.current_observation.done = done
self.current_observation.metadata = metadata
return self.current_observation
@property
def state(self) -> State:
"""
Get the current environment state.
Returns:
Current State with episode_id and step_count
"""
return self._state
def _compute_reward(
self, action: CustomerSupportAction, ground_truth: dict
) -> float:
"""
Compute detailed reward signal with partial progress tracking.
The reward function provides:
- Individual scores for each component (category, priority, team, response)
- Bonuses for premium customer handling
- Penalties for poor decisions
Args:
action: The action taken by the agent
ground_truth: Ground truth labels for the ticket
Returns:
float: Total reward value
"""
# Component scores
category_correct = 0.25 if action.category == ground_truth["category"] else 0.0
priority_correct = 0.20 if action.priority == ground_truth["priority"] else 0.0
team_correct = 0.25 if action.assigned_team == ground_truth["team"] else 0.0
# Response quality evaluation
response_quality = (
self._evaluate_response_quality(
action.response_draft, ground_truth["keywords"]
)
* 0.20
)
# Efficiency bonus for correct responses
efficiency_bonus = 0.0
if category_correct > 0 and priority_correct > 0 and team_correct > 0:
efficiency_bonus = 0.10
# Premium customer handling bonus
if ground_truth["is_premium"]:
response_lower = action.response_draft.lower()
if action.priority in ["high", "critical"] and "value" in response_lower:
efficiency_bonus += 0.05
# Penalties
penalty = 0.0
# Penalty for extremely short responses
if len(action.response_draft) < 20:
penalty -= 0.15
# Penalty for mismatched priority-team assignment
if action.priority == "critical" and action.assigned_team == "tier1":
penalty -= 0.10
# Penalty for not escalating critical issues
if ground_truth["priority"] == "critical" and not action.escalate:
penalty -= 0.05
# Calculate total reward
total = (
category_correct
+ priority_correct
+ team_correct
+ response_quality
+ efficiency_bonus
+ penalty
)
# Ensure total is in valid range
total = max(min(total, 1.0), -0.5)
return total
def _evaluate_response_quality(self, response: str, keywords: list) -> float:
"""
Evaluate the quality of the response draft.
Checks for:
- Appropriate length
- Keyword relevance
- Professional tone
Args:
response: The drafted response
keywords: Relevant keywords for the ticket
Returns:
float: Quality score between 0.0 and 1.0
"""
if len(response) < 20:
return 0.0
score = 0.0
response_lower = response.lower()
# Check for keyword relevance
keyword_matches = sum(1 for kw in keywords if kw.lower() in response_lower)
keyword_score = min(keyword_matches / max(len(keywords), 1), 1.0)
score += keyword_score * 0.4
# Check for professional language
professional_terms = [
"help",
"assist",
"sorry",
"apologize",
"thank",
"appreciate",
"resolve",
]
professional_count = sum(1 for term in professional_terms if term in response_lower)
score += min(professional_count / 3, 1.0) * 0.3
# Check response length is reasonable
word_count = len(response.split())
if 10 <= word_count <= 200:
score += 0.2
elif word_count > 200:
score += 0.1
# Bonus for premium customer language
if self.ground_truth["is_premium"] and (
"value" in response_lower or "priority" in response_lower
):
score += 0.1
return min(score, 1.0)