Spaces:
Sleeping
Sleeping
File size: 10,303 Bytes
7bdbe90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 | """Pure utility functions for the meeting-scheduling RL environment.
Calendar format: Dict[str, List[List]]
Each entry is [start_iso, end_iso, priority_int, summary_str].
All datetimes are timezone-aware ISO 8601 strings.
"""
from __future__ import annotations
import json
from datetime import datetime, date, timedelta
from pathlib import Path
from typing import Dict, List, Optional
def parse_iso(s: str) -> datetime:
"""Parse an ISO 8601 string into a datetime object."""
return datetime.fromisoformat(s)
def load_scenario(scenario_path: str) -> dict:
"""Load a scenario JSON file and return the parsed dict."""
with open(scenario_path, "r") as f:
return json.load(f)
def find_conflicts(
calendars: Dict[str, List[List]],
proposed_start_iso: str,
proposed_end_iso: str,
attendee_ids: List[str],
) -> List[Dict]:
"""Find calendar conflicts between a proposed slot and existing meetings.
Two intervals overlap when start1 < end2 and start2 < end1.
Returns:
List of conflict dicts with keys: attendee, start, end, priority,
summary, meeting_id.
"""
proposed_start = parse_iso(proposed_start_iso)
proposed_end = parse_iso(proposed_end_iso)
conflicts: List[Dict] = []
for attendee in attendee_ids:
entries = calendars.get(attendee, [])
for entry in entries:
entry_start_iso, entry_end_iso, priority, summary = entry
entry_start = parse_iso(entry_start_iso)
entry_end = parse_iso(entry_end_iso)
if proposed_start < entry_end and entry_start < proposed_end:
conflicts.append({
"attendee": attendee,
"start": entry_start_iso,
"end": entry_end_iso,
"priority": priority,
"summary": summary,
"meeting_id": f"{attendee}_{entry_start_iso}",
})
return conflicts
def calculate_collective_hours(preferences: Dict) -> Dict[str, int]:
"""Find the intersection of all users' preferred working hours.
Each user preference has 'preferred_hours': {'start': int, 'end': int}.
Returns:
{"min_start_hour": <max of all starts>, "max_end_hour": <min of all ends>}
"""
start_hours = [p.get("preferred_hours", {}).get("start", 9) for p in preferences.values()]
end_hours = [p.get("preferred_hours", {}).get("end", 17) for p in preferences.values()]
return {
"min_start_hour": max(start_hours),
"max_end_hour": min(end_hours),
}
def within_collective_hours(
start_iso: str,
end_iso: str,
collective_hours: Dict[str, int],
) -> bool:
"""Check if a proposed slot falls within collective working hours.
The start hour must be >= min_start_hour and the end hour must be
<= max_end_hour (exact hour boundary is allowed).
"""
start = parse_iso(start_iso)
end = parse_iso(end_iso)
min_start = collective_hours["min_start_hour"]
max_end = collective_hours["max_end_hour"]
if start.hour < min_start:
return False
# Handle end at exact hour boundary (minute == 0) vs. mid-hour
if end.minute == 0 and end.second == 0:
if end.hour > max_end:
return False
else:
if end.hour >= max_end:
return False
return True
def count_meetings_on_date(calendar_entries: List[List], target_date: date) -> int:
"""Count how many meetings a user has on a given date."""
count = 0
for entry in calendar_entries:
entry_start = parse_iso(entry[0])
if entry_start.date() == target_date:
count += 1
return count
def check_back_to_back(
calendar_entries: List[List],
proposed_start_iso: str,
proposed_end_iso: str,
buffer_minutes: int,
) -> bool:
"""Check if a proposed meeting would be back-to-back with an existing one.
Returns True if any existing meeting ends within buffer_minutes before
the proposed start, or starts within buffer_minutes after the proposed end.
"""
proposed_start = parse_iso(proposed_start_iso)
proposed_end = parse_iso(proposed_end_iso)
buffer = timedelta(minutes=buffer_minutes)
for entry in calendar_entries:
entry_start = parse_iso(entry[0])
entry_end = parse_iso(entry[1])
# Existing meeting ends close before proposed start
gap_before = proposed_start - entry_end
if timedelta(0) <= gap_before < buffer:
return True
# Existing meeting starts close after proposed end
gap_after = entry_start - proposed_end
if timedelta(0) <= gap_after < buffer:
return True
return False
def calculate_preference_score(
proposed_start_iso: str,
duration_minutes: int,
participant_preferences: Dict,
calendars: Dict[str, List[List]],
) -> float:
"""Calculate penalty points for scheduling preference violations.
Penalty rules:
- Outside preferred hours: +50 per participant
- Exceeds max meetings per day: +30 per participant
- Back-to-back without buffer: +20 per participant
Returns:
Total penalty sum (float).
"""
proposed_start = parse_iso(proposed_start_iso)
proposed_end = proposed_start + timedelta(minutes=duration_minutes)
proposed_end_iso = proposed_end.isoformat()
proposed_date = proposed_start.date()
total_penalty = 0.0
for participant, prefs in participant_preferences.items():
pref_hours = prefs.get("preferred_hours", {})
pref_start = pref_hours.get("start", 9)
pref_end = pref_hours.get("end", 17)
max_meetings = prefs.get("max_meetings_per_day", 8)
avoid_btb = prefs.get("avoid_back_to_back", False)
buffer_mins = prefs.get("buffer_minutes", 0)
# Outside preferred hours
collective = {"min_start_hour": pref_start, "max_end_hour": pref_end}
if not within_collective_hours(proposed_start_iso, proposed_end_iso, collective):
total_penalty += 50
# Exceeds max meetings per day
entries = calendars.get(participant, [])
existing_count = count_meetings_on_date(entries, proposed_date)
if existing_count + 1 > max_meetings:
total_penalty += 30
# Back-to-back without buffer (only if user cares about it)
if avoid_btb and buffer_mins > 0:
if check_back_to_back(entries, proposed_start_iso, proposed_end_iso, buffer_mins):
total_penalty += 20
return total_penalty
def is_slot_free(
attendee: str,
start_iso: str,
end_iso: str,
calendars: Dict[str, List[List]],
) -> bool:
"""Check if a time slot is free for a specific attendee (no overlaps)."""
start = parse_iso(start_iso)
end = parse_iso(end_iso)
for entry in calendars.get(attendee, []):
entry_start = parse_iso(entry[0])
entry_end = parse_iso(entry[1])
if start < entry_end and entry_start < end:
return False
return True
def calculate_final_reward(
preference_penalty: float,
num_rescheduled: int,
steps_taken: int,
success: bool = True,
) -> float:
"""Compute the multi-component reward for an episode, clamped to [0.0, 1.0].
Components (deducted from 1.0):
- Preference deduction: min(0.75, (preference_penalty ** 1.2) / 200.0)
- Rescheduling deduction: min(0.30, 0.05 * (1.8 ** num_rescheduled))
(only applied when num_rescheduled > 0)
- Time deduction: steps_taken * 0.015
Returns 0.0 if the episode was not successful.
"""
if not success:
return 0.0
reward = 1.0
# Preference deduction
pref_deduction = min(0.75, (preference_penalty ** 1.2) / 200.0)
reward -= pref_deduction
# Rescheduling deduction (exponential)
if num_rescheduled > 0:
reschedule_deduction = min(0.30, 0.05 * (1.8 ** num_rescheduled))
reward -= reschedule_deduction
# Time deduction
time_deduction = steps_taken * 0.015
reward -= time_deduction
return max(0.0, min(1.0, reward))
def build_busy_slots(
calendars: Dict[str, List[List]],
attendee_ids: List[str],
) -> List[Dict]:
"""Convert calendar data to observation-friendly busy_slots format.
Returns:
List of dicts with keys: start, end, priority, summary, attendee.
"""
busy_slots: List[Dict] = []
for attendee in attendee_ids:
for entry in calendars.get(attendee, []):
start_iso, end_iso, priority, summary = entry
busy_slots.append({
"start": start_iso,
"end": end_iso,
"priority": priority,
"summary": summary,
"attendee": attendee,
})
return busy_slots
def find_earliest_free_slot(
calendars: Dict[str, List[List]],
attendees: List[str],
duration_minutes: int,
search_date_iso: str,
collective_hours: Dict[str, int],
) -> Optional[str]:
"""Find the earliest free slot on a given date for all attendees.
Iterates from min_start_hour to max_end_hour in 15-minute increments.
Returns the ISO 8601 string of the first conflict-free slot, or None.
"""
search_date = parse_iso(search_date_iso)
base_date = search_date.date()
tz = search_date.tzinfo
min_start = collective_hours["min_start_hour"]
max_end = collective_hours["max_end_hour"]
candidate = datetime(base_date.year, base_date.month, base_date.day,
min_start, 0, 0, tzinfo=tz)
end_boundary = datetime(base_date.year, base_date.month, base_date.day,
max_end, 0, 0, tzinfo=tz)
step = timedelta(minutes=15)
while candidate + timedelta(minutes=duration_minutes) <= end_boundary:
candidate_iso = candidate.isoformat()
candidate_end_iso = (candidate + timedelta(minutes=duration_minutes)).isoformat()
all_free = True
for attendee in attendees:
if not is_slot_free(attendee, candidate_iso, candidate_end_iso, calendars):
all_free = False
break
if all_free:
return candidate_iso
candidate += step
return None
|