Spaces:
Runtime error
Runtime error
| """Executive Assistant Arena Environment Implementation.""" | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from models import AssistantAction, AssistantObservation, AssistantState | |
| from .scenario_generator import generate_scenario, Scenario, CalendarEvent, TIME_SLOTS | |
| from .reward import score_reschedule, score_email_reply, score_terminal, RewardBreakdown | |
| class ExecAssistantArenaEnvironment(Environment): | |
| """ | |
| An environment that simulates a personal assistant's morning inbox. | |
| The agent must resolve calendar conflicts, draft email replies, | |
| infer user preferences, and handle late-breaking changes. | |
| Episodes are 10-20 steps. Rewards are rule-based and decomposed | |
| into 6 components for training visibility. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._state = AssistantState(episode_id=str(uuid4()), step_count=0) | |
| self.scenario: Scenario | None = None | |
| self.late_change_injected = False | |
| self.late_change_step: int | None = None | |
| self.replied_emails: set[str] = set() | |
| self.reward_breakdown = RewardBreakdown() | |
| def reset(self, seed=None, difficulty="medium", **kwargs) -> AssistantObservation: | |
| """Reset the environment with a new procedural scenario.""" | |
| if isinstance(seed, str): | |
| seed = hash(seed) % (2**31) | |
| self.scenario = generate_scenario(difficulty, seed) | |
| self.late_change_injected = False | |
| self.late_change_step = None | |
| self.replied_emails = set() | |
| self.reward_breakdown = RewardBreakdown() | |
| self._state = AssistantState( | |
| episode_id=str(uuid4()), | |
| step_count=0, | |
| total_conflicts=len(self.scenario.conflicts), | |
| total_emails=len([e for e in self.scenario.emails if e.requires_reply]), | |
| total_preferences=len(self.scenario.preferences), | |
| total_late_changes=len(self.scenario.late_changes), | |
| ) | |
| # Build the welcome observation | |
| pref_hints = "\n".join(f" - {desc}" for _, desc in self.scenario.preferences) | |
| return AssistantObservation( | |
| inbox_summary=self.scenario.inbox_text(), | |
| calendar_view=self.scenario.calendar_text(), | |
| pending_tasks=self.scenario.pending_tasks_text(), | |
| tool_result=f"Good morning. You have {len(self.scenario.conflicts)} scheduling conflicts and {self._state.total_emails} emails needing replies.\n\nUser preferences:\n{pref_hints}", | |
| conflicts=self.scenario.conflicts_text(), | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: AssistantAction, **kwargs) -> AssistantObservation: | |
| """Process one assistant action.""" | |
| if self.scenario is None: | |
| self.reset() | |
| self._state.step_count += 1 | |
| reward = 0.0 | |
| tool_result = "" | |
| # Inject late change at step 7+ | |
| if self._state.step_count >= 7 and not self.late_change_injected: | |
| change_desc = self.scenario.inject_late_change() | |
| if change_desc: | |
| self.late_change_injected = True | |
| self.late_change_step = self._state.step_count | |
| tool_result = f"*** LATE CHANGE: {change_desc} ***\n\n" | |
| # Process tool call | |
| tool = action.tool | |
| args = action.arguments | |
| if tool == "check_calendar": | |
| tool_result += self.scenario.calendar_text() | |
| # Free action - no reward | |
| elif tool == "check_inbox": | |
| tool_result += self.scenario.inbox_text() | |
| # Free action | |
| elif tool == "reschedule": | |
| event_id = args.get("event_id", "") | |
| new_time = args.get("new_time", "") | |
| conflict_r, pref_r, msg = score_reschedule( | |
| self.scenario, event_id, new_time, self.scenario.preferences | |
| ) | |
| reward += conflict_r + pref_r | |
| self.reward_breakdown.conflict_resolution += conflict_r | |
| self.reward_breakdown.preference_inference += pref_r | |
| if conflict_r > 0: | |
| self._state.conflicts_resolved += 1 | |
| if pref_r > 0: | |
| self._state.preferences_inferred += 1 | |
| tool_result += msg | |
| elif tool == "draft_reply": | |
| email_id = args.get("email_id", "") | |
| body = args.get("body", "") | |
| if email_id in self.replied_emails: | |
| reward -= 0.2 | |
| self._state.unnecessary_actions += 1 | |
| self.reward_breakdown.efficiency_penalty -= 0.2 | |
| tool_result += f"Already replied to {email_id}." | |
| else: | |
| email_r, pref_r, msg = score_email_reply( | |
| email_id, body, self.scenario, self.scenario.preferences | |
| ) | |
| reward += email_r + pref_r | |
| self.reward_breakdown.email_quality += email_r | |
| self.reward_breakdown.preference_inference += pref_r | |
| self._state.emails_drafted += 1 | |
| if pref_r > 0: | |
| self._state.preferences_inferred += 1 | |
| self.replied_emails.add(email_id) | |
| # Mark deadline as met | |
| for e in self.scenario.emails: | |
| if e.email_id == email_id and e.deadline: | |
| self._state.deadlines_met += 1 | |
| self.reward_breakdown.deadline_adherence += 0.5 | |
| tool_result += msg | |
| elif tool == "delegate_task": | |
| task_desc = args.get("task", "") | |
| to = args.get("to", "") | |
| if task_desc and to: | |
| tool_result += f"Delegated '{task_desc}' to {to}." | |
| # Small positive if it's related to a late change | |
| if self.late_change_injected and self.late_change_step: | |
| reward += 0.5 | |
| self.reward_breakdown.late_change_recovery += 0.5 | |
| self._state.late_changes_handled += 1 | |
| else: | |
| reward -= 0.2 | |
| self._state.unnecessary_actions += 1 | |
| self.reward_breakdown.efficiency_penalty -= 0.2 | |
| tool_result += "Delegate requires 'task' and 'to' arguments." | |
| elif tool == "done": | |
| # Compute terminal rewards | |
| terminal = score_terminal(self.scenario) | |
| # Credit back deadlines that were met | |
| terminal.deadline_adherence += self._state.deadlines_met * 1.0 | |
| # Credit late changes handled | |
| if self.late_change_injected: | |
| # Check if agent took any action after the late change | |
| handled = self._state.late_changes_handled > 0 | |
| if handled: | |
| terminal.late_change_recovery += 2.0 | |
| self._state.late_changes_handled = max(1, self._state.late_changes_handled) | |
| reward += terminal.total | |
| self.reward_breakdown.deadline_adherence += terminal.deadline_adherence | |
| self.reward_breakdown.late_change_recovery += terminal.late_change_recovery | |
| self.reward_breakdown.conflict_resolution += terminal.conflict_resolution | |
| tool_result += f"Episode complete. Final breakdown:\n" | |
| tool_result += f" Conflicts resolved: {self._state.conflicts_resolved}/{self._state.total_conflicts}\n" | |
| tool_result += f" Emails drafted: {self._state.emails_drafted}/{self._state.total_emails}\n" | |
| tool_result += f" Preferences inferred: {self._state.preferences_inferred}/{self._state.total_preferences}\n" | |
| tool_result += f" Deadlines met: {self._state.deadlines_met}\n" | |
| tool_result += f" Late changes handled: {self._state.late_changes_handled}/{self._state.total_late_changes}\n" | |
| else: | |
| self._state.unnecessary_actions += 1 | |
| reward -= 0.2 | |
| self.reward_breakdown.efficiency_penalty -= 0.2 | |
| tool_result += f"Unknown tool: {tool}. Available: check_calendar, check_inbox, reschedule, draft_reply, delegate_task, done" | |
| done = tool == "done" or self._state.step_count >= 20 | |
| self._state.cumulative_reward += reward | |
| # If we hit max steps without "done", compute terminal penalties | |
| if self._state.step_count >= 20 and tool != "done": | |
| terminal = score_terminal(self.scenario) | |
| terminal.deadline_adherence += self._state.deadlines_met * 1.0 | |
| reward += terminal.total | |
| self._state.cumulative_reward += terminal.total | |
| tool_result += "\n[Max steps reached - episode terminated]" | |
| return AssistantObservation( | |
| inbox_summary=self.scenario.inbox_text(), | |
| calendar_view=self.scenario.calendar_text(), | |
| pending_tasks=self.scenario.pending_tasks_text(), | |
| tool_result=tool_result, | |
| conflicts=self.scenario.conflicts_text(), | |
| done=done, | |
| reward=reward, | |
| ) | |
| def state(self) -> AssistantState: | |
| return self._state | |