100XZX001 commited on
Commit
7694e55
·
verified ·
1 Parent(s): 0a5683a

Update environment.py

Browse files
Files changed (1) hide show
  1. environment.py +55 -23
environment.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import Tuple, Dict, Any
2
  from models import Observation, Action, Reward, State
3
- from grader import grade_comment
4
 
5
  class CodeReviewEnv:
6
  def __init__(self, task: str = "easy"):
@@ -18,47 +18,49 @@ class CodeReviewEnv:
18
  self.step_count = 0
19
  self.agent_comment = None
20
  self.done = False
 
21
 
22
- # Task definitions with richer context
23
  if self.task == "easy":
24
  self.pr_title = "Fix missing null check in user lookup"
25
  self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError."
26
  self.code_snippet = "def get_user(id):\n return users[id] # missing null check"
27
- self.comments = ["Looks good!", "Maybe add a comment?"]
28
  self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"]
29
  self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError."
30
-
31
  elif self.task == "medium":
32
  self.pr_title = "Improve loop efficiency"
33
  self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable."
34
  self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
35
- self.comments = ["Nice code"]
36
  self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
37
  self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop."
38
-
39
  elif self.task == "hard":
40
  self.pr_title = "Handle division by zero in average calculation"
41
  self.pr_description = "The function crashes when the input list is empty."
42
  self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
43
- self.comments = ["LGTM"]
44
  self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"]
45
  self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error."
46
-
47
  elif self.task == "harder":
48
  self.pr_title = "Fix race condition in counter increment"
49
  self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates."
50
  self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads"
51
- self.comments = ["Simple code, should be fine?"]
52
  self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"]
53
  self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`."
54
-
55
  else: # hardest
56
  self.pr_title = "Fix deadlock in database transaction"
57
  self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock."
58
  self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1"
59
- self.comments = ["Seems okay?"]
60
  self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"]
61
  self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock."
 
62
 
63
  return self._get_observation()
64
 
@@ -72,16 +74,34 @@ class CodeReviewEnv:
72
  if action.action_type == "write_comment":
73
  self.agent_comment = action.comment_text or ""
74
  reward = 0.2 # dense bonus for writing
75
- # Semantic grader
76
- quality_score = grade_comment(
77
- self.agent_comment,
78
- self.expected_keywords,
79
- self.expert_comment
80
- )
81
  reward += quality_score
82
- if reward > 1.0:
83
- reward = 1.0
84
  self.done = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  elif action.action_type == "skip":
86
  reward = -0.1
87
  self.done = True
@@ -96,13 +116,25 @@ class CodeReviewEnv:
96
  obs = self._get_observation()
97
  return obs, Reward(value=reward), self.done, info
98
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def _get_observation(self) -> Observation:
100
  return Observation(
101
  pr_title=self.pr_title,
102
  pr_description=self.pr_description,
103
  code_snippet=self.code_snippet,
104
- comments=self.comments,
105
- agent_comment=self.agent_comment,
106
  step=self.step_count,
107
  done=self.done
108
  )
@@ -112,8 +144,8 @@ class CodeReviewEnv:
112
  pr_title=self.pr_title,
113
  pr_description=self.pr_description,
114
  code_snippet=self.code_snippet,
115
- comments=self.comments,
116
- agent_comment=self.agent_comment,
117
  step=self.step_count,
118
  done=self.done
119
  )
 
1
  from typing import Tuple, Dict, Any
2
  from models import Observation, Action, Reward, State
3
+ from grader import grade_comment, grade_question, grade_fix
4
 
5
  class CodeReviewEnv:
6
  def __init__(self, task: str = "easy"):
 
18
  self.step_count = 0
19
  self.agent_comment = None
20
  self.done = False
21
+ self.test_results = None
22
 
23
+ # Task definitions (same as before)
24
  if self.task == "easy":
25
  self.pr_title = "Fix missing null check in user lookup"
26
  self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError."
27
  self.code_snippet = "def get_user(id):\n return users[id] # missing null check"
28
+ self.comments = []
29
  self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"]
30
  self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError."
31
+ self.expected_fix_keywords = ["if id in users"]
32
  elif self.task == "medium":
33
  self.pr_title = "Improve loop efficiency"
34
  self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable."
35
  self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
36
+ self.comments = []
37
  self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
38
  self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop."
39
+ self.expected_fix_keywords = ["for item in items", "for i, item in enumerate"]
40
  elif self.task == "hard":
41
  self.pr_title = "Handle division by zero in average calculation"
42
  self.pr_description = "The function crashes when the input list is empty."
43
  self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
44
+ self.comments = []
45
  self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"]
46
  self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error."
47
+ self.expected_fix_keywords = ["if not data", "if len(data)==0"]
48
  elif self.task == "harder":
49
  self.pr_title = "Fix race condition in counter increment"
50
  self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates."
51
  self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads"
52
+ self.comments = []
53
  self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"]
54
  self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`."
