File size: 11,608 Bytes
d287a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Customer Support Ticket Management Environment Implementation.

A real-world environment simulating customer support ticket handling.
The agent must categorize tickets, assign priorities, route to appropriate teams,
and draft professional responses.

Three task difficulties:
- EASY: Basic ticket classification
- MEDIUM: Priority assignment + team routing
- HARD: Complete ticket resolution with quality response drafting
"""

from uuid import uuid4
from typing import Optional

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from ..models import CustomerSupportAction, CustomerSupportObservation
    from ..tasks import generate_ticket, get_grader
except ImportError:
    from models import CustomerSupportAction, CustomerSupportObservation
    from tasks import generate_ticket, get_grader


class CustomerSupportEnvironment(Environment):
    """
    Customer Support Ticket Management Environment.

    This environment simulates a real-world customer support system where an AI agent
    must handle incoming tickets by categorizing, prioritizing, routing, and responding.

    Action Space:
        - category: billing, technical, account, shipping, general
        - priority: low, medium, high, critical
        - assigned_team: tier1, tier2, billing, technical, management
        - response_draft: Text response to customer (min 10 chars)
        - internal_notes: Optional notes for the team
        - escalate: Boolean flag for escalation

    Observation Space:
        - Ticket metadata (ID, timestamp, customer ID, channel)
        - Customer message (the support request)
        - Customer history (account age, previous tickets, satisfaction, premium status, LTV)
        - Additional context (previous interactions, attachments)

    Reward Function:
        - Category correctness: 0.25
        - Priority correctness: 0.20
        - Team routing correctness: 0.25
        - Response quality: 0.20
        - Efficiency bonuses: up to 0.15
        - Penalties for errors: up to -0.15

    Tasks:
        - easy: Ticket classification only (threshold: 0.8)
        - medium: Category + priority + routing (threshold: 0.75)
        - hard: Full resolution with quality response (threshold: 0.70)

    Example:
        >>> env = CustomerSupportEnvironment(task_id="easy")
        >>> obs = env.reset()
        >>> action = CustomerSupportAction(
        ...     category="billing",
        ...     priority="high",
        ...     assigned_team="billing",
        ...     response_draft="I'll help you resolve this billing issue immediately."
        ... )
        >>> obs = env.step(action)
        >>> print(obs.reward)  # Score based on correctness
    """

    # Enable concurrent WebSocket sessions.
    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self, task_id: str = "easy", seed: Optional[int] = None):
        """
        Initialize the customer support environment.

        Args:
            task_id: Task difficulty level ("easy", "medium", "hard")
            seed: Random seed for reproducibility
        """
        self.task_id = task_id
        self.seed = seed
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self.current_observation: Optional[CustomerSupportObservation] = None
        self.ground_truth: Optional[dict] = None
        self.cumulative_reward: float = 0.0
        self.grader = get_grader(task_id)

        # Task configurations
        self.task_configs = {
            "easy": {
                "name": "Ticket Classification",
                "description": "Categorize support tickets into the correct category",
                "max_steps": 1,
                "success_threshold": 0.8,
            },
            "medium": {
                "name": "Priority Assignment & Routing",
                "description": "Categorize, prioritize, and route tickets correctly",
                "max_steps": 1,
                "success_threshold": 0.75,
            },
            "hard": {
                "name": "Complete Ticket Resolution",
                "description": "Fully resolve tickets with professional responses",
                "max_steps": 1,
                "success_threshold": 0.70,
            },
        }

    def reset(self) -> CustomerSupportObservation:
        """
        Reset the environment to start a new episode.

        Returns:
            CustomerSupportObservation with a new support ticket
        """
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self.cumulative_reward = 0.0

        # Generate a new ticket
        self.current_observation, self.ground_truth = generate_ticket(
            seed=self.seed, task_id=self.task_id
        )

        return self.current_observation

    def step(
        self, action: CustomerSupportAction
    ) -> CustomerSupportObservation:  # type: ignore[override]
        """
        Execute a step in the environment by processing the agent's action.

        Args:
            action: CustomerSupportAction containing the agent's decisions

        Returns:
            CustomerSupportObservation with reward and done flag
        """
        self._state.step_count += 1

        # Grade the action using the task-specific grader
        score = self.grader(action, self.ground_truth, self.current_observation)

        # Compute detailed reward
        reward = self._compute_reward(action, self.ground_truth)

        # Update cumulative reward
        self.cumulative_reward += reward

        # Check if episode is done (single-step tasks for now)
        max_steps = self.task_configs[self.task_id]["max_steps"]
        done = self._state.step_count >= max_steps

        # Create metadata for debugging/analysis
        metadata = {
            "task_id": self.task_id,
            "task_name": self.task_configs[self.task_id]["name"],
            "episode_id": self._state.episode_id,
            "step_count": self._state.step_count,
            "grader_score": score,
            "cumulative_reward": self.cumulative_reward,
            "ground_truth": {
                "category": self.ground_truth["category"],
                "priority": self.ground_truth["priority"],
                "team": self.ground_truth["team"],
            },
            "agent_action": {
                "category": action.category,
                "priority": action.priority,
                "team": action.assigned_team,
                "escalate": action.escalate,
            },
        }

        # Generate next observation if not done
        if not done:
            self.current_observation, self.ground_truth = generate_ticket(
                seed=self.seed + self._state.step_count if self.seed else None,
                task_id=self.task_id,
            )
        else:
            # Keep current observation for final state
            pass

        # Update observation with reward and done flag
        self.current_observation.reward = reward
        self.current_observation.done = done
        self.current_observation.metadata = metadata

        return self.current_observation

    @property
    def state(self) -> State:
        """
        Get the current environment state.

        Returns:
            Current State with episode_id and step_count
        """
        return self._state

    def _compute_reward(
        self, action: CustomerSupportAction, ground_truth: dict
    ) -> float:
        """
        Compute detailed reward signal with partial progress tracking.

        The reward function provides:
        - Individual scores for each component (category, priority, team, response)
        - Bonuses for premium customer handling
        - Penalties for poor decisions

        Args:
            action: The action taken by the agent
            ground_truth: Ground truth labels for the ticket

        Returns:
            float: Total reward value
        """
        # Component scores
        category_correct = 0.25 if action.category == ground_truth["category"] else 0.0
        priority_correct = 0.20 if action.priority == ground_truth["priority"] else 0.0
        team_correct = 0.25 if action.assigned_team == ground_truth["team"] else 0.0

        # Response quality evaluation
        response_quality = (
            self._evaluate_response_quality(
                action.response_draft, ground_truth["keywords"]
            )
            * 0.20
        )

        # Efficiency bonus for correct responses
        efficiency_bonus = 0.0
        if category_correct > 0 and priority_correct > 0 and team_correct > 0:
            efficiency_bonus = 0.10

        # Premium customer handling bonus
        if ground_truth["is_premium"]:
            response_lower = action.response_draft.lower()
            if action.priority in ["high", "critical"] and "value" in response_lower:
                efficiency_bonus += 0.05

        # Penalties
        penalty = 0.0

        # Penalty for extremely short responses
        if len(action.response_draft) < 20:
            penalty -= 0.15

        # Penalty for mismatched priority-team assignment
        if action.priority == "critical" and action.assigned_team == "tier1":
            penalty -= 0.10

        # Penalty for not escalating critical issues
        if ground_truth["priority"] == "critical" and not action.escalate:
            penalty -= 0.05

        # Calculate total reward
        total = (
            category_correct
            + priority_correct
            + team_correct
            + response_quality
            + efficiency_bonus
            + penalty
        )

        # Ensure total is in valid range
        total = max(min(total, 1.0), -0.5)

        return total

    def _evaluate_response_quality(self, response: str, keywords: list) -> float:
        """
        Evaluate the quality of the response draft.

        Checks for:
        - Appropriate length
        - Keyword relevance
        - Professional tone

        Args:
            response: The drafted response
            keywords: Relevant keywords for the ticket

        Returns:
            float: Quality score between 0.0 and 1.0
        """
        if len(response) < 20:
            return 0.0

        score = 0.0
        response_lower = response.lower()

        # Check for keyword relevance
        keyword_matches = sum(1 for kw in keywords if kw.lower() in response_lower)
        keyword_score = min(keyword_matches / max(len(keywords), 1), 1.0)
        score += keyword_score * 0.4

        # Check for professional language
        professional_terms = [
            "help",
            "assist",
            "sorry",
            "apologize",
            "thank",
            "appreciate",
            "resolve",
        ]
        professional_count = sum(1 for term in professional_terms if term in response_lower)
        score += min(professional_count / 3, 1.0) * 0.3

        # Check response length is reasonable
        word_count = len(response.split())
        if 10 <= word_count <= 200:
            score += 0.2
        elif word_count > 200:
            score += 0.1

        # Bonus for premium customer language
        if self.ground_truth["is_premium"] and (
            "value" in response_lower or "priority" in response_lower
        ):
            score += 0.1

        return min(score, 1.0)