Spaces:
Sleeping
Sleeping
| """Pure utility functions for the meeting-scheduling RL environment. | |
| Calendar format: Dict[str, List[List]] | |
| Each entry is [start_iso, end_iso, priority_int, summary_str]. | |
| All datetimes are timezone-aware ISO 8601 strings. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from datetime import datetime, date, timedelta | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| def parse_iso(s: str) -> datetime: | |
| """Parse an ISO 8601 string into a datetime object.""" | |
| return datetime.fromisoformat(s) | |
| def load_scenario(scenario_path: str) -> dict: | |
| """Load a scenario JSON file and return the parsed dict.""" | |
| with open(scenario_path, "r") as f: | |
| return json.load(f) | |
| def find_conflicts( | |
| calendars: Dict[str, List[List]], | |
| proposed_start_iso: str, | |
| proposed_end_iso: str, | |
| attendee_ids: List[str], | |
| ) -> List[Dict]: | |
| """Find calendar conflicts between a proposed slot and existing meetings. | |
| Two intervals overlap when start1 < end2 and start2 < end1. | |
| Returns: | |
| List of conflict dicts with keys: attendee, start, end, priority, | |
| summary, meeting_id. | |
| """ | |
| proposed_start = parse_iso(proposed_start_iso) | |
| proposed_end = parse_iso(proposed_end_iso) | |
| conflicts: List[Dict] = [] | |
| for attendee in attendee_ids: | |
| entries = calendars.get(attendee, []) | |
| for entry in entries: | |
| entry_start_iso, entry_end_iso, priority, summary = entry | |
| entry_start = parse_iso(entry_start_iso) | |
| entry_end = parse_iso(entry_end_iso) | |
| if proposed_start < entry_end and entry_start < proposed_end: | |
| conflicts.append({ | |
| "attendee": attendee, | |
| "start": entry_start_iso, | |
| "end": entry_end_iso, | |
| "priority": priority, | |
| "summary": summary, | |
| "meeting_id": f"{attendee}_{entry_start_iso}", | |
| }) | |
| return conflicts | |
| def calculate_collective_hours(preferences: Dict) -> Dict[str, int]: | |
| """Find the intersection of all users' preferred working hours. | |
| Each user preference has 'preferred_hours': {'start': int, 'end': int}. | |
| Returns: | |
| {"min_start_hour": <max of all starts>, "max_end_hour": <min of all ends>} | |
| """ | |
| start_hours = [p.get("preferred_hours", {}).get("start", 9) for p in preferences.values()] | |
| end_hours = [p.get("preferred_hours", {}).get("end", 17) for p in preferences.values()] | |
| return { | |
| "min_start_hour": max(start_hours), | |
| "max_end_hour": min(end_hours), | |
| } | |
| def within_collective_hours( | |
| start_iso: str, | |
| end_iso: str, | |
| collective_hours: Dict[str, int], | |
| ) -> bool: | |
| """Check if a proposed slot falls within collective working hours. | |
| The start hour must be >= min_start_hour and the end hour must be | |
| <= max_end_hour (exact hour boundary is allowed). | |
| """ | |
| start = parse_iso(start_iso) | |
| end = parse_iso(end_iso) | |
| min_start = collective_hours["min_start_hour"] | |
| max_end = collective_hours["max_end_hour"] | |
| if start.hour < min_start: | |
| return False | |
| # Handle end at exact hour boundary (minute == 0) vs. mid-hour | |
| if end.minute == 0 and end.second == 0: | |
| if end.hour > max_end: | |
| return False | |
| else: | |
| if end.hour >= max_end: | |
| return False | |
| return True | |
| def count_meetings_on_date(calendar_entries: List[List], target_date: date) -> int: | |
| """Count how many meetings a user has on a given date.""" | |
| count = 0 | |
| for entry in calendar_entries: | |
| entry_start = parse_iso(entry[0]) | |
| if entry_start.date() == target_date: | |
| count += 1 | |
| return count | |
| def check_back_to_back( | |
| calendar_entries: List[List], | |
| proposed_start_iso: str, | |
| proposed_end_iso: str, | |
| buffer_minutes: int, | |
| ) -> bool: | |
| """Check if a proposed meeting would be back-to-back with an existing one. | |
| Returns True if any existing meeting ends within buffer_minutes before | |
| the proposed start, or starts within buffer_minutes after the proposed end. | |
| """ | |
| proposed_start = parse_iso(proposed_start_iso) | |
| proposed_end = parse_iso(proposed_end_iso) | |
| buffer = timedelta(minutes=buffer_minutes) | |
| for entry in calendar_entries: | |
| entry_start = parse_iso(entry[0]) | |
| entry_end = parse_iso(entry[1]) | |
| # Existing meeting ends close before proposed start | |
| gap_before = proposed_start - entry_end | |
| if timedelta(0) <= gap_before < buffer: | |
| return True | |
| # Existing meeting starts close after proposed end | |
| gap_after = entry_start - proposed_end | |
| if timedelta(0) <= gap_after < buffer: | |
| return True | |
| return False | |
| def calculate_preference_score( | |
| proposed_start_iso: str, | |
| duration_minutes: int, | |
| participant_preferences: Dict, | |
| calendars: Dict[str, List[List]], | |
| ) -> float: | |
| """Calculate penalty points for scheduling preference violations. | |
| Penalty rules: | |
| - Outside preferred hours: +50 per participant | |
| - Exceeds max meetings per day: +30 per participant | |
| - Back-to-back without buffer: +20 per participant | |
| Returns: | |
| Total penalty sum (float). | |
| """ | |
| proposed_start = parse_iso(proposed_start_iso) | |
| proposed_end = proposed_start + timedelta(minutes=duration_minutes) | |
| proposed_end_iso = proposed_end.isoformat() | |
| proposed_date = proposed_start.date() | |
| total_penalty = 0.0 | |
| for participant, prefs in participant_preferences.items(): | |
| pref_hours = prefs.get("preferred_hours", {}) | |
| pref_start = pref_hours.get("start", 9) | |
| pref_end = pref_hours.get("end", 17) | |
| max_meetings = prefs.get("max_meetings_per_day", 8) | |
| avoid_btb = prefs.get("avoid_back_to_back", False) | |
| buffer_mins = prefs.get("buffer_minutes", 0) | |
| # Outside preferred hours | |
| collective = {"min_start_hour": pref_start, "max_end_hour": pref_end} | |
| if not within_collective_hours(proposed_start_iso, proposed_end_iso, collective): | |
| total_penalty += 50 | |
| # Exceeds max meetings per day | |
| entries = calendars.get(participant, []) | |
| existing_count = count_meetings_on_date(entries, proposed_date) | |
| if existing_count + 1 > max_meetings: | |
| total_penalty += 30 | |
| # Back-to-back without buffer (only if user cares about it) | |
| if avoid_btb and buffer_mins > 0: | |
| if check_back_to_back(entries, proposed_start_iso, proposed_end_iso, buffer_mins): | |
| total_penalty += 20 | |
| return total_penalty | |
| def is_slot_free( | |
| attendee: str, | |
| start_iso: str, | |
| end_iso: str, | |
| calendars: Dict[str, List[List]], | |
| ) -> bool: | |
| """Check if a time slot is free for a specific attendee (no overlaps).""" | |
| start = parse_iso(start_iso) | |
| end = parse_iso(end_iso) | |
| for entry in calendars.get(attendee, []): | |
| entry_start = parse_iso(entry[0]) | |
| entry_end = parse_iso(entry[1]) | |
| if start < entry_end and entry_start < end: | |
| return False | |
| return True | |
| def calculate_final_reward( | |
| preference_penalty: float, | |
| num_rescheduled: int, | |
| steps_taken: int, | |
| success: bool = True, | |
| ) -> float: | |
| """Compute the multi-component reward for an episode, clamped to [0.0, 1.0]. | |
| Components (deducted from 1.0): | |
| - Preference deduction: min(0.75, (preference_penalty ** 1.2) / 200.0) | |
| - Rescheduling deduction: min(0.30, 0.05 * (1.8 ** num_rescheduled)) | |
| (only applied when num_rescheduled > 0) | |
| - Time deduction: steps_taken * 0.015 | |
| Returns 0.0 if the episode was not successful. | |
| """ | |
| if not success: | |
| return 0.0 | |
| reward = 1.0 | |
| # Preference deduction | |
| pref_deduction = min(0.75, (preference_penalty ** 1.2) / 200.0) | |
| reward -= pref_deduction | |
| # Rescheduling deduction (exponential) | |
| if num_rescheduled > 0: | |
| reschedule_deduction = min(0.30, 0.05 * (1.8 ** num_rescheduled)) | |
| reward -= reschedule_deduction | |
| # Time deduction | |
| time_deduction = steps_taken * 0.015 | |
| reward -= time_deduction | |
| return max(0.0, min(1.0, reward)) | |
| def build_busy_slots( | |
| calendars: Dict[str, List[List]], | |
| attendee_ids: List[str], | |
| ) -> List[Dict]: | |
| """Convert calendar data to observation-friendly busy_slots format. | |
| Returns: | |
| List of dicts with keys: start, end, priority, summary, attendee. | |
| """ | |
| busy_slots: List[Dict] = [] | |
| for attendee in attendee_ids: | |
| for entry in calendars.get(attendee, []): | |
| start_iso, end_iso, priority, summary = entry | |
| busy_slots.append({ | |
| "start": start_iso, | |
| "end": end_iso, | |
| "priority": priority, | |
| "summary": summary, | |
| "attendee": attendee, | |
| }) | |
| return busy_slots | |
| def find_earliest_free_slot( | |
| calendars: Dict[str, List[List]], | |
| attendees: List[str], | |
| duration_minutes: int, | |
| search_date_iso: str, | |
| collective_hours: Dict[str, int], | |
| ) -> Optional[str]: | |
| """Find the earliest free slot on a given date for all attendees. | |
| Iterates from min_start_hour to max_end_hour in 15-minute increments. | |
| Returns the ISO 8601 string of the first conflict-free slot, or None. | |
| """ | |
| search_date = parse_iso(search_date_iso) | |
| base_date = search_date.date() | |
| tz = search_date.tzinfo | |
| min_start = collective_hours["min_start_hour"] | |
| max_end = collective_hours["max_end_hour"] | |
| candidate = datetime(base_date.year, base_date.month, base_date.day, | |
| min_start, 0, 0, tzinfo=tz) | |
| end_boundary = datetime(base_date.year, base_date.month, base_date.day, | |
| max_end, 0, 0, tzinfo=tz) | |
| step = timedelta(minutes=15) | |
| while candidate + timedelta(minutes=duration_minutes) <= end_boundary: | |
| candidate_iso = candidate.isoformat() | |
| candidate_end_iso = (candidate + timedelta(minutes=duration_minutes)).isoformat() | |
| all_free = True | |
| for attendee in attendees: | |
| if not is_slot_free(attendee, candidate_iso, candidate_end_iso, calendars): | |
| all_free = False | |
| break | |
| if all_free: | |
| return candidate_iso | |
| candidate += step | |
| return None | |