File size: 5,729 Bytes
084325c 1d1d8c2 084325c 1d1d8c2 084325c 1d1d8c2 084325c 1d1d8c2 084325c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import hashlib
import os
import uuid
from typing import List, Dict, Optional
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
# Ensure relative imports resolve correctly based on execution context
try:
from models import CustomerSupportAction, CustomerSupportObservation
except ImportError:
from ..models import CustomerSupportAction, CustomerSupportObservation
TASKS = {
"task1": [
{"id": "t1", "content": "I forgot my password and cannot log into my account. Help!", "type": "password"}
],
"task2": [
{"id": "t2_1", "content": "How do I update my billing email?", "type": "billing"},
{"id": "t2_2", "content": "The system says invalid credentials.", "type": "password"},
{"id": "t2_3", "content": "My app crashed!", "type": "vague"}
],
"task3": [
{"id": "t3_1", "content": "How to change password?", "type": "password"},
{"id": "t3_2", "content": "I want an immediate refund, this is garbage! Cancel my account!", "type": "churn"},
{"id": "t3_3", "content": "Found a way to bypass authentication on the user portal.", "type": "security"},
{"id": "t3_4", "content": "Charge on my credit card is double what it should be.", "type": "billing"},
{"id": "t3_5", "content": "Is there a student discount?", "type": "sales"}
]
}
class CustomerSupportEnvironment(Environment):
"""Customer Support Environment for testing RL agents."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, task_name: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self._session_id = str(uuid.uuid4())
self._state = State(episode_id=self._session_id, step_count=0)
# Priority: explicit arg -> env var -> default
self.task_name = task_name if task_name else os.getenv("TASK_NAME", "task1")
if self.task_name not in TASKS:
self.task_name = "task1"
self.tickets = []
self._load_tickets()
self.current_ticket_index = 0
def _load_tickets(self):
self.tickets = [dict(t) for t in TASKS[self.task_name]]
for t in self.tickets:
t["status"] = "open"
def _get_active_ticket(self) -> Optional[Dict]:
if self.current_ticket_index < len(self.tickets):
return self.tickets[self.current_ticket_index]
return None
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_name: Optional[str] = None, **kwargs) -> CustomerSupportObservation:
"""Reset the environment."""
if episode_id is not None:
self._session_id = episode_id
if task_name is not None and task_name in TASKS:
self.task_name = task_name
self._state = State(episode_id=self._session_id, step_count=0)
self._load_tickets()
self.current_ticket_index = 0
return self._make_observation(reward=0.01, done=False)
def _make_observation(self, reward: float = 0.0, done: bool = False) -> CustomerSupportObservation:
t = self._get_active_ticket()
unresolved = sum(1 for x in self.tickets if x["status"] == "open")
summary = [{"id": x["id"], "summary": x["content"][:30] + "...", "status": x["status"]} for x in self.tickets]
return CustomerSupportObservation(
active_ticket_id=t["id"] if t else None,
ticket_content=t["content"] if t else None,
ticket_metadata={"type": t["type"]} if t else {},
unresolved_count=unresolved,
step_count=self._state.step_count,
tickets_summary=summary,
reward=float(reward),
done=done
)
def step(self, action: CustomerSupportAction, timeout_s: Optional[float] = None, **kwargs) -> CustomerSupportObservation:
"""Execute action step."""
self._state.step_count += 1
t = self._get_active_ticket()
if not t:
return self._make_observation(reward=0.05, done=True)
action_type = action.action_type.lower()
ttype = t["type"]
is_correct = False
# Simple logical grader included inline for self-containment
if ttype == "password":
if action_type == "assign" and action.department == "TechSupport":
is_correct = True
elif ttype == "billing":
if action_type == "assign" and action.department == "Billing":
is_correct = True
elif ttype == "sales":
if action_type == "assign" and action.department == "Sales":
is_correct = True
elif ttype == "vague":
if action_type == "ask_user":
is_correct = True
elif ttype == "churn":
if action_type == "escalate":
is_correct = True
elif ttype == "security":
if action_type == "escalate":
is_correct = True
elif action_type == "assign" and action.department == "TechSupport" and action.priority in ["High", "Urgent"]:
is_correct = True
if is_correct:
reward = 0.95 # High reward but strictly < 1.0 per hackathon spec
t["status"] = "resolved"
else:
reward = 0.05 # Low reward but strictly > 0.0 per hackathon spec
t["status"] = "failed"
self.current_ticket_index += 1
done = self.current_ticket_index >= len(self.tickets)
return self._make_observation(reward=reward, done=done)
@property
def state(self) -> State:
return self._state
|