100XZX001 commited on
Commit
5fdf7b4
·
verified ·
1 Parent(s): f951d64

Update environment.py

Browse files
Files changed (1) hide show
  1. environment.py +49 -26
environment.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Tuple, Dict, Any
2
  from models import Observation, Action, Reward, State
 
3
 
4
  class CodeReviewEnv:
5
  def __init__(self, task: str = "easy"):
@@ -7,7 +8,7 @@ class CodeReviewEnv:
7
  self.reset()
8
 
9
  def set_task(self, task: str):
10
- if task not in ["easy", "medium", "hard"]:
11
  raise ValueError(f"Unknown task: {task}")
12
  self.task = task
13
 
@@ -18,17 +19,46 @@ class CodeReviewEnv:
18
  self.agent_comment = None
19
  self.done = False
20
 
 
21
  if self.task == "easy":
22
- self.pr_code = "def get_user(id):\n return users[id] # missing null check"
 
 
23
  self.comments = ["Looks good!", "Maybe add a comment?"]
 
 
 
24
  elif self.task == "medium":
25
- self.pr_code = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
 
 
26
  self.comments = ["Nice code"]
 
 
 
27
  elif self.task == "hard":
28
- self.pr_code = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
 
 
29
  self.comments = ["LGTM"]
30
- else:
31
- raise RuntimeError(f"Invalid task: {self.task}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  return self._get_observation()
34
 
@@ -41,8 +71,13 @@ class CodeReviewEnv:
41
 
42
  if action.action_type == "write_comment":
43
  self.agent_comment = action.comment_text or ""
44
- reward = 0.2
45
- quality_score = self._grade_comment(self.agent_comment)
 
 
 
 
 
46
  reward += quality_score
47
  if reward > 1.0:
48
  reward = 1.0
@@ -61,25 +96,11 @@ class CodeReviewEnv:
61
  obs = self._get_observation()
62
  return obs, Reward(value=reward), self.done, info
63
 
64
- def _grade_comment(self, comment: str) -> float:
65
- if self.task == "easy":
66
- keywords = ["null", "key", "missing", "check", "exists", "handle"]
67
- matched = sum(1 for kw in keywords if kw in comment.lower())
68
- return min(1.0, matched / 3)
69
- elif self.task == "medium":
70
- keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
71
- matched = sum(1 for kw in keywords if kw in comment.lower())
72
- return min(1.0, matched / 3)
73
- elif self.task == "hard":
74
- keywords = ["empty", "zero", "length", "check", "handle", "exception"]
75
- matched = sum(1 for kw in keywords if kw in comment.lower())
76
- return min(1.0, matched / 3)
77
- else:
78
- return 0.0
79
-
80
  def _get_observation(self) -> Observation:
81
  return Observation(
82
- pr_code=self.pr_code,
 
 
83
  comments=self.comments,
84
  agent_comment=self.agent_comment,
85
  step=self.step_count,
@@ -88,7 +109,9 @@ class CodeReviewEnv:
88
 
89
  def state(self) -> State:
90
  return State(
91
- pr_code=self.pr_code,
 
 
92
  comments=self.comments,
93
  agent_comment=self.agent_comment,
94
  step=self.step_count,
 
1
  from typing import Tuple, Dict, Any
2
  from models import Observation, Action, Reward, State
3
+ from grader import grade_comment
4
 
5
  class CodeReviewEnv:
6
  def __init__(self, task: str = "easy"):
 
8
  self.reset()
9
 
10
  def set_task(self, task: str):
11
+ if task not in ["easy", "medium", "hard", "harder", "hardest"]:
12
  raise ValueError(f"Unknown task: {task}")
13
  self.task = task
14
 
 
19
  self.agent_comment = None
20
  self.done = False
21
 
22
+ # Task definitions with richer context
23
  if self.task == "easy":
24
+ self.pr_title = "Fix missing null check in user lookup"
25
+ self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError."
26
+ self.code_snippet = "def get_user(id):\n return users[id] # missing null check"
27
  self.comments = ["Looks good!", "Maybe add a comment?"]
28
+ self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"]
29
+ self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError."
30
+
31
  elif self.task == "medium":
32
+ self.pr_title = "Improve loop efficiency"
33
+ self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable."
34
+ self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
35
  self.comments = ["Nice code"]
36
+ self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
37
+ self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop."
38
+
39
  elif self.task == "hard":
40
+ self.pr_title = "Handle division by zero in average calculation"
41
+ self.pr_description = "The function crashes when the input list is empty."
42
+ self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
43
  self.comments = ["LGTM"]
44
+ self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"]
45
+ self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error."
46
+
47
+ elif self.task == "harder":
48
+ self.pr_title = "Fix race condition in counter increment"
49
+ self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates."
50
+ self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads"
51
+ self.comments = ["Simple code, should be fine?"]
52
+ self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"]
53
+ self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`."
54
+
55
+ else: # hardest
56
+ self.pr_title = "Fix deadlock in database transaction"
57
+ self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock."
58
+ self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1"
59
+ self.comments = ["Seems okay?"]
60
+ self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"]
61
+ self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock."
62
 
63
  return self._get_observation()
64
 
 
71
 
72
  if action.action_type == "write_comment":
73
  self.agent_comment = action.comment_text or ""
74
+ reward = 0.2 # dense bonus for writing
75
+ # Semantic grader
76
+ quality_score = grade_comment(
77
+ self.agent_comment,
78
+ self.expected_keywords,
79
+ self.expert_comment
80
+ )
81
  reward += quality_score
82
  if reward > 1.0:
83
  reward = 1.0
 
96
  obs = self._get_observation()
97
  return obs, Reward(value=reward), self.done, info
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def _get_observation(self) -> Observation:
100
  return Observation(
101
+ pr_title=self.pr_title,
102
+ pr_description=self.pr_description,
103
+ code_snippet=self.code_snippet,
104
  comments=self.comments,
105
  agent_comment=self.agent_comment,
106
  step=self.step_count,
 
109
 
110
  def state(self) -> State:
111
  return State(
112
+ pr_title=self.pr_title,
113
+ pr_description=self.pr_description,
114
+ code_snippet=self.code_snippet,
115
  comments=self.comments,
116
  agent_comment=self.agent_comment,
117
  step=self.step_count,