File size: 825 Bytes
dc762fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class EmailEnv:
    def __init__(self):
        self.data = []
        self.index = 0

    def reset(self):
        self.data = [
            {"text": "Meeting at 5pm", "label": "important"},
            {"text": "Win a free iPhone", "label": "spam"},
            {"text": "Weekly report attached", "label": "normal"}
        ]
        self.index = 0
        return self.data[self.index]["text"]

    def step(self, action):
        if self.index >= len(self.data):
            return None, 0, True, {}

        correct = self.data[self.index]["label"]
        reward = 10 if action == correct else -5

        self.index += 1
        done = self.index >= len(self.data)

        next_obs = None if done else self.data[self.index]["text"]

        return next_obs, reward, done, {"correct": correct}