"""Graders for the meeting-scheduling RL environment. Provides programmatic scoring (0.0–1.0) per episode and validation that graders produce diverse scores across different agent trajectories. """ from __future__ import annotations import logging from typing import List from .scheduling_logic import ( calculate_collective_hours, calculate_final_reward, find_conflicts, parse_iso, ) logger = logging.getLogger(__name__) class SchedulingGrader: """Programmatic grader for scheduling tasks.""" def grade_episode(self, final_state, final_observation) -> float: """Score an episode in [0.0, 1.0]. Returns ``final_state.final_reward`` when the episode completed successfully, with a 50 % penalty applied if any hard constraint violations are detected. """ if not final_state.completed or not final_observation.success: return 0.0 score = final_state.final_reward violations = self._check_violations(final_state) if violations: score *= 0.5 logger.warning("Constraint violations: %s", violations) return max(0.0, min(1.0, score)) def _check_violations(self, state) -> List[str]: """Detect hard constraint violations in the final state.""" violations: List[str] = [] req_priority = state.meeting_request.get("priority", 99) # Violation 1: Rescheduled equal-or-higher priority meeting for rm in state.rescheduled_meetings: attendee = rm["attendee"] old_start = rm["old_start"] for entry in state.calendars.get(attendee, []): if entry[0] == old_start and entry[2] <= req_priority: violations.append( f"Rescheduled higher priority meeting: " f"{attendee} {old_start}" ) # Violation 2: Proposed slot outside collective working hours if state.proposed_slot: collective = calculate_collective_hours(state.participant_preferences) start = parse_iso(state.proposed_slot[0]) end = parse_iso(state.proposed_slot[1]) if start.hour < collective["min_start_hour"]: violations.append( f"Slot starts before working hours: {state.proposed_slot[0]}" ) if end.hour > collective["max_end_hour"] or ( end.hour == collective["max_end_hour"] and end.minute > 0 ): violations.append( f"Slot ends after working hours: {state.proposed_slot[1]}" ) # Violation 3: Overlapping meetings after rescheduling for user_id, calendar in state.calendars.items(): sorted_cal = sorted(calendar, key=lambda e: e[0]) for i in range(len(sorted_cal) - 1): end_i = parse_iso(sorted_cal[i][1]) start_next = parse_iso(sorted_cal[i + 1][0]) if end_i > start_next: violations.append( f"Overlap for {user_id}: {sorted_cal[i][3]} / {sorted_cal[i+1][3]}" ) return violations