File size: 5,239 Bytes
9fdf681 c558d88 9fdf681 c558d88 9fdf681 c558d88 9fdf681 c558d88 9fdf681 c558d88 9fdf681 474a72a 9fdf681 c558d88 9fdf681 c558d88 9fdf681 c558d88 9fdf681 c558d88 9fdf681 c558d88 9fdf681 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | from models import Observation, Action, State
from tasks import TASKS, grader
import copy
from typing import Any, Optional
import openenv.core.env_server as es
class CustomerSupportEnv(es.Environment):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.current_task_idx = 0
self.state_data = {}
self.step_count = 0
self.max_steps = 10
self.done = False
self.reset()
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_idx: int = 0, **kwargs: Any) -> Observation:
self.current_task_idx = task_idx
task = TASKS[self.current_task_idx]
self.step_count = 0
self.done = False
self.state_data = {
"ticket_id": f"TKT-{1000 + task_idx}",
"customer_message": task.initial_msg,
"history":[f"Customer: {task.initial_msg}"],
"missing_info": task.required_info.copy(),
"collected_info":[],
"route": None,
"refund_processed": False,
"status": "OPEN",
"episode_id": episode_id
}
return self._get_obs(reward=0.01, feedback="")
@property
def state(self) -> State:
return State(
ticket_id=self.state_data.get("ticket_id", ""),
customer_message=self.state_data.get("customer_message", ""),
history=copy.deepcopy(self.state_data.get("history", [])),
missing_info=copy.deepcopy(self.state_data.get("missing_info", [])),
status=self.state_data.get("status", "OPEN"),
refund_processed=self.state_data.get("refund_processed", False),
episode_id=self.state_data.get("episode_id", None),
step_count=self.step_count
)
def _get_obs(self, reward: float = 0.01, feedback: str = "") -> Observation:
# Clamp ALL rewards strictly between 0 and 1 (exclusive)
clamped_reward = max(0.01, min(0.99, reward))
return Observation(
ticket_id=self.state_data["ticket_id"],
customer_message=self.state_data["customer_message"],
history=self.state_data["history"],
missing_info=self.state_data["missing_info"],
status=self.state_data["status"],
refund_processed=self.state_data["refund_processed"],
done=self.done,
reward=clamped_reward,
metadata={"feedback": feedback, "state": self.state_data}
)
def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any) -> Observation:
if self.done:
return self._get_obs(reward=0.01, feedback="Episode already done")
self.step_count += 1
reward_val = 0.01
feedback = ""
task = TASKS[self.current_task_idx]
# Penalize infinite loops / max steps
if self.step_count >= self.max_steps:
self.done = True
final_score = grader(task, self.state_data)
return self._get_obs(reward=float(final_score), feedback="Max steps reached")
if action.action_type == "ASK_INFO":
asked = action.argument.lower()
found = False
for req in self.state_data["missing_info"]:
if req.lower() in asked.lower():
self.state_data["missing_info"].remove(req)
self.state_data["collected_info"].append(req)
reply = f"Here is my {req}: [MOCK_DATA]"
self.state_data["history"].extend([f"Agent: {action.argument}", f"Customer: {reply}"])
self.state_data["customer_message"] = reply
reward_val = 0.3
feedback = f"Successfully collected {req}"
found = True
break
if not found:
reward_val = 0.05
feedback = "Asked for unnecessary information."
elif action.action_type == "REFUND":
if task.needs_refund and "order_id" in self.state_data["collected_info"]:
self.state_data["refund_processed"] = True
reward_val = 0.4
feedback = "Refund processed successfully."
else:
reward_val = 0.05
feedback = "Cannot process refund without order ID or refund not required."
elif action.action_type == "ROUTE":
self.state_data["route"] = action.argument
if self.state_data["missing_info"]:
reward_val = 0.05
feedback = "Routed prematurely without gathering required info."
else:
self.done = True
final_score = grader(task, self.state_data)
reward_val = float(final_score)
feedback = f"Ticket routed. Final Score: {final_score}"
elif action.action_type == "CLOSE":
self.done = True
self.state_data["status"] = "CLOSED"
final_score = grader(task, self.state_data)
reward_val = float(final_score)
feedback = f"Ticket closed. Final Score: {final_score}"
return self._get_obs(reward=reward_val, feedback=feedback) |