# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Meeting Scheduling RL Environment. Teaches agents to optimally schedule meetings across multiple attendees by proposing time slots, rescheduling lower-priority conflicts, and balancing participant preferences. """ from __future__ import annotations import copy import json import logging from datetime import timedelta from pathlib import Path from uuid import uuid4 from openenv.core.env_server.interfaces import Environment try: from ..models import SchedulingAction, SchedulingObservation, SchedulingState except ImportError: from models import SchedulingAction, SchedulingObservation, SchedulingState from .scheduling_logic import ( build_busy_slots, calculate_collective_hours, calculate_final_reward, calculate_preference_score, find_conflicts, is_slot_free, parse_iso, within_collective_hours, ) from .scenario_generator import generate_scenario logger = logging.getLogger(__name__) SCENARIOS_DIR = Path(__file__).parent / "scenarios" MAX_STEPS = 20 class SchedulingEnvironment(Environment): """RL environment for intelligent meeting scheduling. The agent must learn to: 1. Propose valid time slots satisfying hard constraints 2. Minimize preference violations 3. Handle cascading rescheduling when conflicts exist 4. Balance speed vs. quality """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): self._state = SchedulingState(episode_id=str(uuid4()), step_count=0) self._scenario: dict = {} self._collective_hours: dict = {} # ------------------------------------------------------------------ # OpenEnv interface # ------------------------------------------------------------------ def reset(self, **kwargs) -> SchedulingObservation: """Reset environment for a new episode. Accepts ``task_id`` kwarg. Static tasks (``"task1_easy"`` etc.) load from JSON. Random tasks (``"random_easy"``, ``"random_medium"``, ``"random_hard"``) generate a fresh scenario every call. An optional ``seed`` kwarg makes random generation reproducible. """ task_id = kwargs.get("task_id", "task1_easy") # ── random scenario generation ── if task_id.startswith("random_"): difficulty = task_id.split("_", 1)[1] seed = kwargs.get("seed", None) try: self._scenario = generate_scenario(difficulty, seed=seed) except ValueError: return SchedulingObservation( error_message=f"Unknown difficulty in task_id: {task_id}", done=True, reward=0.0, ) else: # ── static JSON scenario ── scenario_path = SCENARIOS_DIR / f"{task_id}.json" if not scenario_path.exists(): return SchedulingObservation( error_message=f"Unknown task_id: {task_id}", done=True, reward=0.0, ) with open(scenario_path) as f: self._scenario = json.load(f) req = self._scenario["meeting_request"] prefs = self._scenario["preferences"] self._collective_hours = calculate_collective_hours(prefs) self._state = SchedulingState( episode_id=str(uuid4()), step_count=0, task_id=task_id, scenario_name=self._scenario.get("description", task_id), meeting_request=req, calendars=copy.deepcopy(self._scenario["calendars"]), participant_preferences=prefs, proposed_slot=None, rescheduled_meetings=[], total_preference_penalty=0.0, total_steps=0, final_reward=0.0, completed=False, ) attendees = req["attendees"] return SchedulingObservation( requested_duration=req["duration"], requested_priority=req["priority"], attendee_ids=attendees, busy_slots=build_busy_slots(self._state.calendars, attendees), collective_work_hours=self._collective_hours, preference_constraints=self._aggregate_preferences(prefs), current_proposal=None, conflicts=[], preference_penalty=0.0, num_rescheduled=0, steps_taken=0, max_steps=MAX_STEPS, success=False, error_message=None, done=False, reward=0.0, ) def step(self, action: SchedulingAction) -> SchedulingObservation: # type: ignore[override] """Process one agent action and return an observation.""" if self._state.completed: return self._obs(error_message="Episode already completed", done=True, reward=0.0) self._state.step_count += 1 self._state.total_steps += 1 # Timeout check if self._state.step_count >= MAX_STEPS: return self._handle_timeout() action_type = action.action_type if action_type == "propose_slot": return self._process_propose_slot(action) elif action_type == "reschedule_meeting": return self._process_reschedule_meeting(action) elif action_type == "finalize": return self._process_finalize() elif action_type == "reject": return self._process_reject() else: return self._obs(error_message=f"Unknown action_type: {action_type}", reward=-0.1) @property def state(self) -> SchedulingState: return self._state # ------------------------------------------------------------------ # Action handlers # ------------------------------------------------------------------ def _process_propose_slot(self, action: SchedulingAction) -> SchedulingObservation: if not action.proposed_start or not action.proposed_duration: return self._obs( error_message="propose_slot requires proposed_start and proposed_duration", reward=-0.1, ) try: start = parse_iso(action.proposed_start) except (ValueError, TypeError): return self._obs(error_message="Invalid proposed_start format", reward=-0.1) end = start + timedelta(minutes=action.proposed_duration) start_iso = start.isoformat() end_iso = end.isoformat() attendees = self._state.meeting_request["attendees"] req_priority = self._state.meeting_request["priority"] # Validate working hours if not within_collective_hours(start_iso, end_iso, self._collective_hours): return self._obs( error_message="Proposed slot outside working hours", reward=-0.2, ) # Find conflicts conflicts = find_conflicts( self._state.calendars, start_iso, end_iso, attendees ) # Calculate preference penalty pref_penalty = calculate_preference_score( start_iso, action.proposed_duration, self._state.participant_preferences, self._state.calendars, ) # Update state self._state.proposed_slot = [start_iso, end_iso] self._state.total_preference_penalty = pref_penalty # Step reward if len(conflicts) == 0 and pref_penalty < 100: step_reward = 0.5 elif len(conflicts) > 0: if all(c["priority"] > req_priority for c in conflicts): step_reward = 0.2 else: step_reward = -0.3 else: step_reward = 0.0 return self._obs( current_proposal={"start": start_iso, "end": end_iso}, conflicts=conflicts, preference_penalty=pref_penalty, reward=step_reward, ) def _process_reschedule_meeting(self, action: SchedulingAction) -> SchedulingObservation: if not action.meeting_id_to_move or not action.new_start_time: return self._obs( error_message="reschedule_meeting requires meeting_id_to_move and new_start_time", reward=-0.1, ) if self._state.proposed_slot is None: return self._obs( error_message="Must propose a slot before rescheduling", reward=-0.2, ) # Find the meeting to move meeting = self._find_meeting(action.meeting_id_to_move) if meeting is None: return self._obs( error_message=f"Meeting not found: {action.meeting_id_to_move}", reward=-0.2, ) req_priority = self._state.meeting_request["priority"] if meeting["priority"] <= req_priority: return self._obs( error_message="Cannot reschedule equal or higher priority meeting", reward=-0.5, ) # Validate new slot try: new_start = parse_iso(action.new_start_time) except (ValueError, TypeError): return self._obs(error_message="Invalid new_start_time format", reward=-0.1) old_start = parse_iso(meeting["start"]) old_end = parse_iso(meeting["end"]) duration = old_end - old_start new_end = new_start + duration new_start_iso = new_start.isoformat() new_end_iso = new_end.isoformat() attendee = meeting["attendee"] if not is_slot_free(attendee, new_start_iso, new_end_iso, self._state.calendars): return self._obs(error_message="New slot not free for attendee", reward=-0.2) # Update calendar: remove old, add new cal = self._state.calendars[attendee] self._state.calendars[attendee] = [ e for e in cal if e[0] != meeting["start"] ] self._state.calendars[attendee].append( [new_start_iso, new_end_iso, meeting["priority"], meeting["summary"]] ) self._state.rescheduled_meetings.append({ "meeting_id": action.meeting_id_to_move, "old_start": meeting["start"], "new_start": new_start_iso, "attendee": attendee, }) # Recalculate conflicts for current proposal attendees = self._state.meeting_request["attendees"] new_conflicts = find_conflicts( self._state.calendars, self._state.proposed_slot[0], self._state.proposed_slot[1], attendees, ) num_rescheduled = len(self._state.rescheduled_meetings) step_reward = 0.5 if len(new_conflicts) == 0 else 0.3 return self._obs( conflicts=new_conflicts, num_rescheduled=num_rescheduled, reward=step_reward, ) def _process_finalize(self) -> SchedulingObservation: if self._state.proposed_slot is None: self._state.completed = True return self._obs( error_message="No slot proposed", success=False, reward=0.0, done=True, ) attendees = self._state.meeting_request["attendees"] conflicts = find_conflicts( self._state.calendars, self._state.proposed_slot[0], self._state.proposed_slot[1], attendees, ) if len(conflicts) > 0: self._state.completed = True return self._obs( error_message=f"Unresolved conflicts: {len(conflicts)} meetings", conflicts=conflicts, success=False, reward=0.0, done=True, ) final_reward = calculate_final_reward( preference_penalty=self._state.total_preference_penalty, num_rescheduled=len(self._state.rescheduled_meetings), steps_taken=self._state.step_count, success=True, ) self._state.completed = True self._state.final_reward = final_reward return self._obs( success=True, reward=final_reward, done=True, ) def _process_reject(self) -> SchedulingObservation: self._state.completed = True return self._obs( success=False, reward=0.0, done=True, error_message="Agent rejected scheduling task", ) def _handle_timeout(self) -> SchedulingObservation: """Give partial credit when max steps reached.""" self._state.completed = True if self._state.proposed_slot is None: return self._obs( success=False, reward=0.0, done=True, error_message="Timeout: No slot proposed", ) attendees = self._state.meeting_request["attendees"] conflicts = find_conflicts( self._state.calendars, self._state.proposed_slot[0], self._state.proposed_slot[1], attendees, ) if len(conflicts) == 0: theoretical = calculate_final_reward( self._state.total_preference_penalty, len(self._state.rescheduled_meetings), self._state.step_count, ) partial = theoretical * 0.7 else: progress = 1.0 - (len(conflicts) / max(1, len(attendees))) partial = 0.2 * progress self._state.final_reward = partial return self._obs( success=False, reward=partial, done=True, error_message=f"Timeout after {self._state.step_count} steps (partial credit: {partial:.2f})", ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _obs(self, **overrides) -> SchedulingObservation: """Build an observation from current state, applying overrides.""" req = self._state.meeting_request attendees = req.get("attendees", []) defaults = dict( requested_duration=req.get("duration", 0), requested_priority=req.get("priority", 3), attendee_ids=attendees, busy_slots=build_busy_slots(self._state.calendars, attendees), collective_work_hours=self._collective_hours, preference_constraints=self._aggregate_preferences( self._state.participant_preferences ), current_proposal=( {"start": self._state.proposed_slot[0], "end": self._state.proposed_slot[1]} if self._state.proposed_slot else None ), conflicts=[], preference_penalty=self._state.total_preference_penalty, num_rescheduled=len(self._state.rescheduled_meetings), steps_taken=self._state.step_count, max_steps=MAX_STEPS, success=False, error_message=None, done=False, reward=0.0, ) defaults.update(overrides) return SchedulingObservation(**defaults) def _find_meeting(self, meeting_id: str) -> dict | None: """Look up a meeting by its id (format: attendee_startiso).""" parts = meeting_id.split("_", 1) if len(parts) != 2: return None attendee, start_iso = parts for entry in self._state.calendars.get(attendee, []): if entry[0] == start_iso: return { "attendee": attendee, "start": entry[0], "end": entry[1], "priority": entry[2], "summary": entry[3], } return None @staticmethod def _aggregate_preferences(prefs: dict) -> dict: """Summarize preferences for the observation.""" if not prefs: return {} max_meetings = min(p.get("max_meetings_per_day", 99) for p in prefs.values()) any_buffer = any(p.get("avoid_back_to_back", False) for p in prefs.values()) buffer_mins = max( (p.get("buffer_minutes", 0) for p in prefs.values() if p.get("avoid_back_to_back")), default=0, ) return { "max_meetings_per_day": max_meetings, "requires_buffer": any_buffer, "buffer_minutes": buffer_mins, }