exec-assistant-arena / server /exec_assistant_arena_environment.py
SidraMiconi's picture
Upload folder using huggingface_hub
378cf8e verified
"""Executive Assistant Arena Environment Implementation."""
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from models import AssistantAction, AssistantObservation, AssistantState
from .scenario_generator import generate_scenario, Scenario, CalendarEvent, TIME_SLOTS
from .reward import score_reschedule, score_email_reply, score_terminal, RewardBreakdown
class ExecAssistantArenaEnvironment(Environment):
"""
An environment that simulates a personal assistant's morning inbox.
The agent must resolve calendar conflicts, draft email replies,
infer user preferences, and handle late-breaking changes.
Episodes are 10-20 steps. Rewards are rule-based and decomposed
into 6 components for training visibility.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = AssistantState(episode_id=str(uuid4()), step_count=0)
self.scenario: Scenario | None = None
self.late_change_injected = False
self.late_change_step: int | None = None
self.replied_emails: set[str] = set()
self.reward_breakdown = RewardBreakdown()
def reset(self, seed=None, difficulty="medium", **kwargs) -> AssistantObservation:
"""Reset the environment with a new procedural scenario."""
if isinstance(seed, str):
seed = hash(seed) % (2**31)
self.scenario = generate_scenario(difficulty, seed)
self.late_change_injected = False
self.late_change_step = None
self.replied_emails = set()
self.reward_breakdown = RewardBreakdown()
self._state = AssistantState(
episode_id=str(uuid4()),
step_count=0,
total_conflicts=len(self.scenario.conflicts),
total_emails=len([e for e in self.scenario.emails if e.requires_reply]),
total_preferences=len(self.scenario.preferences),
total_late_changes=len(self.scenario.late_changes),
)
# Build the welcome observation
pref_hints = "\n".join(f" - {desc}" for _, desc in self.scenario.preferences)
return AssistantObservation(
inbox_summary=self.scenario.inbox_text(),
calendar_view=self.scenario.calendar_text(),
pending_tasks=self.scenario.pending_tasks_text(),
tool_result=f"Good morning. You have {len(self.scenario.conflicts)} scheduling conflicts and {self._state.total_emails} emails needing replies.\n\nUser preferences:\n{pref_hints}",
conflicts=self.scenario.conflicts_text(),
done=False,
reward=0.0,
)
def step(self, action: AssistantAction, **kwargs) -> AssistantObservation:
"""Process one assistant action."""
if self.scenario is None:
self.reset()
self._state.step_count += 1
reward = 0.0
tool_result = ""
# Inject late change at step 7+
if self._state.step_count >= 7 and not self.late_change_injected:
change_desc = self.scenario.inject_late_change()
if change_desc:
self.late_change_injected = True
self.late_change_step = self._state.step_count
tool_result = f"*** LATE CHANGE: {change_desc} ***\n\n"
# Process tool call
tool = action.tool
args = action.arguments
if tool == "check_calendar":
tool_result += self.scenario.calendar_text()
# Free action - no reward
elif tool == "check_inbox":
tool_result += self.scenario.inbox_text()
# Free action
elif tool == "reschedule":
event_id = args.get("event_id", "")
new_time = args.get("new_time", "")
conflict_r, pref_r, msg = score_reschedule(
self.scenario, event_id, new_time, self.scenario.preferences
)
reward += conflict_r + pref_r
self.reward_breakdown.conflict_resolution += conflict_r
self.reward_breakdown.preference_inference += pref_r
if conflict_r > 0:
self._state.conflicts_resolved += 1
if pref_r > 0:
self._state.preferences_inferred += 1
tool_result += msg
elif tool == "draft_reply":
email_id = args.get("email_id", "")
body = args.get("body", "")
if email_id in self.replied_emails:
reward -= 0.2
self._state.unnecessary_actions += 1
self.reward_breakdown.efficiency_penalty -= 0.2
tool_result += f"Already replied to {email_id}."
else:
email_r, pref_r, msg = score_email_reply(
email_id, body, self.scenario, self.scenario.preferences
)
reward += email_r + pref_r
self.reward_breakdown.email_quality += email_r
self.reward_breakdown.preference_inference += pref_r
self._state.emails_drafted += 1
if pref_r > 0:
self._state.preferences_inferred += 1
self.replied_emails.add(email_id)
# Mark deadline as met
for e in self.scenario.emails:
if e.email_id == email_id and e.deadline:
self._state.deadlines_met += 1
self.reward_breakdown.deadline_adherence += 0.5
tool_result += msg
elif tool == "delegate_task":
task_desc = args.get("task", "")
to = args.get("to", "")
if task_desc and to:
tool_result += f"Delegated '{task_desc}' to {to}."
# Small positive if it's related to a late change
if self.late_change_injected and self.late_change_step:
reward += 0.5
self.reward_breakdown.late_change_recovery += 0.5
self._state.late_changes_handled += 1
else:
reward -= 0.2
self._state.unnecessary_actions += 1
self.reward_breakdown.efficiency_penalty -= 0.2
tool_result += "Delegate requires 'task' and 'to' arguments."
elif tool == "done":
# Compute terminal rewards
terminal = score_terminal(self.scenario)
# Credit back deadlines that were met
terminal.deadline_adherence += self._state.deadlines_met * 1.0
# Credit late changes handled
if self.late_change_injected:
# Check if agent took any action after the late change
handled = self._state.late_changes_handled > 0
if handled:
terminal.late_change_recovery += 2.0
self._state.late_changes_handled = max(1, self._state.late_changes_handled)
reward += terminal.total
self.reward_breakdown.deadline_adherence += terminal.deadline_adherence
self.reward_breakdown.late_change_recovery += terminal.late_change_recovery
self.reward_breakdown.conflict_resolution += terminal.conflict_resolution
tool_result += f"Episode complete. Final breakdown:\n"
tool_result += f" Conflicts resolved: {self._state.conflicts_resolved}/{self._state.total_conflicts}\n"
tool_result += f" Emails drafted: {self._state.emails_drafted}/{self._state.total_emails}\n"
tool_result += f" Preferences inferred: {self._state.preferences_inferred}/{self._state.total_preferences}\n"
tool_result += f" Deadlines met: {self._state.deadlines_met}\n"
tool_result += f" Late changes handled: {self._state.late_changes_handled}/{self._state.total_late_changes}\n"
else:
self._state.unnecessary_actions += 1
reward -= 0.2
self.reward_breakdown.efficiency_penalty -= 0.2
tool_result += f"Unknown tool: {tool}. Available: check_calendar, check_inbox, reschedule, draft_reply, delegate_task, done"
done = tool == "done" or self._state.step_count >= 20
self._state.cumulative_reward += reward
# If we hit max steps without "done", compute terminal penalties
if self._state.step_count >= 20 and tool != "done":
terminal = score_terminal(self.scenario)
terminal.deadline_adherence += self._state.deadlines_met * 1.0
reward += terminal.total
self._state.cumulative_reward += terminal.total
tool_result += "\n[Max steps reached - episode terminated]"
return AssistantObservation(
inbox_summary=self.scenario.inbox_text(),
calendar_view=self.scenario.calendar_text(),
pending_tasks=self.scenario.pending_tasks_text(),
tool_result=tool_result,
conflicts=self.scenario.conflicts_text(),
done=done,
reward=reward,
)
@property
def state(self) -> AssistantState:
return self._state