File size: 5,729 Bytes
084325c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1d8c2
084325c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1d8c2
084325c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1d8c2
084325c
 
1d1d8c2
084325c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import hashlib
import os
import uuid
from typing import List, Dict, Optional

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

# Ensure relative imports resolve correctly based on execution context
try:
    from models import CustomerSupportAction, CustomerSupportObservation
except ImportError:
    from ..models import CustomerSupportAction, CustomerSupportObservation

TASKS = {
    "task1": [
        {"id": "t1", "content": "I forgot my password and cannot log into my account. Help!", "type": "password"}
    ],
    "task2": [
        {"id": "t2_1", "content": "How do I update my billing email?", "type": "billing"},
        {"id": "t2_2", "content": "The system says invalid credentials.", "type": "password"},
        {"id": "t2_3", "content": "My app crashed!", "type": "vague"}
    ],
    "task3": [
        {"id": "t3_1", "content": "How to change password?", "type": "password"},
        {"id": "t3_2", "content": "I want an immediate refund, this is garbage! Cancel my account!", "type": "churn"},
        {"id": "t3_3", "content": "Found a way to bypass authentication on the user portal.", "type": "security"},
        {"id": "t3_4", "content": "Charge on my credit card is double what it should be.", "type": "billing"},
        {"id": "t3_5", "content": "Is there a student discount?", "type": "sales"}
    ]
}

class CustomerSupportEnvironment(Environment):
    """Customer Support Environment for testing RL agents."""
    SUPPORTS_CONCURRENT_SESSIONS = True

    def __init__(self, task_name: Optional[str] = None, **kwargs):
        super().__init__(**kwargs)
        self._session_id = str(uuid.uuid4())
        self._state = State(episode_id=self._session_id, step_count=0)
        
        # Priority: explicit arg -> env var -> default
        self.task_name = task_name if task_name else os.getenv("TASK_NAME", "task1")
        if self.task_name not in TASKS:
            self.task_name = "task1"
            
        self.tickets = []
        self._load_tickets()
        self.current_ticket_index = 0

    def _load_tickets(self):
        self.tickets = [dict(t) for t in TASKS[self.task_name]]
        for t in self.tickets:
            t["status"] = "open"

    def _get_active_ticket(self) -> Optional[Dict]:
        if self.current_ticket_index < len(self.tickets):
            return self.tickets[self.current_ticket_index]
        return None

    def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_name: Optional[str] = None, **kwargs) -> CustomerSupportObservation:
        """Reset the environment."""
        if episode_id is not None:
            self._session_id = episode_id
        
        if task_name is not None and task_name in TASKS:
            self.task_name = task_name
            
        self._state = State(episode_id=self._session_id, step_count=0)
        self._load_tickets()
        self.current_ticket_index = 0
        
        return self._make_observation(reward=0.01, done=False)

    def _make_observation(self, reward: float = 0.0, done: bool = False) -> CustomerSupportObservation:
        t = self._get_active_ticket()
        unresolved = sum(1 for x in self.tickets if x["status"] == "open")
        summary = [{"id": x["id"], "summary": x["content"][:30] + "...", "status": x["status"]} for x in self.tickets]
        
        return CustomerSupportObservation(
            active_ticket_id=t["id"] if t else None,
            ticket_content=t["content"] if t else None,
            ticket_metadata={"type": t["type"]} if t else {},
            unresolved_count=unresolved,
            step_count=self._state.step_count,
            tickets_summary=summary,
            reward=float(reward),
            done=done
        )

    def step(self, action: CustomerSupportAction, timeout_s: Optional[float] = None, **kwargs) -> CustomerSupportObservation:
        """Execute action step."""
        self._state.step_count += 1
        t = self._get_active_ticket()
        
        if not t:
            return self._make_observation(reward=0.05, done=True)
            
        action_type = action.action_type.lower()
        ttype = t["type"]
        is_correct = False
        
        # Simple logical grader included inline for self-containment
        if ttype == "password":
            if action_type == "assign" and action.department == "TechSupport":
                is_correct = True
        elif ttype == "billing":
            if action_type == "assign" and action.department == "Billing":
                is_correct = True
        elif ttype == "sales":
            if action_type == "assign" and action.department == "Sales":
                is_correct = True
        elif ttype == "vague":
            if action_type == "ask_user":
                is_correct = True
        elif ttype == "churn":
            if action_type == "escalate":
                is_correct = True
        elif ttype == "security":
            if action_type == "escalate":
                is_correct = True
            elif action_type == "assign" and action.department == "TechSupport" and action.priority in ["High", "Urgent"]:
                is_correct = True

        if is_correct:
            reward = 0.95  # High reward but strictly < 1.0 per hackathon spec
            t["status"] = "resolved"
        else:
            reward = 0.05  # Low reward but strictly > 0.0 per hackathon spec
            t["status"] = "failed"

        self.current_ticket_index += 1
        done = self.current_ticket_index >= len(self.tickets)
        
        return self._make_observation(reward=reward, done=done)

    @property
    def state(self) -> State:
        return self._state