Omkar1806 commited on
Commit
7622c8f
·
verified ·
1 Parent(s): 30771b2

Update env.py

Browse files
Files changed (1) hide show
  1. env.py +32 -17
env.py CHANGED
@@ -13,34 +13,44 @@ class EmailTriageEnv:
13
  self._index = 0
14
  self._done = False
15
 
 
16
  def _generate_emails(self) -> List[Dict]:
17
- emails = [
18
- {"description": "Password reset not working", "label": [2, 1, 2]},
19
- {"description": "Billing refund request", "label": [1, 2, 2]},
20
- {"description": "App is slow and buggy", "label": [0, 1, 1]},
21
- ]
22
-
23
- if self.task in ["medium", "hard"]:
24
- emails += [
 
 
 
 
 
 
 
 
 
25
  {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
26
  {"description": "Invoice mismatch and payment issue", "label": [1, 2, 2]},
27
- ]
28
-
29
- if self.task == "hard":
30
- emails += [
31
  {"description": "Ransomware attack suspected on system", "label": [2, 2, 2]},
32
  {"description": "User reports data breach and performance issues", "label": [2, 2, 2]},
33
- ]
 
34
 
 
35
  random.shuffle(emails)
36
  return emails
37
 
 
38
  def reset(self) -> Dict[str, Any]:
39
  self._queue = self._generate_emails()
40
  self._index = 0
41
  self._done = False
42
  return self.state()
43
 
 
44
  def state(self) -> Dict[str, Any]:
45
  if self._done:
46
  return {"done": True}
@@ -50,18 +60,23 @@ class EmailTriageEnv:
50
  "description": current["description"],
51
  "step": self._index,
52
  "remaining": len(self._queue) - self._index,
 
53
  }
54
 
 
55
  def step(self, action: List[int]) -> Tuple[Dict, float, bool, Dict, Dict]:
56
  if self._done:
57
  return self.state(), 0.0, True, {}, {}
58
 
59
  correct = self._queue[self._index]["label"]
60
 
61
- reward = sum(
62
- 1.0 if a == b else 0.0
63
- for a, b in zip(action, correct)
64
- ) / 3.0
 
 
 
65
 
66
  self._index += 1
67
 
 
13
  self._index = 0
14
  self._done = False
15
 
16
+ # ✅ TASK-WISE DATA (required for grader)
17
  def _generate_emails(self) -> List[Dict]:
18
+ task_data = {
19
+ "easy": [
20
+ {"description": "Password reset not working", "label": [2, 1, 2]},
21
+ {"description": "Billing refund request", "label": [1, 2, 2]},
22
+ {"description": "App is slow and buggy", "label": [0, 1, 1]},
23
+ ],
24
+ "medium": [
25
+ {"description": "Password reset not working", "label": [2, 1, 2]},
26
+ {"description": "Billing refund request", "label": [1, 2, 2]},
27
+ {"description": "App is slow and buggy", "label": [0, 1, 1]},
28
+ {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
29
+ {"description": "Invoice mismatch and payment issue", "label": [1, 2, 2]},
30
+ ],
31
+ "hard": [
32
+ {"description": "Password reset not working", "label": [2, 1, 2]},
33
+ {"description": "Billing refund request", "label": [1, 2, 2]},
34
+ {"description": "App is slow and buggy", "label": [0, 1, 1]},
35
  {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
36
  {"description": "Invoice mismatch and payment issue", "label": [1, 2, 2]},
 
 
 
 
37
  {"description": "Ransomware attack suspected on system", "label": [2, 2, 2]},
38
  {"description": "User reports data breach and performance issues", "label": [2, 2, 2]},
39
+ ],
40
+ }
41
 
42
+ emails = task_data.get(self.task, task_data["easy"])
43
  random.shuffle(emails)
44
  return emails
45
 
46
+ # ✅ RESET
47
  def reset(self) -> Dict[str, Any]:
48
  self._queue = self._generate_emails()
49
  self._index = 0
50
  self._done = False
51
  return self.state()
52
 
53
+ # ✅ STATE
54
  def state(self) -> Dict[str, Any]:
55
  if self._done:
56
  return {"done": True}
 
60
  "description": current["description"],
61
  "step": self._index,
62
  "remaining": len(self._queue) - self._index,
63
+ "done": False
64
  }
65
 
66
+ # ✅ STEP (GRADER LOGIC)
67
  def step(self, action: List[int]) -> Tuple[Dict, float, bool, Dict, Dict]:
68
  if self._done:
69
  return self.state(), 0.0, True, {}, {}
70
 
71
  correct = self._queue[self._index]["label"]
72
 
73
+ # 🎯 PARTIAL REWARD (important)
74
+ matches = sum(1 for a, b in zip(action, correct) if a == b)
75
+ reward = matches / 3.0 # normalized [0,1]
76
+
77
+ # 🔥 BONUS for perfect prediction
78
+ if matches == 3:
79
+ reward = 1.0
80
 
81
  self._index += 1
82