Update env.py
Browse files
env.py
CHANGED
|
@@ -31,7 +31,7 @@ class CustomerSupportEnv(es.Environment):
|
|
| 31 |
"status": "OPEN",
|
| 32 |
"episode_id": episode_id
|
| 33 |
}
|
| 34 |
-
return self._get_obs(reward=0.
|
| 35 |
|
| 36 |
@property
|
| 37 |
def state(self) -> State:
|
|
@@ -46,7 +46,9 @@ class CustomerSupportEnv(es.Environment):
|
|
| 46 |
step_count=self.step_count
|
| 47 |
)
|
| 48 |
|
| 49 |
-
def _get_obs(self, reward: float = 0.
|
|
|
|
|
|
|
| 50 |
return Observation(
|
| 51 |
ticket_id=self.state_data["ticket_id"],
|
| 52 |
customer_message=self.state_data["customer_message"],
|
|
@@ -55,16 +57,16 @@ class CustomerSupportEnv(es.Environment):
|
|
| 55 |
status=self.state_data["status"],
|
| 56 |
refund_processed=self.state_data["refund_processed"],
|
| 57 |
done=self.done,
|
| 58 |
-
reward=
|
| 59 |
metadata={"feedback": feedback, "state": self.state_data}
|
| 60 |
)
|
| 61 |
|
| 62 |
def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any) -> Observation:
|
| 63 |
if self.done:
|
| 64 |
-
return self._get_obs(reward=0.
|
| 65 |
|
| 66 |
self.step_count += 1
|
| 67 |
-
reward_val = 0.
|
| 68 |
feedback = ""
|
| 69 |
task = TASKS[self.current_task_idx]
|
| 70 |
|
|
@@ -84,27 +86,27 @@ class CustomerSupportEnv(es.Environment):
|
|
| 84 |
reply = f"Here is my {req}: [MOCK_DATA]"
|
| 85 |
self.state_data["history"].extend([f"Agent: {action.argument}", f"Customer: {reply}"])
|
| 86 |
self.state_data["customer_message"] = reply
|
| 87 |
-
reward_val = 0.
|
| 88 |
feedback = f"Successfully collected {req}"
|
| 89 |
found = True
|
| 90 |
break
|
| 91 |
if not found:
|
| 92 |
-
reward_val =
|
| 93 |
feedback = "Asked for unnecessary information."
|
| 94 |
|
| 95 |
elif action.action_type == "REFUND":
|
| 96 |
if task.needs_refund and "order_id" in self.state_data["collected_info"]:
|
| 97 |
self.state_data["refund_processed"] = True
|
| 98 |
-
reward_val = 0.
|
| 99 |
feedback = "Refund processed successfully."
|
| 100 |
else:
|
| 101 |
-
reward_val =
|
| 102 |
feedback = "Cannot process refund without order ID or refund not required."
|
| 103 |
|
| 104 |
elif action.action_type == "ROUTE":
|
| 105 |
self.state_data["route"] = action.argument
|
| 106 |
if self.state_data["missing_info"]:
|
| 107 |
-
reward_val =
|
| 108 |
feedback = "Routed prematurely without gathering required info."
|
| 109 |
else:
|
| 110 |
self.done = True
|
|
|
|
| 31 |
"status": "OPEN",
|
| 32 |
"episode_id": episode_id
|
| 33 |
}
|
| 34 |
+
return self._get_obs(reward=0.01, feedback="")
|
| 35 |
|
| 36 |
@property
|
| 37 |
def state(self) -> State:
|
|
|
|
| 46 |
step_count=self.step_count
|
| 47 |
)
|
| 48 |
|
| 49 |
+
def _get_obs(self, reward: float = 0.01, feedback: str = "") -> Observation:
|
| 50 |
+
# Clamp ALL rewards strictly between 0 and 1 (exclusive)
|
| 51 |
+
clamped_reward = max(0.01, min(0.99, reward))
|
| 52 |
return Observation(
|
| 53 |
ticket_id=self.state_data["ticket_id"],
|
| 54 |
customer_message=self.state_data["customer_message"],
|
|
|
|
| 57 |
status=self.state_data["status"],
|
| 58 |
refund_processed=self.state_data["refund_processed"],
|
| 59 |
done=self.done,
|
| 60 |
+
reward=clamped_reward,
|
| 61 |
metadata={"feedback": feedback, "state": self.state_data}
|
| 62 |
)
|
| 63 |
|
| 64 |
def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any) -> Observation:
|
| 65 |
if self.done:
|
| 66 |
+
return self._get_obs(reward=0.01, feedback="Episode already done")
|
| 67 |
|
| 68 |
self.step_count += 1
|
| 69 |
+
reward_val = 0.01
|
| 70 |
feedback = ""
|
| 71 |
task = TASKS[self.current_task_idx]
|
| 72 |
|
|
|
|
| 86 |
reply = f"Here is my {req}: [MOCK_DATA]"
|
| 87 |
self.state_data["history"].extend([f"Agent: {action.argument}", f"Customer: {reply}"])
|
| 88 |
self.state_data["customer_message"] = reply
|
| 89 |
+
reward_val = 0.3
|
| 90 |
feedback = f"Successfully collected {req}"
|
| 91 |
found = True
|
| 92 |
break
|
| 93 |
if not found:
|
| 94 |
+
reward_val = 0.05
|
| 95 |
feedback = "Asked for unnecessary information."
|
| 96 |
|
| 97 |
elif action.action_type == "REFUND":
|
| 98 |
if task.needs_refund and "order_id" in self.state_data["collected_info"]:
|
| 99 |
self.state_data["refund_processed"] = True
|
| 100 |
+
reward_val = 0.4
|
| 101 |
feedback = "Refund processed successfully."
|
| 102 |
else:
|
| 103 |
+
reward_val = 0.05
|
| 104 |
feedback = "Cannot process refund without order ID or refund not required."
|
| 105 |
|
| 106 |
elif action.action_type == "ROUTE":
|
| 107 |
self.state_data["route"] = action.argument
|
| 108 |
if self.state_data["missing_info"]:
|
| 109 |
+
reward_val = 0.05
|
| 110 |
feedback = "Routed prematurely without gathering required info."
|
| 111 |
else:
|
| 112 |
self.done = True
|