Spaces:
Sleeping
Sleeping
File size: 2,496 Bytes
6792e60 1d7c05f 694dd8d 6792e60 900731f 6792e60 c6579ff 1d7c05f 6792e60 1d7c05f 5e47949 1d7c05f 6792e60 1d7c05f 5e47949 1d7c05f 73d1307 6792e60 1d7c05f 5e47949 1d7c05f 73d1307 c6579ff 1d7c05f 73d1307 1d7c05f 5e47949 1d7c05f 6792e60 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import uuid
from typing import List, Dict
from openenv.core.env_server import Environment
from models import SupportAction, SupportObservation, SupportState
class SupportEnvironment(Environment):
def __init__(self):
super().__init__()
self._db = {
"T1": "Customer History: Frequent refunder. Status: Valid invoice.",
"T4": "System Log: Payment gateway timed out. Card ending in 4242.",
"T6": "Dev Log: 500 error on /api/auth. Cluster: us-east-1."
}
self.tasks = {
"easy": [{"id": "T1", "text": "Refund status?", "dept": "Billing"}],
"medium": [{"id": "T2", "text": "Upgrade account?", "dept": "Sales"},
{"id": "T3", "text": "App keeps crashing.", "dept": "Tech"}],
"hard": [{"id": "T4", "text": "Payment failed!", "dept": "Billing"},
{"id": "T5", "text": "Bulk pricing?", "dept": "Sales"},
{"id": "T6", "text": "API Auth failure.", "dept": "Tech"}]
}
def reset(self, task_name: str = "easy") -> SupportObservation:
self.current_task = self.tasks.get(task_name, self.tasks["easy"])
self._state = SupportState(episode_id=str(uuid.uuid4()), task_id=task_name, current_ticket_index=0)
t = self.current_task[0]
return SupportObservation(done=False, reward=0.01, ticket_id=t["id"], content=t["text"])
def step(self, action: SupportAction) -> SupportObservation:
self._state.step_count += 1
ticket = self.current_task[self._state.current_ticket_index]
if action.action_type == "search":
# Small penalty but staying > 0
return SupportObservation(done=False, reward=0.01, ticket_id=ticket["id"], content=ticket["text"], search_result=self._db.get(ticket["id"], "No record."))
correct = action.department.strip().lower() == ticket["dept"].lower()
# Strictly between 0 and 1: 0.99 for correct, 0.01 for incorrect
reward = 0.99 if correct else 0.01
self._state.current_ticket_index += 1
if self._state.current_ticket_index < len(self.current_task):
t = self.current_task[self._state.current_ticket_index]
return SupportObservation(done=False, reward=reward, ticket_id=t["id"], content=t["text"])
return SupportObservation(done=True, reward=reward, ticket_id="EOF", content="Done.")
@property
def state(self) -> SupportState:
return self._state
|