openenv-customer-support / server /environment.py
Utkarsh Sinha
Fix: scores strictly within (0,1) - rewards 0.95/0.05 instead of 1.0/0.0
1d1d8c2
import hashlib
import os
import uuid
from typing import List, Dict, Optional
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
# Ensure relative imports resolve correctly based on execution context
try:
from models import CustomerSupportAction, CustomerSupportObservation
except ImportError:
from ..models import CustomerSupportAction, CustomerSupportObservation
TASKS = {
"task1": [
{"id": "t1", "content": "I forgot my password and cannot log into my account. Help!", "type": "password"}
],
"task2": [
{"id": "t2_1", "content": "How do I update my billing email?", "type": "billing"},
{"id": "t2_2", "content": "The system says invalid credentials.", "type": "password"},
{"id": "t2_3", "content": "My app crashed!", "type": "vague"}
],
"task3": [
{"id": "t3_1", "content": "How to change password?", "type": "password"},
{"id": "t3_2", "content": "I want an immediate refund, this is garbage! Cancel my account!", "type": "churn"},
{"id": "t3_3", "content": "Found a way to bypass authentication on the user portal.", "type": "security"},
{"id": "t3_4", "content": "Charge on my credit card is double what it should be.", "type": "billing"},
{"id": "t3_5", "content": "Is there a student discount?", "type": "sales"}
]
}
class CustomerSupportEnvironment(Environment):
"""Customer Support Environment for testing RL agents."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, task_name: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self._session_id = str(uuid.uuid4())
self._state = State(episode_id=self._session_id, step_count=0)
# Priority: explicit arg -> env var -> default
self.task_name = task_name if task_name else os.getenv("TASK_NAME", "task1")
if self.task_name not in TASKS:
self.task_name = "task1"
self.tickets = []
self._load_tickets()
self.current_ticket_index = 0
def _load_tickets(self):
self.tickets = [dict(t) for t in TASKS[self.task_name]]
for t in self.tickets:
t["status"] = "open"
def _get_active_ticket(self) -> Optional[Dict]:
if self.current_ticket_index < len(self.tickets):
return self.tickets[self.current_ticket_index]
return None
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_name: Optional[str] = None, **kwargs) -> CustomerSupportObservation:
"""Reset the environment."""
if episode_id is not None:
self._session_id = episode_id
if task_name is not None and task_name in TASKS:
self.task_name = task_name
self._state = State(episode_id=self._session_id, step_count=0)
self._load_tickets()
self.current_ticket_index = 0
return self._make_observation(reward=0.01, done=False)
def _make_observation(self, reward: float = 0.0, done: bool = False) -> CustomerSupportObservation:
t = self._get_active_ticket()
unresolved = sum(1 for x in self.tickets if x["status"] == "open")
summary = [{"id": x["id"], "summary": x["content"][:30] + "...", "status": x["status"]} for x in self.tickets]
return CustomerSupportObservation(
active_ticket_id=t["id"] if t else None,
ticket_content=t["content"] if t else None,
ticket_metadata={"type": t["type"]} if t else {},
unresolved_count=unresolved,
step_count=self._state.step_count,
tickets_summary=summary,
reward=float(reward),
done=done
)
def step(self, action: CustomerSupportAction, timeout_s: Optional[float] = None, **kwargs) -> CustomerSupportObservation:
"""Execute action step."""
self._state.step_count += 1
t = self._get_active_ticket()
if not t:
return self._make_observation(reward=0.05, done=True)
action_type = action.action_type.lower()
ttype = t["type"]
is_correct = False
# Simple logical grader included inline for self-containment
if ttype == "password":
if action_type == "assign" and action.department == "TechSupport":
is_correct = True
elif ttype == "billing":
if action_type == "assign" and action.department == "Billing":
is_correct = True
elif ttype == "sales":
if action_type == "assign" and action.department == "Sales":
is_correct = True
elif ttype == "vague":
if action_type == "ask_user":
is_correct = True
elif ttype == "churn":
if action_type == "escalate":
is_correct = True
elif ttype == "security":
if action_type == "escalate":
is_correct = True
elif action_type == "assign" and action.department == "TechSupport" and action.priority in ["High", "Urgent"]:
is_correct = True
if is_correct:
reward = 0.95 # High reward but strictly < 1.0 per hackathon spec
t["status"] = "resolved"
else:
reward = 0.05 # Low reward but strictly > 0.0 per hackathon spec
t["status"] = "failed"
self.current_ticket_index += 1
done = self.current_ticket_index >= len(self.tickets)
return self._make_observation(reward=reward, done=done)
@property
def state(self) -> State:
return self._state