Spaces:
Sleeping
Sleeping
File size: 5,516 Bytes
af3f703 aa4f7bc af3f703 aa4f7bc af3f703 aa4f7bc af3f703 aa4f7bc af3f703 aa4f7bc af3f703 aa4f7bc af3f703 1724801 aa4f7bc af3f703 aa4f7bc 31f4f64 1724801 aa4f7bc ba2722e aa4f7bc ba2722e 1724801 aa4f7bc ba2722e aa4f7bc 1724801 aa4f7bc af3f703 aa4f7bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | from typing import Tuple, Dict, Any, Optional, cast
from .models import Action, Observation, EnvironmentState, TicketInfo, UserData
from .tasks import TASKS
from .graders import grade
class SupportTicketEnv:
def __init__(self, task_id: str = "task_easy_1"):
self.task_id = task_id
if task_id not in TASKS:
raise ValueError(f"Unknown task_id: {task_id}")
self.task_data = TASKS[task_id]
self.state: Optional[EnvironmentState] = None
self.max_steps = 10
self.reset()
def reset(self) -> Observation:
ticket_data = cast(Dict[str, Any], self.task_data["ticket"])
self.state = EnvironmentState(
current_task_id=self.task_id,
step_count=0,
ticket=TicketInfo(**ticket_data),
action_history=[],
is_done=False,
final_reward=0.0,
task_difficulty=str(self.task_data["difficulty"])
)
return self._get_observation("System initialized. Ticket assigned.")
def _get_observation(self, system_message: str, tool_output: Optional[str] = None) -> Observation:
assert self.state is not None
return Observation(
ticket=self.state.ticket,
available_actions=[
"fetch_user_data", "check_policy", "issue_refund",
"reply_to_customer", "escalate", "close_ticket"
],
system_message=system_message,
history=[f"{a.action_type}({a.parameters})" for a in self.state.action_history],
tool_output=tool_output,
step_count=self.state.step_count
)
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
assert self.state is not None
if self.state.is_done:
return self._get_observation("Episode is over."), 0.0, True, {}
self.state.step_count += 1
self.state.action_history.append(action)
tool_output = None
system_message = f"Action {action.action_type} executed."
# Execute action logic
if action.action_type == "fetch_user_data":
user_id = action.parameters.get("user_id")
if user_id == self.state.ticket.user_id:
user_data = cast(Dict[str, Any], self.task_data["user_data"])
self.state.user_data = UserData(**user_data)
chargeback_info = f", Chargebacks = {self.state.user_data.chargeback_history}" if hasattr(self.state.user_data, "chargeback_history") else ""
tool_output = f"User Data: Tier = {self.state.user_data.account_tier}, Joined = {self.state.user_data.join_date}{chargeback_info}"
else:
tool_output = "Error: Invalid user_id."
system_message = "Failed to fetch user data."
elif action.action_type == "check_policy":
issue_type = action.parameters.get("issue_type", self.state.ticket.issue_type)
policy_map = cast(Dict[str, str], self.task_data["policy"])
policy = policy_map.get(issue_type, "No specific policy found.")
tool_output = f"Policy for {issue_type}: {policy}"
elif action.action_type == "issue_refund":
if self.state.user_data and self.state.user_data.chargeback_history is not None and self.state.user_data.chargeback_history > 0:
tool_output = "Refund denied due to chargeback history."
system_message = "Refund action blocked."
else:
amount = action.parameters.get("amount", "fully")
tool_output = f"Refund issued for {amount}."
elif action.action_type == "reply_to_customer":
msg = action.parameters.get("message", "")
tool_output = f"Replied: '{msg}'"
elif action.action_type == "escalate":
reason = action.parameters.get("reason", "support_tier2")
tool_output = f"Escalated to {reason}."
self.state.ticket.status = "escalated"
self.state.is_done = True
elif action.action_type == "close_ticket":
res = action.parameters.get("resolution", "")
tool_output = f"Ticket closed. Resolution: {res}"
self.state.ticket.status = "closed"
self.state.is_done = True
else:
tool_output = "Invalid action."
system_message = "Action unrecognized."
# Check termination
if self.state.step_count >= self.max_steps:
self.state.is_done = True
system_message = "Max steps reached."
# Calculate intermediate/final reward
new_total_reward = grade(self.state)
step_reward = new_total_reward - self.state.final_reward
self.state.final_reward = new_total_reward
reward = step_reward
if self.state.is_done:
print(f"Final reward calculated: {self.state.final_reward}")
info = {
"current_reward": self.state.final_reward,
"step_count": self.state.step_count
}
print(f"Updated info dictionary: {info}")
return self._get_observation(system_message, tool_output), reward, self.state.is_done, info
def get_state(self) -> EnvironmentState:
assert self.state is not None
return self.state
|