SidraMiconi's picture
Upload folder using huggingface_hub
378cf8e verified
"""Decomposed reward computation for the Executive Assistant Arena.
All rewards are rule-based and deterministic. No LLM judges.
Each component is logged separately for W&B tracking.
"""
from dataclasses import dataclass
from .scenario_generator import Scenario, TIME_SLOTS
@dataclass
class RewardBreakdown:
conflict_resolution: float = 0.0
preference_inference: float = 0.0
email_quality: float = 0.0
deadline_adherence: float = 0.0
efficiency_penalty: float = 0.0
late_change_recovery: float = 0.0
@property
def total(self) -> float:
return (
self.conflict_resolution
+ self.preference_inference
+ self.email_quality
+ self.deadline_adherence
+ self.efficiency_penalty
+ self.late_change_recovery
)
def score_reschedule(
scenario: Scenario,
event_id: str,
new_time: str,
preferences: list[tuple[str, str]],
) -> tuple[float, float, str]:
"""Score a reschedule action. Returns (conflict_reward, pref_reward, message)."""
event = None
for e in scenario.calendar:
if e.event_id == event_id:
event = e
break
if event is None:
return -0.2, 0.0, f"Event {event_id} not found."
if not event.can_reschedule:
return -0.5, 0.0, f"Event {event_id} cannot be rescheduled (high priority)."
if new_time not in TIME_SLOTS:
return -0.2, 0.0, f"Invalid time slot: {new_time}."
# Check if this resolves a conflict
old_time = event.time
was_in_conflict = any(
event_id in (a, b) for a, b in scenario.conflicts
)
# Temporarily move event and check new conflicts
event.time = new_time
time_index = {t: i for i, t in enumerate(TIME_SLOTS)}
creates_new_conflict = False
for other in scenario.calendar:
if other.event_id == event_id:
continue
if other.time in time_index and new_time in time_index:
o_start = time_index[other.time]
n_start = time_index[new_time]
o_slots = other.duration_min // 30
e_slots = event.duration_min // 30
if n_start < o_start + o_slots and o_start < n_start + e_slots:
creates_new_conflict = True
break
conflict_reward = 0.0
if was_in_conflict and not creates_new_conflict:
conflict_reward = 1.0
# Remove resolved conflicts
scenario.conflicts = [
(a, b) for a, b in scenario.conflicts
if event_id not in (a, b)
]
msg = f"Conflict resolved: {event_id} moved to {new_time}."
elif creates_new_conflict:
conflict_reward = -0.5
event.time = old_time # revert
msg = f"Cannot move {event_id} to {new_time} - creates new conflict."
else:
conflict_reward = 0.0
msg = f"Moved {event_id} to {new_time} (no conflict impact)."
# Check preference alignment
pref_reward = 0.0
pref_ids = [p[0] for p in preferences]
if "no_early_meetings" in pref_ids and new_time in ["9:00am", "9:30am"]:
pref_reward -= 0.3
msg += " Warning: user prefers no early meetings."
if "lunch_block" in pref_ids and new_time in ["12:00pm", "12:30pm"]:
pref_reward -= 0.3
msg += " Warning: moved into lunch block."
if "no_early_meetings" in pref_ids and old_time in ["9:00am", "9:30am"] and new_time not in ["9:00am", "9:30am"]:
pref_reward += 0.5
msg += " Good: moved away from early slot per preference."
if "buffer_time" in pref_ids or "no_back_to_back" in pref_ids:
# Check adjacent meetings
n_idx = time_index.get(new_time, -1)
for other in scenario.calendar:
if other.event_id == event_id:
continue
o_idx = time_index.get(other.time, -1)
if abs(n_idx - o_idx) == 1:
pref_reward -= 0.3
msg += " Warning: back-to-back meeting created."
break
return conflict_reward, pref_reward, msg
def score_email_reply(
email_id: str,
reply_body: str,
scenario: Scenario,
preferences: list[tuple[str, str]],
) -> tuple[float, float, str]:
"""Score an email reply. Returns (email_reward, pref_reward, message)."""
email = None
for e in scenario.emails:
if e.email_id == email_id:
email = e
break
if email is None:
return -0.2, 0.0, f"Email {email_id} not found."
if not reply_body or len(reply_body.strip()) < 10:
return 0.0, 0.0, "Reply too short."
reply_lower = reply_body.lower()
# Score: addresses_issue (0.4)
addresses_score = 0.0
for kp in email.key_points:
# Simple keyword matching
keywords = kp.lower().split()
matches = sum(1 for kw in keywords if kw in reply_lower)
if matches >= len(keywords) * 0.3:
addresses_score += 0.4 / len(email.key_points)
# Score: tone (0.3)
formal_markers = ["dear", "regards", "sincerely", "please find", "i would like to"]
informal_markers = ["hey", "hi!", "thanks!", "sounds good", "sure thing", "no worries"]
formal_count = sum(1 for m in formal_markers if m in reply_lower)
informal_count = sum(1 for m in informal_markers if m in reply_lower)
tone_score = 0.0
if email.tone_expected == "formal" and formal_count > informal_count:
tone_score = 0.3
elif email.tone_expected == "informal" and informal_count >= formal_count:
tone_score = 0.3
elif formal_count == 0 and informal_count == 0:
tone_score = 0.15 # neutral is ok
# Score: preference alignment (0.3)
pref_score = 0.0
pref_ids = [p[0] for p in preferences]
if "informal_tone" in pref_ids and informal_count > 0:
pref_score += 0.3
elif "formal_tone" in pref_ids and formal_count > 0:
pref_score += 0.3
elif "informal_tone" not in pref_ids and "formal_tone" not in pref_ids:
pref_score += 0.15 # no tone preference
email_reward = addresses_score + tone_score + pref_score
pref_reward = 0.0
if pref_score > 0:
pref_reward = 0.5 # preference inferred
msg = f"Email reply scored: addresses={addresses_score:.2f}, tone={tone_score:.2f}, pref={pref_score:.2f}"
return email_reward, pref_reward, msg
def score_terminal(scenario: Scenario) -> RewardBreakdown:
"""Compute terminal rewards at episode end."""
breakdown = RewardBreakdown()
# Deadline adherence
for email in scenario.emails:
if email.deadline and email.requires_reply:
breakdown.deadline_adherence -= 1.0 # missed deadline (unreplied)
elif email.deadline is None and email.requires_reply:
breakdown.deadline_adherence -= 0.5 # unreplied but no deadline
# Unresolved conflicts
remaining = len(scenario.conflicts)
breakdown.conflict_resolution -= remaining * 0.5
# Late changes not handled
for lc in scenario.late_changes:
if lc.injected:
breakdown.late_change_recovery += 0.0 # was injected but not handled
return breakdown