| import hashlib |
| import os |
| import uuid |
| from typing import List, Dict, Optional |
|
|
| from openenv.core.env_server.interfaces import Environment |
| from openenv.core.env_server.types import State |
|
|
| |
| try: |
| from models import CustomerSupportAction, CustomerSupportObservation |
| except ImportError: |
| from ..models import CustomerSupportAction, CustomerSupportObservation |
|
|
| TASKS = { |
| "task1": [ |
| {"id": "t1", "content": "I forgot my password and cannot log into my account. Help!", "type": "password"} |
| ], |
| "task2": [ |
| {"id": "t2_1", "content": "How do I update my billing email?", "type": "billing"}, |
| {"id": "t2_2", "content": "The system says invalid credentials.", "type": "password"}, |
| {"id": "t2_3", "content": "My app crashed!", "type": "vague"} |
| ], |
| "task3": [ |
| {"id": "t3_1", "content": "How to change password?", "type": "password"}, |
| {"id": "t3_2", "content": "I want an immediate refund, this is garbage! Cancel my account!", "type": "churn"}, |
| {"id": "t3_3", "content": "Found a way to bypass authentication on the user portal.", "type": "security"}, |
| {"id": "t3_4", "content": "Charge on my credit card is double what it should be.", "type": "billing"}, |
| {"id": "t3_5", "content": "Is there a student discount?", "type": "sales"} |
| ] |
| } |
|
|
| class CustomerSupportEnvironment(Environment): |
| """Customer Support Environment for testing RL agents.""" |
| SUPPORTS_CONCURRENT_SESSIONS = True |
|
|
| def __init__(self, task_name: Optional[str] = None, **kwargs): |
| super().__init__(**kwargs) |
| self._session_id = str(uuid.uuid4()) |
| self._state = State(episode_id=self._session_id, step_count=0) |
| |
| |
| self.task_name = task_name if task_name else os.getenv("TASK_NAME", "task1") |
| if self.task_name not in TASKS: |
| self.task_name = "task1" |
| |
| self.tickets = [] |
| self._load_tickets() |
| self.current_ticket_index = 0 |
|
|
| def _load_tickets(self): |
| self.tickets = [dict(t) for t in TASKS[self.task_name]] |
| for t in self.tickets: |
| t["status"] = "open" |
|
|
| def _get_active_ticket(self) -> Optional[Dict]: |
| if self.current_ticket_index < len(self.tickets): |
| return self.tickets[self.current_ticket_index] |
| return None |
|
|
| def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_name: Optional[str] = None, **kwargs) -> CustomerSupportObservation: |
| """Reset the environment.""" |
| if episode_id is not None: |
| self._session_id = episode_id |
| |
| if task_name is not None and task_name in TASKS: |
| self.task_name = task_name |
| |
| self._state = State(episode_id=self._session_id, step_count=0) |
| self._load_tickets() |
| self.current_ticket_index = 0 |
| |
| return self._make_observation(reward=0.01, done=False) |
|
|
| def _make_observation(self, reward: float = 0.0, done: bool = False) -> CustomerSupportObservation: |
| t = self._get_active_ticket() |
| unresolved = sum(1 for x in self.tickets if x["status"] == "open") |
| summary = [{"id": x["id"], "summary": x["content"][:30] + "...", "status": x["status"]} for x in self.tickets] |
| |
| return CustomerSupportObservation( |
| active_ticket_id=t["id"] if t else None, |
| ticket_content=t["content"] if t else None, |
| ticket_metadata={"type": t["type"]} if t else {}, |
| unresolved_count=unresolved, |
| step_count=self._state.step_count, |
| tickets_summary=summary, |
| reward=float(reward), |
| done=done |
| ) |
|
|
| def step(self, action: CustomerSupportAction, timeout_s: Optional[float] = None, **kwargs) -> CustomerSupportObservation: |
| """Execute action step.""" |
| self._state.step_count += 1 |
| t = self._get_active_ticket() |
| |
| if not t: |
| return self._make_observation(reward=0.05, done=True) |
| |
| action_type = action.action_type.lower() |
| ttype = t["type"] |
| is_correct = False |
| |
| |
| if ttype == "password": |
| if action_type == "assign" and action.department == "TechSupport": |
| is_correct = True |
| elif ttype == "billing": |
| if action_type == "assign" and action.department == "Billing": |
| is_correct = True |
| elif ttype == "sales": |
| if action_type == "assign" and action.department == "Sales": |
| is_correct = True |
| elif ttype == "vague": |
| if action_type == "ask_user": |
| is_correct = True |
| elif ttype == "churn": |
| if action_type == "escalate": |
| is_correct = True |
| elif ttype == "security": |
| if action_type == "escalate": |
| is_correct = True |
| elif action_type == "assign" and action.department == "TechSupport" and action.priority in ["High", "Urgent"]: |
| is_correct = True |
|
|
| if is_correct: |
| reward = 0.95 |
| t["status"] = "resolved" |
| else: |
| reward = 0.05 |
| t["status"] = "failed" |
|
|
| self.current_ticket_index += 1 |
| done = self.current_ticket_index >= len(self.tickets) |
| |
| return self._make_observation(reward=reward, done=done) |
|
|
| @property |
| def state(self) -> State: |
| return self._state |
|
|