55
+ self.expected_fix_keywords = ["lock", "threading.Lock", "with lock"]
56
  else: # hardest
57
  self.pr_title = "Fix deadlock in database transaction"
58
  self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock."
59
  self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1"
60
+ self.comments = []
61
  self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"]
62
  self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock."
63
+ self.expected_fix_keywords = ["same order", "lock order", "acquire lock1 then lock2"]
64
 
65
  return self._get_observation()
66
 
 
74
  if action.action_type == "write_comment":
75
  self.agent_comment = action.comment_text or ""
76
  reward = 0.2 # dense bonus for writing
77
+ quality_score = grade_comment(self.agent_comment, self.expected_keywords, self.task)
 
 
 
 
 
78
  reward += quality_score
 
 
79
  self.done = True
80
+
81
+ elif action.action_type == "ask_question":
82
+ if not action.question:
83
+ reward = -0.1
84
+ else:
85
+ q_score = grade_question(action.question)
86
+ reward = 0.1 + q_score # small bonus + quality
87
+ # Simulate a helpful answer
88
+ answer = self._answer_question(action.question)
89
+ self.comments.append(f"Agent: {action.question}")
90
+ self.comments.append(f"Env: {answer}")
91
+ self.step_count += 1
92
+ # Episode continues, not done
93
+
94
+ elif action.action_type == "propose_fix":
95
+ if not action.fix_code:
96
+ reward = -0.2
97
+ else:
98
+ # We'll use a simple keyword check for demonstration
99
+ # In a full version, you'd run unit tests
100
+ fix_score = grade_fix(action.fix_code, self.expected_fix_keywords, None)
101
+ reward = 0.3 + fix_score
102
+ self.test_results = f"Fix evaluated with score {fix_score:.2f}"
103
+ self.done = True
104
+
105
  elif action.action_type == "skip":
106
  reward = -0.1
107
  self.done = True
 
116
  obs = self._get_observation()
117
  return obs, Reward(value=reward), self.done, info
118
 
119
+ def _answer_question(self, question: str) -> str:
120
+ # Simple rule‑based answers – you can expand
121
+ q = question.lower()
122
+ if "what" in q and "purpose" in q:
123
+ return "The purpose of this function is to retrieve a user by ID from a dictionary."
124
+ elif "expected" in q:
125
+ return "The function should return the user object if the ID exists, otherwise raise a KeyError."
126
+ elif "how" in q and "fix" in q:
127
+ return "You might consider adding a check for missing keys or using a safer dictionary method like `get`."
128
+ else:
129
+ return "I'm not sure. Could you be more specific?"
130
+
131
  def _get_observation(self) -> Observation:
132
  return Observation(
133
  pr_title=self.pr_title,
134
  pr_description=self.pr_description,
135
  code_snippet=self.code_snippet,
136
+ comments=self.comments.copy(),
137
+ test_results=self.test_results,
138
  step=self.step_count,
139
  done=self.done
140
  )
 
144
  pr_title=self.pr_title,
145
  pr_description=self.pr_description,
146
  code_snippet=self.code_snippet,
147
+ comments=self.comments.copy(),
148
+ test_results=self.test_results,
149
  step=self.step_count,
150
  done=self.done
151
  )