File size: 9,109 Bytes
378cf8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""Executive Assistant Arena Environment Implementation."""

from uuid import uuid4

from openenv.core.env_server.interfaces import Environment

from models import AssistantAction, AssistantObservation, AssistantState
from .scenario_generator import generate_scenario, Scenario, CalendarEvent, TIME_SLOTS
from .reward import score_reschedule, score_email_reply, score_terminal, RewardBreakdown


class ExecAssistantArenaEnvironment(Environment):
    """
    An environment that simulates a personal assistant's morning inbox.

    The agent must resolve calendar conflicts, draft email replies,
    infer user preferences, and handle late-breaking changes.

    Episodes are 10-20 steps. Rewards are rule-based and decomposed
    into 6 components for training visibility.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        self._state = AssistantState(episode_id=str(uuid4()), step_count=0)
        self.scenario: Scenario | None = None
        self.late_change_injected = False
        self.late_change_step: int | None = None
        self.replied_emails: set[str] = set()
        self.reward_breakdown = RewardBreakdown()

    def reset(self, seed=None, difficulty="medium", **kwargs) -> AssistantObservation:
        """Reset the environment with a new procedural scenario."""
        if isinstance(seed, str):
            seed = hash(seed) % (2**31)

        self.scenario = generate_scenario(difficulty, seed)
        self.late_change_injected = False
        self.late_change_step = None
        self.replied_emails = set()
        self.reward_breakdown = RewardBreakdown()

        self._state = AssistantState(
            episode_id=str(uuid4()),
            step_count=0,
            total_conflicts=len(self.scenario.conflicts),
            total_emails=len([e for e in self.scenario.emails if e.requires_reply]),
            total_preferences=len(self.scenario.preferences),
            total_late_changes=len(self.scenario.late_changes),
        )

        # Build the welcome observation
        pref_hints = "\n".join(f"  - {desc}" for _, desc in self.scenario.preferences)

        return AssistantObservation(
            inbox_summary=self.scenario.inbox_text(),
            calendar_view=self.scenario.calendar_text(),
            pending_tasks=self.scenario.pending_tasks_text(),
            tool_result=f"Good morning. You have {len(self.scenario.conflicts)} scheduling conflicts and {self._state.total_emails} emails needing replies.\n\nUser preferences:\n{pref_hints}",
            conflicts=self.scenario.conflicts_text(),
            done=False,
            reward=0.0,
        )

    def step(self, action: AssistantAction, **kwargs) -> AssistantObservation:
        """Process one assistant action."""
        if self.scenario is None:
            self.reset()

        self._state.step_count += 1
        reward = 0.0
        tool_result = ""

        # Inject late change at step 7+
        if self._state.step_count >= 7 and not self.late_change_injected:
            change_desc = self.scenario.inject_late_change()
            if change_desc:
                self.late_change_injected = True
                self.late_change_step = self._state.step_count
                tool_result = f"*** LATE CHANGE: {change_desc} ***\n\n"

        # Process tool call
        tool = action.tool
        args = action.arguments

        if tool == "check_calendar":
            tool_result += self.scenario.calendar_text()
            # Free action - no reward

        elif tool == "check_inbox":
            tool_result += self.scenario.inbox_text()
            # Free action

        elif tool == "reschedule":
            event_id = args.get("event_id", "")
            new_time = args.get("new_time", "")
            conflict_r, pref_r, msg = score_reschedule(
                self.scenario, event_id, new_time, self.scenario.preferences
            )
            reward += conflict_r + pref_r
            self.reward_breakdown.conflict_resolution += conflict_r
            self.reward_breakdown.preference_inference += pref_r
            if conflict_r > 0:
                self._state.conflicts_resolved += 1
            if pref_r > 0:
                self._state.preferences_inferred += 1
            tool_result += msg

        elif tool == "draft_reply":
            email_id = args.get("email_id", "")
            body = args.get("body", "")

            if email_id in self.replied_emails:
                reward -= 0.2
                self._state.unnecessary_actions += 1
                self.reward_breakdown.efficiency_penalty -= 0.2
                tool_result += f"Already replied to {email_id}."
            else:
                email_r, pref_r, msg = score_email_reply(
                    email_id, body, self.scenario, self.scenario.preferences
                )
                reward += email_r + pref_r
                self.reward_breakdown.email_quality += email_r
                self.reward_breakdown.preference_inference += pref_r
                self._state.emails_drafted += 1
                if pref_r > 0:
                    self._state.preferences_inferred += 1
                self.replied_emails.add(email_id)

                # Mark deadline as met
                for e in self.scenario.emails:
                    if e.email_id == email_id and e.deadline:
                        self._state.deadlines_met += 1
                        self.reward_breakdown.deadline_adherence += 0.5

                tool_result += msg

        elif tool == "delegate_task":
            task_desc = args.get("task", "")
            to = args.get("to", "")
            if task_desc and to:
                tool_result += f"Delegated '{task_desc}' to {to}."
                # Small positive if it's related to a late change
                if self.late_change_injected and self.late_change_step:
                    reward += 0.5
                    self.reward_breakdown.late_change_recovery += 0.5
                    self._state.late_changes_handled += 1
            else:
                reward -= 0.2
                self._state.unnecessary_actions += 1
                self.reward_breakdown.efficiency_penalty -= 0.2
                tool_result += "Delegate requires 'task' and 'to' arguments."

        elif tool == "done":
            # Compute terminal rewards
            terminal = score_terminal(self.scenario)

            # Credit back deadlines that were met
            terminal.deadline_adherence += self._state.deadlines_met * 1.0

            # Credit late changes handled
            if self.late_change_injected:
                # Check if agent took any action after the late change
                handled = self._state.late_changes_handled > 0
                if handled:
                    terminal.late_change_recovery += 2.0
                    self._state.late_changes_handled = max(1, self._state.late_changes_handled)

            reward += terminal.total
            self.reward_breakdown.deadline_adherence += terminal.deadline_adherence
            self.reward_breakdown.late_change_recovery += terminal.late_change_recovery
            self.reward_breakdown.conflict_resolution += terminal.conflict_resolution

            tool_result += f"Episode complete. Final breakdown:\n"
            tool_result += f"  Conflicts resolved: {self._state.conflicts_resolved}/{self._state.total_conflicts}\n"
            tool_result += f"  Emails drafted: {self._state.emails_drafted}/{self._state.total_emails}\n"
            tool_result += f"  Preferences inferred: {self._state.preferences_inferred}/{self._state.total_preferences}\n"
            tool_result += f"  Deadlines met: {self._state.deadlines_met}\n"
            tool_result += f"  Late changes handled: {self._state.late_changes_handled}/{self._state.total_late_changes}\n"

        else:
            self._state.unnecessary_actions += 1
            reward -= 0.2
            self.reward_breakdown.efficiency_penalty -= 0.2
            tool_result += f"Unknown tool: {tool}. Available: check_calendar, check_inbox, reschedule, draft_reply, delegate_task, done"

        done = tool == "done" or self._state.step_count >= 20
        self._state.cumulative_reward += reward

        # If we hit max steps without "done", compute terminal penalties
        if self._state.step_count >= 20 and tool != "done":
            terminal = score_terminal(self.scenario)
            terminal.deadline_adherence += self._state.deadlines_met * 1.0
            reward += terminal.total
            self._state.cumulative_reward += terminal.total
            tool_result += "\n[Max steps reached - episode terminated]"

        return AssistantObservation(
            inbox_summary=self.scenario.inbox_text(),
            calendar_view=self.scenario.calendar_text(),
            pending_tasks=self.scenario.pending_tasks_text(),
            tool_result=tool_result,
            conflicts=self.scenario.conflicts_text(),
            done=done,
            reward=reward,
        )

    @property
    def state(self) -> AssistantState:
        return self._state