Spaces:
Runtime error
Runtime error
File size: 9,109 Bytes
378cf8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """Executive Assistant Arena Environment Implementation."""
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from models import AssistantAction, AssistantObservation, AssistantState
from .scenario_generator import generate_scenario, Scenario, CalendarEvent, TIME_SLOTS
from .reward import score_reschedule, score_email_reply, score_terminal, RewardBreakdown
class ExecAssistantArenaEnvironment(Environment):
"""
An environment that simulates a personal assistant's morning inbox.
The agent must resolve calendar conflicts, draft email replies,
infer user preferences, and handle late-breaking changes.
Episodes are 10-20 steps. Rewards are rule-based and decomposed
into 6 components for training visibility.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = AssistantState(episode_id=str(uuid4()), step_count=0)
self.scenario: Scenario | None = None
self.late_change_injected = False
self.late_change_step: int | None = None
self.replied_emails: set[str] = set()
self.reward_breakdown = RewardBreakdown()
def reset(self, seed=None, difficulty="medium", **kwargs) -> AssistantObservation:
"""Reset the environment with a new procedural scenario."""
if isinstance(seed, str):
seed = hash(seed) % (2**31)
self.scenario = generate_scenario(difficulty, seed)
self.late_change_injected = False
self.late_change_step = None
self.replied_emails = set()
self.reward_breakdown = RewardBreakdown()
self._state = AssistantState(
episode_id=str(uuid4()),
step_count=0,
total_conflicts=len(self.scenario.conflicts),
total_emails=len([e for e in self.scenario.emails if e.requires_reply]),
total_preferences=len(self.scenario.preferences),
total_late_changes=len(self.scenario.late_changes),
)
# Build the welcome observation
pref_hints = "\n".join(f" - {desc}" for _, desc in self.scenario.preferences)
return AssistantObservation(
inbox_summary=self.scenario.inbox_text(),
calendar_view=self.scenario.calendar_text(),
pending_tasks=self.scenario.pending_tasks_text(),
tool_result=f"Good morning. You have {len(self.scenario.conflicts)} scheduling conflicts and {self._state.total_emails} emails needing replies.\n\nUser preferences:\n{pref_hints}",
conflicts=self.scenario.conflicts_text(),
done=False,
reward=0.0,
)
def step(self, action: AssistantAction, **kwargs) -> AssistantObservation:
"""Process one assistant action."""
if self.scenario is None:
self.reset()
self._state.step_count += 1
reward = 0.0
tool_result = ""
# Inject late change at step 7+
if self._state.step_count >= 7 and not self.late_change_injected:
change_desc = self.scenario.inject_late_change()
if change_desc:
self.late_change_injected = True
self.late_change_step = self._state.step_count
tool_result = f"*** LATE CHANGE: {change_desc} ***\n\n"
# Process tool call
tool = action.tool
args = action.arguments
if tool == "check_calendar":
tool_result += self.scenario.calendar_text()
# Free action - no reward
elif tool == "check_inbox":
tool_result += self.scenario.inbox_text()
# Free action
elif tool == "reschedule":
event_id = args.get("event_id", "")
new_time = args.get("new_time", "")
conflict_r, pref_r, msg = score_reschedule(
self.scenario, event_id, new_time, self.scenario.preferences
)
reward += conflict_r + pref_r
self.reward_breakdown.conflict_resolution += conflict_r
self.reward_breakdown.preference_inference += pref_r
if conflict_r > 0:
self._state.conflicts_resolved += 1
if pref_r > 0:
self._state.preferences_inferred += 1
tool_result += msg
elif tool == "draft_reply":
email_id = args.get("email_id", "")
body = args.get("body", "")
if email_id in self.replied_emails:
reward -= 0.2
self._state.unnecessary_actions += 1
self.reward_breakdown.efficiency_penalty -= 0.2
tool_result += f"Already replied to {email_id}."
else:
email_r, pref_r, msg = score_email_reply(
email_id, body, self.scenario, self.scenario.preferences
)
reward += email_r + pref_r
self.reward_breakdown.email_quality += email_r
self.reward_breakdown.preference_inference += pref_r
self._state.emails_drafted += 1
if pref_r > 0:
self._state.preferences_inferred += 1
self.replied_emails.add(email_id)
# Mark deadline as met
for e in self.scenario.emails:
if e.email_id == email_id and e.deadline:
self._state.deadlines_met += 1
self.reward_breakdown.deadline_adherence += 0.5
tool_result += msg
elif tool == "delegate_task":
task_desc = args.get("task", "")
to = args.get("to", "")
if task_desc and to:
tool_result += f"Delegated '{task_desc}' to {to}."
# Small positive if it's related to a late change
if self.late_change_injected and self.late_change_step:
reward += 0.5
self.reward_breakdown.late_change_recovery += 0.5
self._state.late_changes_handled += 1
else:
reward -= 0.2
self._state.unnecessary_actions += 1
self.reward_breakdown.efficiency_penalty -= 0.2
tool_result += "Delegate requires 'task' and 'to' arguments."
elif tool == "done":
# Compute terminal rewards
terminal = score_terminal(self.scenario)
# Credit back deadlines that were met
terminal.deadline_adherence += self._state.deadlines_met * 1.0
# Credit late changes handled
if self.late_change_injected:
# Check if agent took any action after the late change
handled = self._state.late_changes_handled > 0
if handled:
terminal.late_change_recovery += 2.0
self._state.late_changes_handled = max(1, self._state.late_changes_handled)
reward += terminal.total
self.reward_breakdown.deadline_adherence += terminal.deadline_adherence
self.reward_breakdown.late_change_recovery += terminal.late_change_recovery
self.reward_breakdown.conflict_resolution += terminal.conflict_resolution
tool_result += f"Episode complete. Final breakdown:\n"
tool_result += f" Conflicts resolved: {self._state.conflicts_resolved}/{self._state.total_conflicts}\n"
tool_result += f" Emails drafted: {self._state.emails_drafted}/{self._state.total_emails}\n"
tool_result += f" Preferences inferred: {self._state.preferences_inferred}/{self._state.total_preferences}\n"
tool_result += f" Deadlines met: {self._state.deadlines_met}\n"
tool_result += f" Late changes handled: {self._state.late_changes_handled}/{self._state.total_late_changes}\n"
else:
self._state.unnecessary_actions += 1
reward -= 0.2
self.reward_breakdown.efficiency_penalty -= 0.2
tool_result += f"Unknown tool: {tool}. Available: check_calendar, check_inbox, reschedule, draft_reply, delegate_task, done"
done = tool == "done" or self._state.step_count >= 20
self._state.cumulative_reward += reward
# If we hit max steps without "done", compute terminal penalties
if self._state.step_count >= 20 and tool != "done":
terminal = score_terminal(self.scenario)
terminal.deadline_adherence += self._state.deadlines_met * 1.0
reward += terminal.total
self._state.cumulative_reward += terminal.total
tool_result += "\n[Max steps reached - episode terminated]"
return AssistantObservation(
inbox_summary=self.scenario.inbox_text(),
calendar_view=self.scenario.calendar_text(),
pending_tasks=self.scenario.pending_tasks_text(),
tool_result=tool_result,
conflicts=self.scenario.conflicts_text(),
done=done,
reward=reward,
)
@property
def state(self) -> AssistantState:
return self._state
|