scheduling_env / server /graders.py
Akshaykumarbm's picture
Upload folder using huggingface_hub
7bdbe90 verified
"""Graders for the meeting-scheduling RL environment.
Provides programmatic scoring (0.0–1.0) per episode and validation
that graders produce diverse scores across different agent trajectories.
"""
from __future__ import annotations
import logging
from typing import List
from .scheduling_logic import (
calculate_collective_hours,
calculate_final_reward,
find_conflicts,
parse_iso,
)
logger = logging.getLogger(__name__)
class SchedulingGrader:
"""Programmatic grader for scheduling tasks."""
def grade_episode(self, final_state, final_observation) -> float:
"""Score an episode in [0.0, 1.0].
Returns ``final_state.final_reward`` when the episode completed
successfully, with a 50 % penalty applied if any hard constraint
violations are detected.
"""
if not final_state.completed or not final_observation.success:
return 0.0
score = final_state.final_reward
violations = self._check_violations(final_state)
if violations:
score *= 0.5
logger.warning("Constraint violations: %s", violations)
return max(0.0, min(1.0, score))
def _check_violations(self, state) -> List[str]:
"""Detect hard constraint violations in the final state."""
violations: List[str] = []
req_priority = state.meeting_request.get("priority", 99)
# Violation 1: Rescheduled equal-or-higher priority meeting
for rm in state.rescheduled_meetings:
attendee = rm["attendee"]
old_start = rm["old_start"]
for entry in state.calendars.get(attendee, []):
if entry[0] == old_start and entry[2] <= req_priority:
violations.append(
f"Rescheduled higher priority meeting: "
f"{attendee} {old_start}"
)
# Violation 2: Proposed slot outside collective working hours
if state.proposed_slot:
collective = calculate_collective_hours(state.participant_preferences)
start = parse_iso(state.proposed_slot[0])
end = parse_iso(state.proposed_slot[1])
if start.hour < collective["min_start_hour"]:
violations.append(
f"Slot starts before working hours: {state.proposed_slot[0]}"
)
if end.hour > collective["max_end_hour"] or (
end.hour == collective["max_end_hour"] and end.minute > 0
):
violations.append(
f"Slot ends after working hours: {state.proposed_slot[1]}"
)
# Violation 3: Overlapping meetings after rescheduling
for user_id, calendar in state.calendars.items():
sorted_cal = sorted(calendar, key=lambda e: e[0])
for i in range(len(sorted_cal) - 1):
end_i = parse_iso(sorted_cal[i][1])
start_next = parse_iso(sorted_cal[i + 1][0])
if end_i > start_next:
violations.append(
f"Overlap for {user_id}: {sorted_cal[i][3]} / {sorted_cal[i+1][3]}"
)
return violations