Aryan
Fix openenv Observation inheritance and pydantic schema return types to fix HTTP 500
b2b8c23
Raw
History Blame Contribute Delete
16.1 kB
import uuid
from datetime import datetime, timedelta
import dateutil.parser
import pytz
from typing import Literal, List, Dict, Optional, Any
from pydantic import BaseModel, Field
from openenv.core import Environment
import openenv.core as openenv_core
class Observation(openenv_core.Observation):
current_simulated_time: str
task_description: str
last_action_result: str
error_message: str = ""
class Action(openenv_core.Action):
action_type: Literal['lookup_employee', 'view_calendar', 'book_meeting', 'cancel_meeting', 'submit_task']
employee_ids: List[str] = Field(default_factory=list, description="List of employee IDs for lookups or meetings")
start_time: Optional[str] = Field(None, description="ISO 8601 start time for the meeting")
end_time: Optional[str] = Field(None, description="ISO 8601 end time for the meeting")
meeting_id: Optional[str] = Field(None, description="ID of meeting to cancel")
EMPLOYEES = {
"alice": {"id": "alice", "name": "Alice", "timezone": "UTC"},
"bob": {"id": "bob", "name": "Bob", "timezone": "UTC"},
"charlie": {"id": "charlie", "name": "Charlie", "timezone": "US/Pacific"},
"dave": {"id": "dave", "name": "Dave", "timezone": "US/Eastern"},
"eve": {"id": "eve", "name": "Eve", "timezone": "US/Central"},
"ceo": {"id": "ceo", "name": "CEO", "timezone": "US/Eastern"},
"vp_sales": {"id": "vp_sales", "name": "VP of Sales", "timezone": "US/Pacific"}
}
class SchedulingEnv(Environment):
def __init__(self, task_level: str = "easy"):
self.task_level = task_level.lower()
if self.task_level not in ["easy", "medium", "hard"]:
self.task_level = "easy"
self.max_steps = 15
self.reset()
def reset(self) -> Observation:
self.current_step = 0
# Simulated time is Oct 10 2023 08:00 UTC
self.simulated_time = datetime(2023, 10, 10, 8, 0, 0, tzinfo=pytz.UTC)
self.calendars = {k: [] for k in EMPLOYEES.keys()}
self.task_state = {
"calendar_lookups": set(),
"found_valid_slot": False,
"booked_successfully": False,
"canceled_blocker": False,
"rescheduled_blocker": False,
"wrong_cancellation": False
}
self.target_meeting_id = None
self.blocked_meeting_id = None
self.high_priority_meeting_id = None
self._setup_scenario()
return self.state()
def _setup_scenario(self):
start_of_day = self.simulated_time.replace(hour=0, minute=0, second=0)
if self.task_level == "easy":
self.task_description = (
"Book a 30-minute meeting between Alice and Bob (both in UTC). "
"Find a non-conflicting time tomorrow (Oct 11) within their 9-to-5 working hours. "
"When finished, run submit_task."
)
# Add some random meetings
self._force_book("alice", start_of_day + timedelta(days=1, hours=9), start_of_day + timedelta(days=1, hours=10), "Sync 1")
self._force_book("bob", start_of_day + timedelta(days=1, hours=9, minutes=30), start_of_day + timedelta(days=1, hours=10, minutes=30), "Sync 2")
elif self.task_level == "medium":
self.task_description = (
"Schedule a 1-hour meeting for 4 people: Charlie (PST), Dave (EST), Alice (UTC), and Eve (CST). "
"The meeting must fall specifically within the 9-to-5 local working hours of ALL 4 participants tomorrow (Oct 11). "
"When finished, run submit_task."
)
self._force_book("eve", start_of_day + timedelta(days=1, hours=10), start_of_day + timedelta(days=1, hours=11), "Sync")
elif self.task_level == "hard":
self.task_description = (
"The CEO needs an urgent 1-hour meeting tomorrow (Oct 11) with the VP of Sales. "
"Their calendars are full. You must find a 'low priority' internal sync blocking them, cancel it, "
"book the CEO + VP of Sales in that slot, and reschedule the canceled sync to another non-conflicting time tomorrow. "
"Do NOT cancel 'high priority' meetings. Submit the task when finished."
)
# Fill tomorrow 9 to 5 PST (VP Sales hours) which restricts the slot heavily
# CEO is EST. Valid overlap between EST (9-5) and PST (9-5) is 12:00 PM EST to 5:00 PM EST -> 9:00 AM PST to 2:00 PM PST
# which is 17:00 UTC to 22:00 UTC.
vp_start = start_of_day + timedelta(days=1, hours=17) # 9 AM PST
# High Priority meeting
self.high_priority_meeting_id = self._force_book(
["ceo", "vp_sales"],
vp_start,
vp_start + timedelta(hours=2),
"High Priority Client Pitch"
)
# Low Priority meeting
self.blocked_meeting_id = self._force_book(
["ceo", "vp_sales"],
vp_start + timedelta(hours=3), # 20:00 UTC -> 12pm PST / 3pm EST
vp_start + timedelta(hours=4), # 21:00 UTC -> 1pm PST / 4pm EST
"Low priority internal sync"
)
# Pad the rest of the valid overlap with High Priority
self._force_book(
["vp_sales"],
vp_start + timedelta(hours=2),
vp_start + timedelta(hours=3),
"High Priority Q3 Review"
)
def _force_book(self, emp_ids, start_time: datetime, end_time: datetime, title: str) -> str:
if isinstance(emp_ids, str):
emp_ids = [emp_ids]
m_id = str(uuid.uuid4())
for e in emp_ids:
self.calendars[e].append({
"id": m_id,
"title": title,
"start": start_time.isoformat(),
"end": end_time.isoformat(),
"participants": emp_ids
})
return m_id
def _parse_time(self, time_str: str) -> Optional[datetime]:
try:
dt = dateutil.parser.isoparse(time_str)
if dt.tzinfo is None:
dt = pytz.UTC.localize(dt)
return dt
except Exception:
return None
def _is_working_hours(self, emp_id: str, start_dt: datetime, end_dt: datetime) -> bool:
tz_str = EMPLOYEES[emp_id]["timezone"]
tz = pytz.timezone(tz_str)
local_start = start_dt.astimezone(tz)
local_end = end_dt.astimezone(tz)
if local_start.date() != local_end.date():
return False # crosses midnight locally
start_hour = local_start.hour + local_start.minute / 60.0
end_hour = local_end.hour + local_end.minute / 60.0
return 9.0 <= start_hour and end_hour <= 17.0
def _check_conflict(self, emp_id: str, start_dt: datetime, end_dt: datetime) -> bool:
for m in self.calendars.get(emp_id, []):
m_s = self._parse_time(m["start"])
m_e = self._parse_time(m["end"])
# Overlap condition
if max(start_dt, m_s) < min(end_dt, m_e):
return True
return False
def step(self, action: Action) -> Observation:
if action.employee_ids:
action.employee_ids = [e.lower() for e in action.employee_ids]
self.current_step += 1
reward = 0.0
done = False
last_action_result = ""
error_message = ""
if self.current_step >= self.max_steps:
done = True
error_message = f"Max steps ({self.max_steps}) reached."
return self._finalize_step(last_action_result, error_message, done)
if action.action_type == "lookup_employee":
results = []
for e_id in action.employee_ids:
if e_id in EMPLOYEES:
results.append(EMPLOYEES[e_id])
else:
error_message += f"Employee {e_id} not found. "
if not error_message:
last_action_result = str(results)
elif action.action_type == "view_calendar":
results = {}
for e_id in action.employee_ids:
if e_id in EMPLOYEES:
results[e_id] = self.calendars[e_id]
self.task_state["calendar_lookups"].add(e_id)
else:
error_message += f"Employee {e_id} not found. "
if not error_message:
last_action_result = str(results)
elif action.action_type == "book_meeting":
if not action.start_time or not action.end_time or not action.employee_ids:
error_message = "book_meeting requires start_time, end_time, and employee_ids."
else:
s_dt = self._parse_time(action.start_time)
e_dt = self._parse_time(action.end_time)
if not s_dt or not e_dt or s_dt >= e_dt:
error_message = "Invalid times provided."
else:
# check constraints
valid = True
for e_id in action.employee_ids:
if e_id not in EMPLOYEES:
error_message += f"Employee {e_id} missing. "
valid = False
continue
if not self._is_working_hours(e_id, s_dt, e_dt):
error_message += f"Outside working hours for {e_id}. "
valid = False
if self._check_conflict(e_id, s_dt, e_dt):
error_message += f"Schedule conflict for {e_id}. "
valid = False
if valid:
meeting_id = self._force_book(action.employee_ids, s_dt, e_dt, "Agent Booked Sync")
last_action_result = f"Successfully booked. Meeting ID: {meeting_id}"
# Task tracking logic
dur = (e_dt - s_dt).total_seconds() / 60.0
if self.task_level == "easy" and set(action.employee_ids) == {"alice", "bob"} and dur >= 30:
self.task_state["booked_successfully"] = True
elif self.task_level == "medium" and set(action.employee_ids) == {"charlie", "dave", "alice", "eve"} and dur >= 60:
self.task_state["booked_successfully"] = True
elif self.task_level == "hard":
if set(action.employee_ids) == {"ceo", "vp_sales"} and dur >= 60:
self.task_state["booked_successfully"] = True
if "ceo" in action.employee_ids and "vp_sales" in action.employee_ids and self.task_state["canceled_blocker"]:
# they are booking the CEO/VP in the freed slot
pass
if len(action.employee_ids) == 2 and "ceo" in action.employee_ids and "vp_sales" in action.employee_ids and self.task_state["canceled_blocker"]:
# Maybe this is the rescheduled sync? No, the rescheduled is just the same internal sync.
# Let's assume the agent uses the same participants or general booking.
pass
# If they are rescheduling the canceled sync, they need to book anyone else, but the canceled meeting was with ceo, vp_sales
# If they make another booking after booking the top priority, we count it.
if self.task_state["booked_successfully"] and not set(action.employee_ids) == {"ceo", "vp_sales"}:
# Hacky check for rescheduling
pass
if self.task_state["canceled_blocker"]:
# They booked *something* after canceling. Let's count it if it's the right participants (ceo, vp_sales) but not the main one?
# Actually, original had CEO + VP Sales. So they just re-book them.
if set(action.employee_ids) == {"ceo", "vp_sales"} and self.task_state["booked_successfully"]:
self.task_state["rescheduled_blocker"] = True
self.target_meeting_id = meeting_id
elif action.action_type == "cancel_meeting":
if not action.meeting_id:
error_message = "cancel_meeting requires meeting_id."
else:
found = False
for emp, cals in self.calendars.items():
cals_new = [m for m in cals if m["id"] != action.meeting_id]
if len(cals_new) < len(cals):
self.calendars[emp] = cals_new
found = True
if found:
last_action_result = f"Meeting {action.meeting_id} canceled."
if self.task_level == "hard":
if action.meeting_id == self.blocked_meeting_id:
self.task_state["canceled_blocker"] = True
elif action.meeting_id == self.high_priority_meeting_id:
self.task_state["wrong_cancellation"] = True
else:
error_message = "Meeting ID not found."
elif action.action_type == "submit_task":
done = True
last_action_result = "Task submitted."
return self._finalize_step(last_action_result, error_message, done)
def _finalize_step(self, last_action_result: str, error_message: str, done: bool) -> Observation:
reward = self._calculate_reward()
obs = Observation(
current_simulated_time=self.simulated_time.isoformat(),
task_description=self.task_description,
last_action_result=last_action_result,
error_message=error_message,
reward=reward,
done=done
)
return obs
def _calculate_reward(self) -> float:
r = 0.0
if self.task_level == "easy":
if "alice" in self.task_state["calendar_lookups"]: r += 0.1
if "bob" in self.task_state["calendar_lookups"]: r += 0.1
# Found time happens during book_meeting validation
# Booked gives remaining
if self.task_state["booked_successfully"]: r += 0.8
elif self.task_level == "medium":
for e in ["charlie", "dave", "alice", "eve"]:
if e in self.task_state["calendar_lookups"]:
r += 0.05 # up to 0.2
if self.task_state["booked_successfully"]:
r += 0.8
elif self.task_level == "hard":
for e in ["ceo", "vp_sales"]:
if e in self.task_state["calendar_lookups"]: r += 0.1
if self.task_state["canceled_blocker"]: r += 0.2
if self.task_state["wrong_cancellation"]: r -= 0.5
if self.task_state["booked_successfully"]: r += 0.3
if self.task_state["rescheduled_blocker"]: r += 0.3
return max(min(r, 1.0), 0.0)
def close(self) -> None:
pass
def state(self) -> Observation:
return Observation(
current_simulated_time=self.simulated_time.isoformat(),
task_description=self.task_description,
last_action_result="",
error_message="",
reward=0.0,
done=False
)