File size: 3,167 Bytes
cf0020c
 
 
 
 
 
 
 
 
 
 
 
 
 
c5fdb67
cf0020c
c5fdb67
 
7622c8f
 
c5fdb67
 
 
 
 
7622c8f
 
c5fdb67
7622c8f
c5fdb67
 
 
 
 
7622c8f
 
c5fdb67
cf0020c
c5fdb67
 
 
 
cf0020c
c5fdb67
 
cf0020c
c5fdb67
cf0020c
 
 
 
c5fdb67
 
 
 
 
 
 
cf0020c
7622c8f
cf0020c
 
 
 
 
 
 
 
 
7622c8f
cf0020c
 
c5fdb67
cf0020c
 
 
 
 
 
c5fdb67
7622c8f
 
 
cf0020c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import List, Dict, Any, Tuple

URGENCY_LABELS = ["low", "medium", "high"]
ROUTING_LABELS = ["general", "support", "security"]
RESOLUTION_LABELS = ["ignore", "respond", "escalate"]


class EmailTriageEnv:
    def __init__(self, task: str = "easy"):
        self.task = task
        self._queue: List[Dict] = []
        self._index = 0
        self._done = False

    # ✅ EXPLICIT TASK DATA (NO RANDOMNESS)
    def _generate_emails(self) -> List[Dict]:
        if self.task == "easy":
            return [
                {"description": "Password reset not working", "label": [2, 1, 2]},
                {"description": "Billing refund request", "label": [1, 2, 2]},
                {"description": "App is slow", "label": [0, 1, 1]},
            ]

        elif self.task == "medium":
            return [
                {"description": "Password reset not working", "label": [2, 1, 2]},
                {"description": "Billing refund request", "label": [1, 2, 2]},
                {"description": "App is slow", "label": [0, 1, 1]},
                {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
                {"description": "Invoice mismatch issue", "label": [1, 2, 2]},
            ]

        elif self.task == "hard":
            return [
                {"description": "Password reset not working", "label": [2, 1, 2]},
                {"description": "Billing refund request", "label": [1, 2, 2]},
                {"description": "App is slow", "label": [0, 1, 1]},
                {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
                {"description": "Invoice mismatch issue", "label": [1, 2, 2]},
                {"description": "Ransomware attack suspected", "label": [2, 2, 2]},
                {"description": "Data breach reported", "label": [2, 2, 2]},
            ]

        else:
            return []

    # ✅ RESET (DETERMINISTIC)
    def reset(self) -> Dict[str, Any]:
        self._queue = self._generate_emails()
        self._index = 0
        self._done = False

        return {
            "description": self._queue[self._index]["description"],
            "step": 0,
            "remaining": len(self._queue),
            "done": False
        }

    # ✅ STATE
    def state(self) -> Dict[str, Any]:
        if self._done:
            return {"done": True}

        current = self._queue[self._index]
        return {
            "description": current["description"],
            "step": self._index,
            "remaining": len(self._queue) - self._index,
            "done": False
        }

    # ✅ STEP (CLEAR GRADER)
    def step(self, action: List[int]) -> Tuple[Dict, float, bool, Dict, Dict]:
        if self._done:
            return self.state(), 0.0, True, {}, {}

        correct = self._queue[self._index]["label"]

        # 🎯 GRADER (CLEAR + NORMALIZED)
        matches = sum(1 for a, b in zip(action, correct) if a == b)
        reward = matches / 3.0  # normalized [0,1]

        self._index += 1

        if self._index >= len(self._queue):
            self._done = True

        return self.state(), reward, self._done, {}, {}