File size: 12,964 Bytes
96a5caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed3421
96a5caf
4ed3421
96a5caf
4ed3421
96a5caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
"""
Email Triage & Response Environment
OpenEnv-compatible environment for agent evaluation.
"""

import json
from typing import Optional, Literal
from pydantic import BaseModel, Field

class Email(BaseModel):
    id: str
    from_: str = Field(..., alias="from")
    subject: str
    body: str
    labels: list[str] = []
    replied: bool = False
    archived: bool = False
    flagged: bool = False
    flag_reason: Optional[str] = None
    reply_body: Optional[str] = None

    class Config:
        populate_by_name = True


class InboxState(BaseModel):
    inbox: list[Email]
    sent: list[dict] = []
    step_count: int = 0


class Observation(BaseModel):
    status: str
    message: str
    data: Optional[dict] = None
    step_count: int = 0


class Action(BaseModel):
    action: Literal["label", "draft_reply", "archive", "flag", "read", "list_inbox"]
    email_id: Optional[str] = None
    priority: Optional[Literal["urgent", "normal", "low"]] = None
    body: Optional[str] = None
    reason: Optional[str] = None


class StepResult(BaseModel):
    observation: Observation
    reward: float
    done: bool
    info: dict = {}

import os

# Load dataset (generated by curate_dataset.py)
DATASET_PATH = os.path.join(os.path.dirname(__file__), "data", "emails.json")
try:
    with open(DATASET_PATH, "r", encoding="utf-8") as f:
        _dataset = json.load(f)
except FileNotFoundError:
    # Fallback to empty if not curated yet, though curate_dataset.py should be run first
    _dataset = {"task1": {"emails": [], "ground_truth": {}}, "task2": {"emails": []}, "task3": {"emails": [], "ground_truth": {}, "urgent_ids": [], "archive_ids": [], "flag_ids": []}}

TASK1_EMAILS = _dataset["task1"]["emails"]
TASK1_GROUND_TRUTH = _dataset["task1"].get("ground_truth", {})

TASK2_EMAIL = _dataset["task2"]["emails"][0] if _dataset["task2"]["emails"] else {}

TASK3_EMAILS = _dataset["task3"]["emails"]
TASK3_GROUND_TRUTH = _dataset["task3"].get("ground_truth", {})

TASK3_URGENT_IDS = set(_dataset["task3"].get("urgent_ids", []))
TASK3_ARCHIVE_IDS = set(_dataset["task3"].get("archive_ids", []))
TASK3_FLAG_IDS    = set(_dataset["task3"].get("flag_ids", []))

def grade_task1(state: InboxState) -> float:
    score = 0.0
    for email in state.inbox:
        gt = TASK1_GROUND_TRUTH.get(email.id)
        if gt and "urgent" in email.labels and gt == "urgent":
            score += 0.2
        elif gt and "normal" in email.labels and gt == "normal":
            score += 0.2
        elif gt and "low" in email.labels and gt == "low":
            score += 0.2
    return round(min(score, 1.0), 2)


def grade_task2(state: InboxState) -> float:
    score = 0.0
    email = next((e for e in state.inbox if e.id == "t2_001"), None)
    if email is None or not email.replied or not email.reply_body:
        return 0.0

    reply = email.reply_body.lower()

    issues_covered = 0
    if "order" in reply and ("48291" in reply or "order" in reply):
        issues_covered += 1
    if any(w in reply for w in ["refund", "deliver", "shipment", "track"]):
        issues_covered += 1
    if any(w in reply for w in ["compensat", "apologi", "sorry", "inconvenien"]):
        issues_covered += 1
    score += 0.1 * issues_covered  # up to 0.3

    # +0.3 professional tone
    professional_signals = ["dear", "sincerely", "regards", "thank you", "we apologize",
                            "we understand", "please", "we will"]
    rude_signals = ["whatever", "not our fault", "calm down"]
    tone_score = sum(1 for w in professional_signals if w in reply)
    rude_penalty = sum(1 for w in rude_signals if w in reply)
    score += min(0.3, tone_score * 0.05) - (rude_penalty * 0.1)

    # +0.2 correct recipient / subject handling
    if email.reply_body and len(email.reply_body) > 50:
        score += 0.2

    # +0.2 no fabricated facts (heuristic: no invented order dates / amounts)
    fabrication_signals = ["$", "€", "refund amount", "exact date", "tracking number is"]
    fab_hits = sum(1 for w in fabrication_signals if w in reply)
    if fab_hits == 0:
        score += 0.2

    return round(max(0.0, min(score, 1.0)), 2)


