openenv-rl-environment / env /environment.py
Sid8421's picture
Fix CI validation workflow: Add strict None checks for Mypy compliance
31f4f64
from typing import Tuple, Dict, Any, Optional, cast
from .models import Action, Observation, EnvironmentState, TicketInfo, UserData
from .tasks import TASKS
from .graders import grade
class SupportTicketEnv:
def __init__(self, task_id: str = "task_easy_1"):
self.task_id = task_id
if task_id not in TASKS:
raise ValueError(f"Unknown task_id: {task_id}")
self.task_data = TASKS[task_id]
self.state: Optional[EnvironmentState] = None
self.max_steps = 10
self.reset()
def reset(self) -> Observation:
ticket_data = cast(Dict[str, Any], self.task_data["ticket"])
self.state = EnvironmentState(
current_task_id=self.task_id,
step_count=0,
ticket=TicketInfo(**ticket_data),
action_history=[],
is_done=False,
final_reward=0.0,
task_difficulty=str(self.task_data["difficulty"])
)
return self._get_observation("System initialized. Ticket assigned.")
def _get_observation(self, system_message: str, tool_output: Optional[str] = None) -> Observation:
assert self.state is not None
return Observation(
ticket=self.state.ticket,
available_actions=[
"fetch_user_data", "check_policy", "issue_refund",
"reply_to_customer", "escalate", "close_ticket"
],
system_message=system_message,
history=[f"{a.action_type}({a.parameters})" for a in self.state.action_history],
tool_output=tool_output,
step_count=self.state.step_count
)
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
assert self.state is not None
if self.state.is_done:
return self._get_observation("Episode is over."), 0.0, True, {}
self.state.step_count += 1
self.state.action_history.append(action)
tool_output = None
system_message = f"Action {action.action_type} executed."
# Execute action logic
if action.action_type == "fetch_user_data":
user_id = action.parameters.get("user_id")
if user_id == self.state.ticket.user_id:
user_data = cast(Dict[str, Any], self.task_data["user_data"])
self.state.user_data = UserData(**user_data)
chargeback_info = f", Chargebacks = {self.state.user_data.chargeback_history}" if hasattr(self.state.user_data, "chargeback_history") else ""
tool_output = f"User Data: Tier = {self.state.user_data.account_tier}, Joined = {self.state.user_data.join_date}{chargeback_info}"
else:
tool_output = "Error: Invalid user_id."
system_message = "Failed to fetch user data."
elif action.action_type == "check_policy":
issue_type = action.parameters.get("issue_type", self.state.ticket.issue_type)
policy_map = cast(Dict[str, str], self.task_data["policy"])
policy = policy_map.get(issue_type, "No specific policy found.")
tool_output = f"Policy for {issue_type}: {policy}"
elif action.action_type == "issue_refund":
if self.state.user_data and self.state.user_data.chargeback_history is not None and self.state.user_data.chargeback_history > 0:
tool_output = "Refund denied due to chargeback history."
system_message = "Refund action blocked."
else:
amount = action.parameters.get("amount", "fully")
tool_output = f"Refund issued for {amount}."
elif action.action_type == "reply_to_customer":
msg = action.parameters.get("message", "")
tool_output = f"Replied: '{msg}'"
elif action.action_type == "escalate":
reason = action.parameters.get("reason", "support_tier2")
tool_output = f"Escalated to {reason}."
self.state.ticket.status = "escalated"
self.state.is_done = True
elif action.action_type == "close_ticket":
res = action.parameters.get("resolution", "")
tool_output = f"Ticket closed. Resolution: {res}"
self.state.ticket.status = "closed"
self.state.is_done = True
else:
tool_output = "Invalid action."
system_message = "Action unrecognized."
# Check termination
if self.state.step_count >= self.max_steps:
self.state.is_done = True
system_message = "Max steps reached."
# Calculate intermediate/final reward
new_total_reward = grade(self.state)
step_reward = new_total_reward - self.state.final_reward
self.state.final_reward = new_total_reward
reward = step_reward
if self.state.is_done:
print(f"Final reward calculated: {self.state.final_reward}")
info = {
"current_reward": self.state.final_reward,
"step_count": self.state.step_count
}
print(f"Updated info dictionary: {info}")
return self._get_observation(system_message, tool_output), reward, self.state.is_done, info
def get_state(self) -> EnvironmentState:
assert self.state is not None
return self.state