Spaces:
Runtime error
Runtime error
| """Decomposed reward computation for the Executive Assistant Arena. | |
| All rewards are rule-based and deterministic. No LLM judges. | |
| Each component is logged separately for W&B tracking. | |
| """ | |
| from dataclasses import dataclass | |
| from .scenario_generator import Scenario, TIME_SLOTS | |
| class RewardBreakdown: | |
| conflict_resolution: float = 0.0 | |
| preference_inference: float = 0.0 | |
| email_quality: float = 0.0 | |
| deadline_adherence: float = 0.0 | |
| efficiency_penalty: float = 0.0 | |
| late_change_recovery: float = 0.0 | |
| def total(self) -> float: | |
| return ( | |
| self.conflict_resolution | |
| + self.preference_inference | |
| + self.email_quality | |
| + self.deadline_adherence | |
| + self.efficiency_penalty | |
| + self.late_change_recovery | |
| ) | |
| def score_reschedule( | |
| scenario: Scenario, | |
| event_id: str, | |
| new_time: str, | |
| preferences: list[tuple[str, str]], | |
| ) -> tuple[float, float, str]: | |
| """Score a reschedule action. Returns (conflict_reward, pref_reward, message).""" | |
| event = None | |
| for e in scenario.calendar: | |
| if e.event_id == event_id: | |
| event = e | |
| break | |
| if event is None: | |
| return -0.2, 0.0, f"Event {event_id} not found." | |
| if not event.can_reschedule: | |
| return -0.5, 0.0, f"Event {event_id} cannot be rescheduled (high priority)." | |
| if new_time not in TIME_SLOTS: | |
| return -0.2, 0.0, f"Invalid time slot: {new_time}." | |
| # Check if this resolves a conflict | |
| old_time = event.time | |
| was_in_conflict = any( | |
| event_id in (a, b) for a, b in scenario.conflicts | |
| ) | |
| # Temporarily move event and check new conflicts | |
| event.time = new_time | |
| time_index = {t: i for i, t in enumerate(TIME_SLOTS)} | |
| creates_new_conflict = False | |
| for other in scenario.calendar: | |
| if other.event_id == event_id: | |
| continue | |
| if other.time in time_index and new_time in time_index: | |
| o_start = time_index[other.time] | |
| n_start = time_index[new_time] | |
| o_slots = other.duration_min // 30 | |
| e_slots = event.duration_min // 30 | |
| if n_start < o_start + o_slots and o_start < n_start + e_slots: | |
| creates_new_conflict = True | |
| break | |
| conflict_reward = 0.0 | |
| if was_in_conflict and not creates_new_conflict: | |
| conflict_reward = 1.0 | |
| # Remove resolved conflicts | |
| scenario.conflicts = [ | |
| (a, b) for a, b in scenario.conflicts | |
| if event_id not in (a, b) | |
| ] | |
| msg = f"Conflict resolved: {event_id} moved to {new_time}." | |
| elif creates_new_conflict: | |
| conflict_reward = -0.5 | |
| event.time = old_time # revert | |
| msg = f"Cannot move {event_id} to {new_time} - creates new conflict." | |
| else: | |
| conflict_reward = 0.0 | |
| msg = f"Moved {event_id} to {new_time} (no conflict impact)." | |
| # Check preference alignment | |
| pref_reward = 0.0 | |
| pref_ids = [p[0] for p in preferences] | |
| if "no_early_meetings" in pref_ids and new_time in ["9:00am", "9:30am"]: | |
| pref_reward -= 0.3 | |
| msg += " Warning: user prefers no early meetings." | |
| if "lunch_block" in pref_ids and new_time in ["12:00pm", "12:30pm"]: | |
| pref_reward -= 0.3 | |
| msg += " Warning: moved into lunch block." | |
| if "no_early_meetings" in pref_ids and old_time in ["9:00am", "9:30am"] and new_time not in ["9:00am", "9:30am"]: | |
| pref_reward += 0.5 | |
| msg += " Good: moved away from early slot per preference." | |
| if "buffer_time" in pref_ids or "no_back_to_back" in pref_ids: | |
| # Check adjacent meetings | |
| n_idx = time_index.get(new_time, -1) | |
| for other in scenario.calendar: | |
| if other.event_id == event_id: | |
| continue | |
| o_idx = time_index.get(other.time, -1) | |
| if abs(n_idx - o_idx) == 1: | |
| pref_reward -= 0.3 | |
| msg += " Warning: back-to-back meeting created." | |
| break | |
| return conflict_reward, pref_reward, msg | |
| def score_email_reply( | |
| email_id: str, | |
| reply_body: str, | |
| scenario: Scenario, | |
| preferences: list[tuple[str, str]], | |
| ) -> tuple[float, float, str]: | |
| """Score an email reply. Returns (email_reward, pref_reward, message).""" | |
| email = None | |
| for e in scenario.emails: | |
| if e.email_id == email_id: | |
| email = e | |
| break | |
| if email is None: | |
| return -0.2, 0.0, f"Email {email_id} not found." | |
| if not reply_body or len(reply_body.strip()) < 10: | |
| return 0.0, 0.0, "Reply too short." | |
| reply_lower = reply_body.lower() | |
| # Score: addresses_issue (0.4) | |
| addresses_score = 0.0 | |
| for kp in email.key_points: | |
| # Simple keyword matching | |
| keywords = kp.lower().split() | |
| matches = sum(1 for kw in keywords if kw in reply_lower) | |
| if matches >= len(keywords) * 0.3: | |
| addresses_score += 0.4 / len(email.key_points) | |
| # Score: tone (0.3) | |
| formal_markers = ["dear", "regards", "sincerely", "please find", "i would like to"] | |
| informal_markers = ["hey", "hi!", "thanks!", "sounds good", "sure thing", "no worries"] | |
| formal_count = sum(1 for m in formal_markers if m in reply_lower) | |
| informal_count = sum(1 for m in informal_markers if m in reply_lower) | |
| tone_score = 0.0 | |
| if email.tone_expected == "formal" and formal_count > informal_count: | |
| tone_score = 0.3 | |
| elif email.tone_expected == "informal" and informal_count >= formal_count: | |
| tone_score = 0.3 | |
| elif formal_count == 0 and informal_count == 0: | |
| tone_score = 0.15 # neutral is ok | |
| # Score: preference alignment (0.3) | |
| pref_score = 0.0 | |
| pref_ids = [p[0] for p in preferences] | |
| if "informal_tone" in pref_ids and informal_count > 0: | |
| pref_score += 0.3 | |
| elif "formal_tone" in pref_ids and formal_count > 0: | |
| pref_score += 0.3 | |
| elif "informal_tone" not in pref_ids and "formal_tone" not in pref_ids: | |
| pref_score += 0.15 # no tone preference | |
| email_reward = addresses_score + tone_score + pref_score | |
| pref_reward = 0.0 | |
| if pref_score > 0: | |
| pref_reward = 0.5 # preference inferred | |
| msg = f"Email reply scored: addresses={addresses_score:.2f}, tone={tone_score:.2f}, pref={pref_score:.2f}" | |
| return email_reward, pref_reward, msg | |
| def score_terminal(scenario: Scenario) -> RewardBreakdown: | |
| """Compute terminal rewards at episode end.""" | |
| breakdown = RewardBreakdown() | |
| # Deadline adherence | |
| for email in scenario.emails: | |
| if email.deadline and email.requires_reply: | |
| breakdown.deadline_adherence -= 1.0 # missed deadline (unreplied) | |
| elif email.deadline is None and email.requires_reply: | |
| breakdown.deadline_adherence -= 0.5 # unreplied but no deadline | |
| # Unresolved conflicts | |
| remaining = len(scenario.conflicts) | |
| breakdown.conflict_resolution -= remaining * 0.5 | |
| # Late changes not handled | |
| for lc in scenario.late_changes: | |
| if lc.injected: | |
| breakdown.late_change_recovery += 0.0 # was injected but not handled | |
| return breakdown | |