def grade_task3(state: InboxState, penalties: dict) -> float:
    score = 0.0
    email_map = {e.id: e for e in state.inbox}

    # Priority labels (0.2 per correct, 10 emails = max 2.0 → normalise to 0.5 weight)
    label_score = 0.0
    for eid, gt in TASK3_GROUND_TRUTH.items():
        email = email_map.get(eid)
        if email and gt in email.labels:
            label_score += 0.2
    score += min(label_score, 2.0) * 0.25   # normalise to 0.5

    # Replies for urgent emails (max 0.4)
    reply_scores = []
    for eid in TASK3_URGENT_IDS:
        email = email_map.get(eid)
        if email and email.replied and email.reply_body:
            reply_scores.append(min(len(email.reply_body) / 200, 1.0) * 0.1)
    score += sum(reply_scores)

    # Archive spam (0.05 each, max 0.1)
    for eid in TASK3_ARCHIVE_IDS:
        email = email_map.get(eid)
        if email and email.archived:
            score += 0.05

    # Flag ambiguous (0.05 each)
    for eid in TASK3_FLAG_IDS:
        email = email_map.get(eid)
        if email and email.flagged:
            score += 0.05

    # Penalties
    score -= penalties.get("destructive_actions", 0) * 0.1
    score -= penalties.get("loop_actions", 0) * 0.05

    return round(max(0.0, min(score, 1.0)), 2)


# ---------------------------------------------------------------------------
# Environment Class
# ---------------------------------------------------------------------------

class EmailTriageEnv:
    """OpenEnv-compatible Email Triage environment."""

    TASKS = {1, 2, 3}

    def __init__(self, task: int = 1):
        assert task in self.TASKS, f"task must be one of {self.TASKS}"
        self.task = task
        self._state: Optional[InboxState] = None
        self._penalties = {"destructive_actions": 0, "loop_actions": 0}
        self._action_history: list[str] = []
        self._done = False

    # ------------------------------------------------------------------
    # OpenEnv interface
    # ------------------------------------------------------------------

    def reset(self) -> Observation:
        self._penalties = {"destructive_actions": 0, "loop_actions": 0}
        self._action_history = []
        self._done = False

        if self.task == 1:
            emails = [Email.model_validate(e) for e in TASK1_EMAILS]
        elif self.task == 2:
            emails = [Email.model_validate(TASK2_EMAIL)]
        else:
            emails = [Email.model_validate(e) for e in TASK3_EMAILS]

        self._state = InboxState(inbox=emails)

        return Observation(
            status="ok",
            message=f"Task {self.task} environment reset. Inbox contains {len(emails)} email(s).",
            data={"task": self.task, "inbox_size": len(emails)},
            step_count=0,
        )

    def state(self) -> dict:
        assert self._state is not None, "Call reset() first."
        return json.loads(self._state.model_dump_json(by_alias=True))

    def step(self, action: Action) -> StepResult:
        assert self._state is not None, "Call reset() first."

        if self._done:
            return StepResult(
                observation=Observation(status="done", message="Episode already finished.", step_count=self._state.step_count),
                reward=0.0,
                done=True,
            )

        self._state.step_count += 1
        action_key = f"{action.action}:{action.email_id}"

        # Loop detection
        if self._action_history.count(action_key) >= 2:
            self._penalties["loop_actions"] += 1

        self._action_history.append(action_key)
        obs, reward = self._dispatch(action)
        obs.step_count = self._state.step_count
        return StepResult(observation=obs, reward=reward, done=self._done)

    def score(self) -> float:
        """Return current cumulative score (0-1)."""
        assert self._state is not None, "Call reset() first."
        if self.task == 1:
            return grade_task1(self._state)
        elif self.task == 2:
            return grade_task2(self._state)
        else:
            return grade_task3(self._state, self._penalties)

    # ------------------------------------------------------------------
    # Action dispatch
    # ------------------------------------------------------------------

    def _dispatch(self, action: Action):
        handlers = {
            "list_inbox": self._act_list_inbox,
            "read":       self._act_read,
            "label":      self._act_label,
            "draft_reply":self._act_draft_reply,
            "archive":    self._act_archive,
            "flag":       self._act_flag,
        }
        handler = handlers.get(action.action)
        if handler is None:
            return Observation(status="error", message=f"Unknown action: {action.action}"), 0.0
        return handler(action)

    def _act_list_inbox(self, action: Action):
        summaries = [
            {"id": e.id, "from": e.from_, "subject": e.subject,
             "labels": e.labels, "replied": e.replied, "archived": e.archived, "flagged": e.flagged}
            for e in self._state.inbox
        ]
        return Observation(status="ok", message="Inbox listed.", data={"emails": summaries}), 0.0

    def _act_read(self, action: Action):
        email = self._find(action.email_id)
        if email is None:
            return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
        return Observation(
            status="ok",
            message=f"Read email {action.email_id}.",
            data=json.loads(email.model_dump_json(by_alias=True)),
        ), 0.0

    def _act_label(self, action: Action):
        email = self._find(action.email_id)
        if email is None:
            return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
        if action.priority not in ("urgent", "normal", "low"):
            return Observation(status="error", message="priority must be urgent | normal | low"), 0.0

        # Remove existing priority labels then add new
        email.labels = [l for l in email.labels if l not in ("urgent", "normal", "low")]
        email.labels.append(action.priority)

        incremental = self._incremental_label_reward(email.id, action.priority)
        return Observation(
            status="ok",
            message=f"Labelled {action.email_id} as {action.priority}.",
            data={"email_id": action.email_id, "priority": action.priority},
        ), incremental

    def _act_draft_reply(self, action: Action):
        email = self._find(action.email_id)
        if email is None:
            return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
        if not action.body or len(action.body.strip()) < 10:
            return Observation(status="error", message="Reply body too short."), 0.0

        email.replied = True
        email.reply_body = action.body
        self._state.sent.append({"to": email.from_, "subject": f"Re: {email.subject}", "body": action.body})
        return Observation(status="ok", message=f"Reply drafted for {action.email_id}."), 0.0

    def _act_archive(self, action: Action):
        email = self._find(action.email_id)
        if email is None:
            return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0

        # Penalty if archiving urgent email
        if "urgent" in email.labels:
            self._penalties["destructive_actions"] += 1
            return Observation(
                status="warning",
                message=f"Archived urgent email {action.email_id} — penalty applied.",
            ), -0.1

        email.archived = True
        return Observation(status="ok", message=f"Email {action.email_id} archived."), 0.0

    def _act_flag(self, action: Action):
        email = self._find(action.email_id)
        if email is None:
            return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
        email.flagged = True
        email.flag_reason = action.reason or "unspecified"
        return Observation(status="ok", message=f"Email {action.email_id} flagged for human review."), 0.0

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _find(self, email_id: Optional[str]) -> Optional[Email]:
        if email_id is None:
            return None
        return next((e for e in self._state.inbox if e.id == email_id), None)

    def _incremental_label_reward(self, email_id: str, priority: str) -> float:
        """Return +0.2 if label matches ground truth for task 1."""
        if self.task == 1:
            gt = TASK1_GROUND_TRUTH.get(email_id)
            return 0.2 if gt == priority else 0.0
        return 0.0