Harshit2N commited on
Commit
5e92b80
·
1 Parent(s): b61cfff

Enhance Code Review Environment with Action History, Valid Actions, and Improved Grading

Browse files

- Added action history tracking in CodeReviewEnv to store recent actions.
- Implemented valid_actions method to return available actions based on the current state.
- Updated reset method to accept a seed for randomization.
- Improved step method to handle action processing and state completion more robustly.
- Enhanced TaskGrader with new grading metrics for false positives and efficiency.
- Updated diagnostics to include efficiency bonus and false positive penalties.
- Added render and summary methods in CodeReviewEnv for better visualization and reporting.
- Refactored inference.py to support batch processing of tasks and improved output handling.
- Added difficulty levels to tasks in TaskDefinitions for better task categorization.

Files changed (5) hide show
  1. environment/env.py +142 -32
  2. environment/graders.py +103 -57
  3. environment/init.py +3 -1
  4. environment/tasks.py +24 -18
  5. inference.py +135 -153
environment/env.py CHANGED
@@ -1,4 +1,6 @@
1
- from typing import Dict, Any, Tuple, Optional
 
 
2
  from environment.models import (
3
  ReviewAction,
4
  ReviewState,
@@ -9,24 +11,30 @@ from environment.graders import TaskGrader, RewardCalculator
9
 
10
 
11
  class CodeReviewEnv:
12
-
13
  def __init__(self):
14
  self._state: Optional[ReviewState] = None
15
  self.grader: Optional[TaskGrader] = None
16
  self.reward_calculator = RewardCalculator()
17
  self.max_steps = 50
18
  self.current_task_id: Optional[str] = None
19
-
20
- def reset(self, task_id: str = None) -> Dict[str, Any]:
 
 
 
 
 
 
21
  if task_id is None:
22
  task_id = "bug_detection_easy_1"
23
-
24
  self.current_task_id = task_id
25
  task_data = TaskDefinitions.get_task(task_id)
26
-
27
  code_context = TaskDefinitions.create_code_context(task_data)
28
  task_metadata = TaskDefinitions.create_task_metadata(task_data)
29
-
30
  self._state = ReviewState(
31
  code_context=code_context,
32
  task_metadata=task_metadata,
@@ -38,75 +46,93 @@ class CodeReviewEnv:
38
  last_action_valid=True,
39
  last_error=None
40
  )
41
-
42
  self.grader = TaskGrader(task_metadata.expected_issues)
43
  self.reward_calculator.reset()
44
-
 
45
  return self._get_observation()
46
-
47
  def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
48
  if self._state is None:
49
  return {}, -0.1, True, {"error": "Environment not initialized. Call reset() first."}
50
 
51
  if self._state.is_complete:
52
  return self._get_observation(), 0.0, True, {"error": "Episode already complete"}
53
-
 
 
 
 
54
  try:
55
  review_action = ReviewAction(**action)
56
  except Exception as e:
57
  self._state.last_action_valid = False
58
  self._state.last_error = str(e)
59
  return self._get_observation(), -0.1, False, {"error": str(e), "last_action_valid": False}
60
-
61
  self._state.current_step += 1
62
  self._process_action(review_action)
63
 
 
 
 
 
 
 
 
 
64
  if review_action.action_type.value == "approve" and not review_action.final_decision:
65
  review_action.final_decision = "approved"
66
  elif review_action.action_type.value == "request_changes" and not review_action.final_decision:
67
  review_action.final_decision = "changes_requested"
68
-
69
- if self._state.current_step >= self.max_steps:
70
- self._state.is_complete = True
71
- if not self._state.final_decision:
72
- self._state.final_decision = "changes_requested"
73
-
74
  if review_action.final_decision and not self._state.is_complete:
75
  self._state.is_complete = True
76
  self._state.final_decision = review_action.final_decision
77
-
 
 
 
 
 
78
  reward = self.reward_calculator.calculate_reward(
79
  review_action,
80
  self._state.comments_made,
81
  self._state.suggestions_made,
82
  self._state.final_decision or "changes_requested",
83
- self.grader,
84
  self._state.last_action_valid,
 
 
85
  )
86
 
87
- diagnostics = self.grader.get_diagnostics(
88
  comments=self._state.comments_made,
89
  suggestions=self._state.suggestions_made,
90
  final_decision=self._state.final_decision or "changes_requested",
 
 
91
  )
92
-
93
  info = {
94
  "step": self._state.current_step,
95
  "last_action_valid": self._state.last_action_valid,
96
  "error": self._state.last_error,
97
  "task_score": self.get_task_score(),
98
  "diagnostics": diagnostics,
 
99
  }
100
-
101
  return self._get_observation(), reward, self._state.is_complete, info
102
-
103
  def _process_action(self, action: ReviewAction):
104
  if self._state is None:
105
  return
106
 
107
  self._state.last_action_valid = True
108
  self._state.last_error = None
109
-
110
  if action.action_type.value == "add_comment":
111
  for comment in action.comments:
112
  if comment.line_number <= self._state.code_context.line_count:
@@ -114,7 +140,7 @@ class CodeReviewEnv:
114
  else:
115
  self._state.last_action_valid = False
116
  self._state.last_error = f"Line {comment.line_number} out of range"
117
-
118
  elif action.action_type.value == "suggest_fix":
119
  for suggestion in action.suggestions:
120
  if suggestion.original_line <= self._state.code_context.line_count:
@@ -122,18 +148,34 @@ class CodeReviewEnv:
122
  else:
123
  self._state.last_action_valid = False
124
  self._state.last_error = f"Line {suggestion.original_line} out of range"
125
-
126
  elif action.action_type.value == "mark_as_resolved":
 
 
 
 
127
  for comment in action.comments:
128
  for existing_comment in self._state.comments_made:
129
  if existing_comment.line_number == comment.line_number:
130
  existing_comment.resolved = True
131
-
 
 
 
 
 
 
 
 
 
 
 
 
132
  def _get_observation(self) -> Dict[str, Any]:
133
  if self._state is None:
134
  return {}
135
 
136
- return Observation(
137
  code_diff=self._state.code_context.code_diff,
138
  file_context=self._state.code_context.surrounding_code,
139
  file_path=self._state.code_context.file_path,
@@ -147,7 +189,73 @@ class CodeReviewEnv:
147
  review_complete=self._state.is_complete,
148
  final_decision_made=self._state.final_decision
149
  ).model_dump()
150
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def get_task_score(self) -> float:
152
  if not self.grader or self._state is None:
153
  return 0.0
@@ -156,11 +264,13 @@ class CodeReviewEnv:
156
  comments=self._state.comments_made,
157
  suggestions=self._state.suggestions_made,
158
  final_decision=self._state.final_decision or "changes_requested",
 
 
159
  )
160
-
161
  def close(self):
162
  pass
163
-
164
  def state(self) -> Dict[str, Any]:
165
  if self._state:
166
  return self._state.model_dump()
 
1
+ from typing import Dict, Any, Tuple, Optional, List, Deque
2
+ from collections import deque
3
+ import random
4
  from environment.models import (
5
  ReviewAction,
6
  ReviewState,
 
11
 
12
 
13
  class CodeReviewEnv:
14
+
15
  def __init__(self):
16
  self._state: Optional[ReviewState] = None
17
  self.grader: Optional[TaskGrader] = None
18
  self.reward_calculator = RewardCalculator()
19
  self.max_steps = 50
20
  self.current_task_id: Optional[str] = None
21
+ self._action_history: Deque[Dict[str, Any]] = deque(maxlen=5)
22
+ self._seed: Optional[int] = None
23
+
24
+ def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None) -> Dict[str, Any]:
25
+ if seed is not None:
26
+ self._seed = seed
27
+ random.seed(seed)
28
+
29
  if task_id is None:
30
  task_id = "bug_detection_easy_1"
31
+
32
  self.current_task_id = task_id
33
  task_data = TaskDefinitions.get_task(task_id)
34
+
35
  code_context = TaskDefinitions.create_code_context(task_data)
36
  task_metadata = TaskDefinitions.create_task_metadata(task_data)
37
+
38
  self._state = ReviewState(
39
  code_context=code_context,
40
  task_metadata=task_metadata,
 
46
  last_action_valid=True,
47
  last_error=None
48
  )
49
+
50
  self.grader = TaskGrader(task_metadata.expected_issues)
51
  self.reward_calculator.reset()
52
+ self._action_history.clear()
53
+
54
  return self._get_observation()
55
+
56
  def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
57
  if self._state is None:
58
  return {}, -0.1, True, {"error": "Environment not initialized. Call reset() first."}
59
 
60
  if self._state.is_complete:
61
  return self._get_observation(), 0.0, True, {"error": "Episode already complete"}
62
+
63
+ if self.grader is None:
64
+ return self._get_observation(), -0.1, True, {"error": "Environment not initialized. Call reset() first."}
65
+ grader = self.grader
66
+
67
  try:
68
  review_action = ReviewAction(**action)
69
  except Exception as e:
70
  self._state.last_action_valid = False
71
  self._state.last_error = str(e)
72
  return self._get_observation(), -0.1, False, {"error": str(e), "last_action_valid": False}
73
+
74
  self._state.current_step += 1
75
  self._process_action(review_action)
76
 
77
+ self._action_history.append({
78
+ "step": self._state.current_step,
79
+ "action_type": review_action.action_type.value,
80
+ "num_comments": len(review_action.comments),
81
+ "num_suggestions": len(review_action.suggestions),
82
+ "final_decision": review_action.final_decision,
83
+ })
84
+
85
  if review_action.action_type.value == "approve" and not review_action.final_decision:
86
  review_action.final_decision = "approved"
87
  elif review_action.action_type.value == "request_changes" and not review_action.final_decision:
88
  review_action.final_decision = "changes_requested"
89
+
 
 
 
 
 
90
  if review_action.final_decision and not self._state.is_complete:
91
  self._state.is_complete = True
92
  self._state.final_decision = review_action.final_decision
93
+
94
+ if self._state.current_step >= self.max_steps and not self._state.is_complete:
95
+ self._state.is_complete = True
96
+ if not self._state.final_decision:
97
+ self._state.final_decision = "changes_requested"
98
+
99
  reward = self.reward_calculator.calculate_reward(
100
  review_action,
101
  self._state.comments_made,
102
  self._state.suggestions_made,
103
  self._state.final_decision or "changes_requested",
104
+ grader,
105
  self._state.last_action_valid,
106
+ steps_taken=self._state.current_step,
107
+ max_steps=self.max_steps,
108
  )
109
 
110
+ diagnostics = grader.get_diagnostics(
111
  comments=self._state.comments_made,
112
  suggestions=self._state.suggestions_made,
113
  final_decision=self._state.final_decision or "changes_requested",
114
+ steps_taken=self._state.current_step,
115
+ max_steps=self.max_steps,
116
  )
117
+
118
  info = {
119
  "step": self._state.current_step,
120
  "last_action_valid": self._state.last_action_valid,
121
  "error": self._state.last_error,
122
  "task_score": self.get_task_score(),
123
  "diagnostics": diagnostics,
124
+ "valid_actions": self.valid_actions(),
125
  }
126
+
127
  return self._get_observation(), reward, self._state.is_complete, info
128
+
129
  def _process_action(self, action: ReviewAction):
130
  if self._state is None:
131
  return
132
 
133
  self._state.last_action_valid = True
134
  self._state.last_error = None
135
+
136
  if action.action_type.value == "add_comment":
137
  for comment in action.comments:
138
  if comment.line_number <= self._state.code_context.line_count:
 
140
  else:
141
  self._state.last_action_valid = False
142
  self._state.last_error = f"Line {comment.line_number} out of range"
143
+
144
  elif action.action_type.value == "suggest_fix":
145
  for suggestion in action.suggestions:
146
  if suggestion.original_line <= self._state.code_context.line_count:
 
148
  else:
149
  self._state.last_action_valid = False
150
  self._state.last_error = f"Line {suggestion.original_line} out of range"
151
+
152
  elif action.action_type.value == "mark_as_resolved":
153
+ if not self._state.comments_made:
154
+ self._state.last_action_valid = False
155
+ self._state.last_error = "No comments exist to mark as resolved"
156
+ return
157
  for comment in action.comments:
158
  for existing_comment in self._state.comments_made:
159
  if existing_comment.line_number == comment.line_number:
160
  existing_comment.resolved = True
161
+
162
+ def valid_actions(self) -> List[str]:
163
+ if self._state is None:
164
+ return []
165
+
166
+ actions = ["add_comment", "approve", "request_changes"]
167
+
168
+ if self._state.comments_made:
169
+ actions.append("suggest_fix")
170
+ actions.append("mark_as_resolved")
171
+
172
+ return actions
173
+
174
  def _get_observation(self) -> Dict[str, Any]:
175
  if self._state is None:
176
  return {}
177
 
178
+ obs = Observation(
179
  code_diff=self._state.code_context.code_diff,
180
  file_context=self._state.code_context.surrounding_code,
181
  file_path=self._state.code_context.file_path,
 
189
  review_complete=self._state.is_complete,
190
  final_decision_made=self._state.final_decision
191
  ).model_dump()
192
+
193
+ obs["action_history"] = list(self._action_history)
194
+ obs["valid_actions"] = self.valid_actions()
195
+ obs["line_count"] = self._state.code_context.line_count
196
+
197
+ return obs
198
+
199
+ def render(self):
200
+ if self._state is None:
201
+ print("Environment not initialized.")
202
+ return
203
+
204
+ print("=" * 60)
205
+ print(f"FILE : {self._state.code_context.file_path}")
206
+ print(f"LANGUAGE : {self._state.code_context.language}")
207
+ print(f"STEP : {self._state.current_step}/{self.max_steps}")
208
+ print(f"DONE : {self._state.is_complete}")
209
+ print(f"DECISION : {self._state.final_decision or 'pending'}")
210
+ print(f"SCORE : {self.get_task_score():.3f}")
211
+ print("-" * 60)
212
+ print("CODE DIFF:")
213
+ for i, line in enumerate(self._state.code_context.code_diff.split("\n"), start=1):
214
+ print(f" {i}: {line}")
215
+ print("-" * 60)
216
+
217
+ if self._state.comments_made:
218
+ print(f"COMMENTS ({len(self._state.comments_made)}):")
219
+ for c in self._state.comments_made:
220
+ print(f" Line {c.line_number} [{c.severity}]: {c.content}")
221
+
222
+ if self._state.suggestions_made:
223
+ print(f"SUGGESTIONS ({len(self._state.suggestions_made)}):")
224
+ for s in self._state.suggestions_made:
225
+ print(f" Line {s.original_line}: {s.suggested_code}")
226
+
227
+ print(f"VALID ACTIONS: {self.valid_actions()}")
228
+ print("=" * 60)
229
+
230
+ def summary(self) -> Dict[str, Any]:
231
+ if not self.grader or self._state is None:
232
+ return {}
233
+
234
+ diagnostics = self.grader.get_diagnostics(
235
+ comments=self._state.comments_made,
236
+ suggestions=self._state.suggestions_made,
237
+ final_decision=self._state.final_decision or "changes_requested",
238
+ steps_taken=self._state.current_step,
239
+ max_steps=self.max_steps,
240
+ )
241
+
242
+ print("\n--- Episode Summary ---")
243
+ print(f" Task : {self.current_task_id}")
244
+ print(f" Steps taken : {self._state.current_step}/{self.max_steps}")
245
+ print(f" Final decision : {self._state.final_decision or 'none'}")
246
+ print(f" Score : {diagnostics['score']}")
247
+ print(f" Precision : {diagnostics['precision']}")
248
+ print(f" Recall : {diagnostics['recall']}")
249
+ print(f" F1 : {diagnostics['f1']}")
250
+ print(f" True positives : {diagnostics['true_positive_count']}")
251
+ print(f" False positives : {diagnostics['false_positive_count']}")
252
+ print(f" False negatives : {diagnostics['false_negative_count']}")
253
+ print(f" FP penalty : {diagnostics['false_positive_penalty']}")
254
+ print(f" Efficiency bonus: {diagnostics['efficiency_bonus']}")
255
+ print("-----------------------")
256
+
257
+ return diagnostics
258
+
259
  def get_task_score(self) -> float:
260
  if not self.grader or self._state is None:
261
  return 0.0
 
264
  comments=self._state.comments_made,
265
  suggestions=self._state.suggestions_made,
266
  final_decision=self._state.final_decision or "changes_requested",
267
+ steps_taken=self._state.current_step,
268
+ max_steps=self.max_steps,
269
  )
270
+
271
  def close(self):
272
  pass
273
+
274
  def state(self) -> Dict[str, Any]:
275
  if self._state:
276
  return self._state.model_dump()
environment/graders.py CHANGED
@@ -3,7 +3,7 @@ from environment.models import Comment, Suggestion, ReviewAction
3
 
4
 
5
  class TaskGrader:
6
-
7
  def __init__(self, expected_issues: List[Dict[str, Any]]):
8
  self.expected_issues = expected_issues
9
 
@@ -24,54 +24,105 @@ class TaskGrader:
24
  return True
25
 
26
  return expected_type in comment_text
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def grade_detection(self, comments: List[Comment]) -> float:
29
  if not self.expected_issues:
30
- # No-issue tasks reward restraint and penalize false positives.
31
  issue_comments = [c for c in comments if c.is_issue]
32
  return 1.0 if not issue_comments else 0.0
33
-
34
  if not comments:
35
  return 0.0
36
-
37
- matched_expected_indexes: Set[int] = set()
38
- for idx, expected in enumerate(self.expected_issues):
 
39
  for comment in comments:
40
- if not comment.is_issue:
41
- continue
42
- if self._match_issue(expected, comment):
43
- matched_expected_indexes.add(idx)
44
- break
 
45
 
46
- return len(matched_expected_indexes) / len(self.expected_issues)
47
-
48
  def grade_suggestions(self, suggestions: List[Suggestion]) -> float:
49
  if not self.expected_issues:
50
  return 1.0 if not suggestions else 0.0
51
-
52
  if not suggestions:
53
  return 0.0
54
-
55
  matched_expected_indexes: Set[int] = set()
56
  for idx, expected in enumerate(self.expected_issues):
57
  for suggestion in suggestions:
58
- if suggestion.original_line == expected.get("line"):
 
59
  matched_expected_indexes.add(idx)
60
  break
61
 
62
  return min(1.0, len(matched_expected_indexes) / len(self.expected_issues))
63
-
64
  def grade_decision(self, final_decision: str) -> float:
65
  if not self.expected_issues:
66
  return 1.0 if final_decision == "approved" else 0.0
67
-
68
- # If task includes expected issues, a safe review should request changes.
69
  return 1.0 if final_decision == "changes_requested" else 0.0
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def get_diagnostics(self,
72
  comments: List[Comment],
73
  suggestions: List[Suggestion],
74
- final_decision: str) -> Dict[str, Any]:
 
 
75
  issue_comments = [c for c in comments if c.is_issue]
76
  expected_count = len(self.expected_issues)
77
 
@@ -96,30 +147,16 @@ class TaskGrader:
96
  precision = true_positives / max(1, len(issue_comments))
97
  recall = true_positives / expected_count
98
 
99
- severity_weights = {
100
- "low": 0.25,
101
- "medium": 0.5,
102
- "high": 0.75,
103
- "critical": 1.0,
104
- }
105
- weighted_found = 0.0
106
- weighted_total = 0.0
107
- for expected_idx, expected in enumerate(self.expected_issues):
108
- weight = severity_weights.get(str(expected.get("severity", "medium")).lower(), 0.5)
109
- weighted_total += weight
110
- if expected_idx in matched_expected_indexes:
111
- weighted_found += weight
112
- severity_weighted_detection = 1.0 if weighted_total == 0 else (weighted_found / weighted_total)
113
 
114
  detection_score = self.grade_detection(comments)
115
  suggestion_score = self.grade_suggestions(suggestions)
116
  decision_score = self.grade_decision(final_decision)
117
-
118
- false_positive_rate = false_positives / max(1, len(issue_comments))
119
- false_positive_penalty = min(0.4, false_positive_rate * 0.25)
120
 
121
  raw_score = (detection_score * 0.4) + (suggestion_score * 0.3) + (decision_score * 0.3)
122
- final_score = max(0.0, min(1.0, raw_score - false_positive_penalty))
123
 
124
  return {
125
  "expected_issue_count": expected_count,
@@ -128,61 +165,70 @@ class TaskGrader:
128
  "false_negative_count": false_negatives,
129
  "precision": round(precision, 4),
130
  "recall": round(recall, 4),
131
- "severity_weighted_detection": round(severity_weighted_detection, 4),
132
  "detection_score": round(detection_score, 4),
133
  "suggestion_score": round(suggestion_score, 4),
134
  "decision_score": round(decision_score, 4),
135
  "false_positive_penalty": round(false_positive_penalty, 4),
 
136
  "score": round(final_score, 4),
137
  }
138
 
139
  def compute_score(self,
140
  comments: List[Comment],
141
  suggestions: List[Suggestion],
142
- final_decision: str) -> float:
143
- diagnostics = self.get_diagnostics(comments, suggestions, final_decision)
 
 
144
  return float(diagnostics["score"])
145
 
146
  def compute_score_from_state(self,
147
  comments: List[Comment],
148
  suggestions: List[Suggestion],
149
- final_decision: str) -> float:
150
- return self.compute_score(comments, suggestions, final_decision)
 
 
151
 
152
 
153
  class RewardCalculator:
154
-
155
  def __init__(self):
156
  self.last_score = 0.0
157
-
158
- def calculate_reward(self,
159
  current_action: ReviewAction,
160
  all_comments: List[Comment],
161
  all_suggestions: List[Suggestion],
162
  final_decision: str,
163
  grader: TaskGrader,
164
- last_action_valid: bool) -> float:
 
 
165
 
166
  current_score = grader.compute_score(
167
  comments=all_comments,
168
  suggestions=all_suggestions,
169
  final_decision=final_decision,
 
 
170
  )
171
-
172
  reward = current_score - self.last_score
173
-
174
  if current_action.action_type.value in ["add_comment", "suggest_fix"]:
175
  reward += 0.03
176
 
177
  if not last_action_valid:
178
  reward -= 0.15
179
-
180
  if not current_action.comments and not current_action.suggestions:
181
  if current_action.action_type.value in ["approve", "request_changes"]:
182
  pass
183
  else:
184
  reward -= 0.1
185
-
186
  for comment in current_action.comments:
187
  if comment.severity == "critical":
188
  reward += 0.2
@@ -190,19 +236,19 @@ class RewardCalculator:
190
  reward += 0.1
191
  elif comment.severity == "medium":
192
  reward += 0.05
193
-
194
  if len(current_action.suggestions) > 0:
195
  reward += 0.05 * len(current_action.suggestions)
196
-
197
  if current_action.final_decision:
198
  optimal_decision = "changes_requested" if grader.expected_issues else "approved"
199
  reward += 0.1 if current_action.final_decision == optimal_decision else -0.1
200
 
201
  reward = max(-0.5, min(1.0, reward))
202
-
203
  self.last_score = current_score
204
-
205
  return reward
206
-
207
  def reset(self):
208
  self.last_score = 0.0
 
3
 
4
 
5
  class TaskGrader:
6
+
7
  def __init__(self, expected_issues: List[Dict[str, Any]]):
8
  self.expected_issues = expected_issues
9
 
 
24
  return True
25
 
26
  return expected_type in comment_text
27
+
28
+ def _partial_credit(self, expected: Dict[str, Any], comment: Comment) -> float:
29
+ expected_line = int(expected.get("line", 0) or 0)
30
+ expected_type = self._normalize(expected.get("type", ""))
31
+ comment_text = self._normalize(comment.content)
32
+
33
+ if not comment.is_issue:
34
+ return 0.0
35
+
36
+ keyword_tokens = expected_type.replace("_", " ").split()
37
+ content_match = expected_type in comment_text or any(t in comment_text for t in keyword_tokens)
38
+
39
+ if comment.line_number == expected_line and content_match:
40
+ return 1.0
41
+
42
+ distance = abs(comment.line_number - expected_line)
43
+ if distance <= 2 and content_match:
44
+ return max(0.0, 1.0 - distance * 0.25)
45
+
46
+ if content_match:
47
+ return 0.2
48
+
49
+ return 0.0
50
+
51
  def grade_detection(self, comments: List[Comment]) -> float:
52
  if not self.expected_issues:
 
53
  issue_comments = [c for c in comments if c.is_issue]
54
  return 1.0 if not issue_comments else 0.0
55
+
56
  if not comments:
57
  return 0.0
58
+
59
+ total_credit = 0.0
60
+ for expected in self.expected_issues:
61
+ best_credit = 0.0
62
  for comment in comments:
63
+ credit = self._partial_credit(expected, comment)
64
+ if credit > best_credit:
65
+ best_credit = credit
66
+ total_credit += best_credit
67
+
68
+ return min(1.0, total_credit / len(self.expected_issues))
69
 
 
 
70
  def grade_suggestions(self, suggestions: List[Suggestion]) -> float:
71
  if not self.expected_issues:
72
  return 1.0 if not suggestions else 0.0
73
+
74
  if not suggestions:
75
  return 0.0
76
+
77
  matched_expected_indexes: Set[int] = set()
78
  for idx, expected in enumerate(self.expected_issues):
79
  for suggestion in suggestions:
80
+ distance = abs(suggestion.original_line - expected.get("line", 0))
81
+ if distance <= 1:
82
  matched_expected_indexes.add(idx)
83
  break
84
 
85
  return min(1.0, len(matched_expected_indexes) / len(self.expected_issues))
86
+
87
  def grade_decision(self, final_decision: str) -> float:
88
  if not self.expected_issues:
89
  return 1.0 if final_decision == "approved" else 0.0
 
 
90
  return 1.0 if final_decision == "changes_requested" else 0.0
91
 
92
+ def grade_false_positives(self, comments: List[Comment]) -> float:
93
+ if not self.expected_issues:
94
+ return 0.0
95
+
96
+ issue_comments = [c for c in comments if c.is_issue]
97
+ if not issue_comments:
98
+ return 0.0
99
+
100
+ matched_comment_indexes: Set[int] = set()
101
+ for expected in self.expected_issues:
102
+ for idx, comment in enumerate(issue_comments):
103
+ if self._partial_credit(expected, comment) > 0:
104
+ matched_comment_indexes.add(idx)
105
+
106
+ false_positive_count = len(issue_comments) - len(matched_comment_indexes)
107
+ false_positive_rate = false_positive_count / max(1, len(issue_comments))
108
+ return min(0.4, false_positive_rate * 0.25)
109
+
110
+ def grade_efficiency(self, steps_taken: int, max_steps: int) -> float:
111
+ if max_steps <= 0:
112
+ return 0.0
113
+ ratio = steps_taken / max_steps
114
+ if ratio <= 0.1:
115
+ return 0.1
116
+ if ratio <= 0.2:
117
+ return 0.05
118
+ return 0.0
119
+
120
  def get_diagnostics(self,
121
  comments: List[Comment],
122
  suggestions: List[Suggestion],
123
+ final_decision: str,
124
+ steps_taken: int = 0,
125
+ max_steps: int = 50) -> Dict[str, Any]:
126
  issue_comments = [c for c in comments if c.is_issue]
127
  expected_count = len(self.expected_issues)
128
 
 
147
  precision = true_positives / max(1, len(issue_comments))
148
  recall = true_positives / expected_count
149
 
150
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  detection_score = self.grade_detection(comments)
153
  suggestion_score = self.grade_suggestions(suggestions)
154
  decision_score = self.grade_decision(final_decision)
155
+ false_positive_penalty = self.grade_false_positives(comments)
156
+ efficiency_bonus = self.grade_efficiency(steps_taken, max_steps)
 
157
 
158
  raw_score = (detection_score * 0.4) + (suggestion_score * 0.3) + (decision_score * 0.3)
159
+ final_score = max(0.0, min(1.0, raw_score - false_positive_penalty + efficiency_bonus))
160
 
161
  return {
162
  "expected_issue_count": expected_count,
 
165
  "false_negative_count": false_negatives,
166
  "precision": round(precision, 4),
167
  "recall": round(recall, 4),
168
+ "f1": round(f1, 4),
169
  "detection_score": round(detection_score, 4),
170
  "suggestion_score": round(suggestion_score, 4),
171
  "decision_score": round(decision_score, 4),
172
  "false_positive_penalty": round(false_positive_penalty, 4),
173
+ "efficiency_bonus": round(efficiency_bonus, 4),
174
  "score": round(final_score, 4),
175
  }
176
 
177
  def compute_score(self,
178
  comments: List[Comment],
179
  suggestions: List[Suggestion],
180
+ final_decision: str,
181
+ steps_taken: int = 0,
182
+ max_steps: int = 50) -> float:
183
+ diagnostics = self.get_diagnostics(comments, suggestions, final_decision, steps_taken, max_steps)
184
  return float(diagnostics["score"])
185
 
186
  def compute_score_from_state(self,
187
  comments: List[Comment],
188
  suggestions: List[Suggestion],
189
+ final_decision: str,
190
+ steps_taken: int = 0,
191
+ max_steps: int = 50) -> float:
192
+ return self.compute_score(comments, suggestions, final_decision, steps_taken, max_steps)
193
 
194
 
195
  class RewardCalculator:
196
+
197
  def __init__(self):
198
  self.last_score = 0.0
199
+
200
+ def calculate_reward(self,
201
  current_action: ReviewAction,
202
  all_comments: List[Comment],
203
  all_suggestions: List[Suggestion],
204
  final_decision: str,
205
  grader: TaskGrader,
206
+ last_action_valid: bool,
207
+ steps_taken: int = 0,
208
+ max_steps: int = 50) -> float:
209
 
210
  current_score = grader.compute_score(
211
  comments=all_comments,
212
  suggestions=all_suggestions,
213
  final_decision=final_decision,
214
+ steps_taken=steps_taken,
215
+ max_steps=max_steps,
216
  )
217
+
218
  reward = current_score - self.last_score
219
+
220
  if current_action.action_type.value in ["add_comment", "suggest_fix"]:
221
  reward += 0.03
222
 
223
  if not last_action_valid:
224
  reward -= 0.15
225
+
226
  if not current_action.comments and not current_action.suggestions:
227
  if current_action.action_type.value in ["approve", "request_changes"]:
228
  pass
229
  else:
230
  reward -= 0.1
231
+
232
  for comment in current_action.comments:
233
  if comment.severity == "critical":
234
  reward += 0.2
 
236
  reward += 0.1
237
  elif comment.severity == "medium":
238
  reward += 0.05
239
+
240
  if len(current_action.suggestions) > 0:
241
  reward += 0.05 * len(current_action.suggestions)
242
+
243
  if current_action.final_decision:
244
  optimal_decision = "changes_requested" if grader.expected_issues else "approved"
245
  reward += 0.1 if current_action.final_decision == optimal_decision else -0.1
246
 
247
  reward = max(-0.5, min(1.0, reward))
248
+
249
  self.last_score = current_score
250
+
251
  return reward
252
+
253
  def reset(self):
254
  self.last_score = 0.0
environment/init.py CHANGED
@@ -9,6 +9,7 @@ from environment.models import (
9
  ReviewState,
10
  Observation
11
  )
 
12
 
13
  __all__ = [
14
  "CodeReviewEnv",
@@ -19,5 +20,6 @@ __all__ = [
19
  "CodeContext",
20
  "TaskMetadata",
21
  "ReviewState",
22
- "Observation"
 
23
  ]
 
9
  ReviewState,
10
  Observation
11
  )
12
+ from environment.tasks import TaskDefinitions
13
 
14
  __all__ = [
15
  "CodeReviewEnv",
 
20
  "CodeContext",
21
  "TaskMetadata",
22
  "ReviewState",
23
+ "Observation",
24
+ "TaskDefinitions",
25
  ]
environment/tasks.py CHANGED
@@ -9,11 +9,12 @@ class TaskDefinitions:
9
  "bug_detection_medium": "memory_leak_medium_1",
10
  "bug_detection_hard": "security_hard_1",
11
  }
12
-
13
  EASY_TASKS = [
14
  {
15
  "task_id": "bug_detection_easy_1",
16
  "task_name": "Division by Zero",
 
17
  "description": "Find the division by zero vulnerability in the calculate_average function",
18
  "code_diff": """def calculate_average(numbers):
19
  total = sum(numbers)
@@ -21,11 +22,11 @@ class TaskDefinitions:
21
  "surrounding_code": """class StatisticsCalculator:
22
  def __init__(self):
23
  self.results = []
24
-
25
  def calculate_average(self, numbers):
26
  total = sum(numbers)
27
  return total / len(numbers)
28
-
29
  def add_result(self, value):
30
  self.results.append(value)""",
31
  "file_path": "statistics.py",
@@ -43,6 +44,7 @@ class TaskDefinitions:
43
  {
44
  "task_id": "bug_detection_easy_2",
45
  "task_name": "Off-by-One Error",
 
46
  "description": "Find the off-by-one error in the array iteration",
47
  "code_diff": """def process_items(items):
48
  for i in range(len(items)):
@@ -70,6 +72,7 @@ class TaskDefinitions:
70
  {
71
  "task_id": "approve_easy_3",
72
  "task_name": "Approve Safe Refactor",
 
73
  "description": "No issues expected: approve this small readability refactor",
74
  "code_diff": """def normalize_name(name):
75
  cleaned = name.strip()
@@ -86,11 +89,12 @@ def format_username(user):
86
  "expected_issues": []
87
  }
88
  ]
89
-
90
  MEDIUM_TASKS = [
91
  {
92
  "task_id": "memory_leak_medium_1",
93
  "task_name": "File Handle Leak",
 
94
  "description": "Find the memory leak where file handles are not properly closed",
95
  "code_diff": """def read_files(file_list):
96
  contents = []
@@ -127,6 +131,7 @@ def write_output(data, filename):
127
  {
128
  "task_id": "performance_medium_2",
129
  "task_name": "Inefficient String Concatenation",
 
130
  "description": "Find the performance issue with string concatenation in a loop",
131
  "code_diff": """def build_string(items):
132
  result = ""
@@ -156,6 +161,7 @@ def format_output(data):
156
  {
157
  "task_id": "approve_medium_3",
158
  "task_name": "Approve Safe Query Helper",
 
159
  "description": "No issues expected: approve this query helper cleanup",
160
  "code_diff": """def build_user_query(limit):
161
  safe_limit = max(1, int(limit))
@@ -173,11 +179,12 @@ def run_user_query(db, limit):
173
  "expected_issues": []
174
  }
175
  ]
176
-
177
  HARD_TASKS = [
178
  {
179
  "task_id": "security_hard_1",
180
  "task_name": "SQL Injection Vulnerability",
 
181
  "description": "Find the SQL injection vulnerability in the database query",
182
  "code_diff": """def get_user_data(user_id):
183
  query = f"SELECT * FROM users WHERE id = {user_id}"
@@ -205,11 +212,12 @@ def get_all_users():
205
  {
206
  "task_id": "race_condition_hard_2",
207
  "task_name": "Race Condition",
 
208
  "description": "Find the race condition in the thread-safe counter",
209
  "code_diff": """class Counter:
210
  def __init__(self):
211
  self.count = 0
212
-
213
  def increment(self):
214
  current = self.count
215
  self.count = current + 1
@@ -219,12 +227,12 @@ def get_all_users():
219
  class Counter:
220
  def __init__(self):
221
  self.count = 0
222
-
223
  def increment(self):
224
  current = self.count
225
  self.count = current + 1
226
  return self.count
227
-
228
  def get_count(self):
229
  return self.count""",
230
  "file_path": "counter.py",
@@ -242,6 +250,7 @@ class Counter:
242
  {
243
  "task_id": "approve_hard_3",
244
  "task_name": "Approve Thread-Safe Counter",
 
245
  "description": "No issues expected: approve this lock-based concurrency fix",
246
  "code_diff": """class Counter:
247
  def __init__(self):
@@ -269,7 +278,7 @@ class Counter:
269
  "expected_issues": []
270
  }
271
  ]
272
-
273
  @classmethod
274
  def get_task(cls, task_id: str) -> Dict[str, Any]:
275
  canonical_task_id = cls.TASK_ALIASES.get(task_id, task_id)
@@ -277,12 +286,13 @@ class Counter:
277
  for task in all_tasks:
278
  if task["task_id"] == canonical_task_id:
279
  return task
 
280
  return cls.EASY_TASKS[0]
281
 
282
  @classmethod
283
  def get_all_tasks(cls) -> List[Dict[str, Any]]:
284
  return cls.EASY_TASKS + cls.MEDIUM_TASKS + cls.HARD_TASKS
285
-
286
  @classmethod
287
  def get_tasks_by_difficulty(cls, difficulty: str) -> List[Dict[str, Any]]:
288
  if difficulty == "easy":
@@ -292,7 +302,7 @@ class Counter:
292
  elif difficulty == "hard":
293
  return cls.HARD_TASKS
294
  return []
295
-
296
  @classmethod
297
  def create_code_context(cls, task_data: Dict[str, Any]) -> CodeContext:
298
  return CodeContext(
@@ -303,15 +313,11 @@ class Counter:
303
  language=task_data["language"],
304
  line_count=task_data["line_count"]
305
  )
306
-
307
  @classmethod
308
  def create_task_metadata(cls, task_data: Dict[str, Any]) -> TaskMetadata:
309
- difficulty = "easy"
310
- if "medium" in task_data["task_id"]:
311
- difficulty = "medium"
312
- elif "hard" in task_data["task_id"]:
313
- difficulty = "hard"
314
-
315
  return TaskMetadata(
316
  task_id=task_data["task_id"],
317
  task_name=task_data["task_name"],
 
9
  "bug_detection_medium": "memory_leak_medium_1",
10
  "bug_detection_hard": "security_hard_1",
11
  }
12
+
13
  EASY_TASKS = [
14
  {
15
  "task_id": "bug_detection_easy_1",
16
  "task_name": "Division by Zero",
17
+ "difficulty": "easy",
18
  "description": "Find the division by zero vulnerability in the calculate_average function",
19
  "code_diff": """def calculate_average(numbers):
20
  total = sum(numbers)
 
22
  "surrounding_code": """class StatisticsCalculator:
23
  def __init__(self):
24
  self.results = []
25
+
26
  def calculate_average(self, numbers):
27
  total = sum(numbers)
28
  return total / len(numbers)
29
+
30
  def add_result(self, value):
31
  self.results.append(value)""",
32
  "file_path": "statistics.py",
 
44
  {
45
  "task_id": "bug_detection_easy_2",
46
  "task_name": "Off-by-One Error",
47
+ "difficulty": "easy",
48
  "description": "Find the off-by-one error in the array iteration",
49
  "code_diff": """def process_items(items):
50
  for i in range(len(items)):
 
72
  {
73
  "task_id": "approve_easy_3",
74
  "task_name": "Approve Safe Refactor",
75
+ "difficulty": "easy",
76
  "description": "No issues expected: approve this small readability refactor",
77
  "code_diff": """def normalize_name(name):
78
  cleaned = name.strip()
 
89
  "expected_issues": []
90
  }
91
  ]
92
+
93
  MEDIUM_TASKS = [
94
  {
95
  "task_id": "memory_leak_medium_1",
96
  "task_name": "File Handle Leak",
97
+ "difficulty": "medium",
98
  "description": "Find the memory leak where file handles are not properly closed",
99
  "code_diff": """def read_files(file_list):
100
  contents = []
 
131
  {
132
  "task_id": "performance_medium_2",
133
  "task_name": "Inefficient String Concatenation",
134
+ "difficulty": "medium",
135
  "description": "Find the performance issue with string concatenation in a loop",
136
  "code_diff": """def build_string(items):
137
  result = ""
 
161
  {
162
  "task_id": "approve_medium_3",
163
  "task_name": "Approve Safe Query Helper",
164
+ "difficulty": "medium",
165
  "description": "No issues expected: approve this query helper cleanup",
166
  "code_diff": """def build_user_query(limit):
167
  safe_limit = max(1, int(limit))
 
179
  "expected_issues": []
180
  }
181
  ]
182
+
183
  HARD_TASKS = [
184
  {
185
  "task_id": "security_hard_1",
186
  "task_name": "SQL Injection Vulnerability",
187
+ "difficulty": "hard",
188
  "description": "Find the SQL injection vulnerability in the database query",
189
  "code_diff": """def get_user_data(user_id):
190
  query = f"SELECT * FROM users WHERE id = {user_id}"
 
212
  {
213
  "task_id": "race_condition_hard_2",
214
  "task_name": "Race Condition",
215
+ "difficulty": "hard",
216
  "description": "Find the race condition in the thread-safe counter",
217
  "code_diff": """class Counter:
218
  def __init__(self):
219
  self.count = 0
220
+
221
  def increment(self):
222
  current = self.count
223
  self.count = current + 1
 
227
  class Counter:
228
  def __init__(self):
229
  self.count = 0
230
+
231
  def increment(self):
232
  current = self.count
233
  self.count = current + 1
234
  return self.count
235
+
236
  def get_count(self):
237
  return self.count""",
238
  "file_path": "counter.py",
 
250
  {
251
  "task_id": "approve_hard_3",
252
  "task_name": "Approve Thread-Safe Counter",
253
+ "difficulty": "hard",
254
  "description": "No issues expected: approve this lock-based concurrency fix",
255
  "code_diff": """class Counter:
256
  def __init__(self):
 
278
  "expected_issues": []
279
  }
280
  ]
281
+
282
  @classmethod
283
  def get_task(cls, task_id: str) -> Dict[str, Any]:
284
  canonical_task_id = cls.TASK_ALIASES.get(task_id, task_id)
 
286
  for task in all_tasks:
287
  if task["task_id"] == canonical_task_id:
288
  return task
289
+ print(f"WARNING: task_id '{task_id}' not found, falling back to bug_detection_easy_1")
290
  return cls.EASY_TASKS[0]
291
 
292
  @classmethod
293
  def get_all_tasks(cls) -> List[Dict[str, Any]]:
294
  return cls.EASY_TASKS + cls.MEDIUM_TASKS + cls.HARD_TASKS
295
+
296
  @classmethod
297
  def get_tasks_by_difficulty(cls, difficulty: str) -> List[Dict[str, Any]]:
298
  if difficulty == "easy":
 
302
  elif difficulty == "hard":
303
  return cls.HARD_TASKS
304
  return []
305
+
306
  @classmethod
307
  def create_code_context(cls, task_data: Dict[str, Any]) -> CodeContext:
308
  return CodeContext(
 
313
  language=task_data["language"],
314
  line_count=task_data["line_count"]
315
  )
316
+
317
  @classmethod
318
  def create_task_metadata(cls, task_data: Dict[str, Any]) -> TaskMetadata:
319
+ difficulty = task_data.get("difficulty", "easy")
320
+
 
 
 
 
321
  return TaskMetadata(
322
  task_id=task_data["task_id"],
323
  task_name=task_data["task_name"],
inference.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import json
7
  import argparse
8
  import sys
9
- from typing import Dict, Any
10
  from openai import OpenAI
11
 
12
  API_BASE_URL = os.environ.get("API_BASE_URL", "")
@@ -72,9 +72,8 @@ class LLMClient:
72
  print(f"Endpoint: {self.base_url}")
73
  print(f"Model: {self.model}\n")
74
 
75
- def chat_completion(self, messages: list, temperature: float = 0.7, max_tokens: int = 2000) -> str:
76
  last_error = None
77
- # Retry once for flaky local-model responses.
78
  for _ in range(2):
79
  try:
80
  completion = self.client.chat.completions.create(
@@ -132,69 +131,34 @@ class CodeReviewAgent:
132
  if " / len(" in code_diff:
133
  line = self._line_number(code_diff, " / len(", 1)
134
  line = self._task_expected_line(observation, line)
135
- return {
136
- "line_number": line,
137
- "content": "Possible division_by_zero when list is empty before dividing by len(...).",
138
- "is_issue": True,
139
- "severity": "high",
140
- }
141
 
142
  if "open(" in code_diff and ".read(" in code_diff and "with open" not in code_diff:
143
  line = self._line_number(code_diff, "open(", 1)
144
  line = self._task_expected_line(observation, line)
145
- return {
146
- "line_number": line,
147
- "content": "Potential resource_leak: file handle opened without context manager or explicit close().",
148
- "is_issue": True,
149
- "severity": "high",
150
- }
151
 
152
  if "SELECT" in code_diff and "{" in code_diff and "}" in code_diff:
153
  line = self._line_number(code_diff, "SELECT", 1)
154
  line = self._task_expected_line(observation, line)
155
- return {
156
- "line_number": line,
157
- "content": "Potential sql_injection due to string interpolation in SQL query.",
158
- "is_issue": True,
159
- "severity": "critical",
160
- }
161
 
162
  if "i + 1" in code_diff and "range(len(" in code_diff:
163
  line = self._line_number(code_diff, "i + 1", 1)
164
  line = self._task_expected_line(observation, line)
165
- return {
166
- "line_number": line,
167
- "content": "Potential index_error: i + 1 can go out of bounds on the last iteration.",
168
- "is_issue": True,
169
- "severity": "medium",
170
- }
171
 
172
  if "result = result +" in code_diff:
173
  line = self._line_number(code_diff, "result = result +", 1)
174
  line = self._task_expected_line(observation, line)
175
- return {
176
- "line_number": line,
177
- "content": "Potential performance issue from repeated string concatenation in a loop.",
178
- "is_issue": True,
179
- "severity": "medium",
180
- }
181
 
182
  if "current = self.count" in code_diff and "self.count = current + 1" in code_diff:
183
  line = self._line_number(code_diff, "self.count = current + 1", 1)
184
  line = self._task_expected_line(observation, line)
185
- return {
186
- "line_number": line,
187
- "content": "Potential race_condition: increment is not atomic without synchronization.",
188
- "is_issue": True,
189
- "severity": "high",
190
- }
191
 
192
- return {
193
- "line_number": 1,
194
- "content": "Potential correctness issue requires manual validation.",
195
- "is_issue": True,
196
- "severity": "low",
197
- }
198
 
199
  def _heuristic_suggestion(self, observation: Dict[str, Any]) -> Dict[str, Any]:
200
  code_diff = observation.get("code_diff", "")
@@ -202,62 +166,34 @@ class CodeReviewAgent:
202
  if " / len(" in code_diff:
203
  line = self._line_number(code_diff, " / len(", 1)
204
  line = self._task_expected_line(observation, line)
205
- return {
206
- "original_line": line,
207
- "suggested_code": "return total / len(numbers) if numbers else 0",
208
- "explanation": "Guard against empty input before division.",
209
- }
210
 
211
  if "open(" in code_diff and ".read(" in code_diff and "with open" not in code_diff:
212
  line = self._line_number(code_diff, "open(", 1)
213
  line = self._task_expected_line(observation, line)
214
- return {
215
- "original_line": line,
216
- "suggested_code": "with open(filename, 'r') as f:\n data = f.read()",
217
- "explanation": "Use a context manager so file handles are always closed.",
218
- }
219
 
220
  if "SELECT" in code_diff and "{" in code_diff and "}" in code_diff:
221
  line = self._line_number(code_diff, "SELECT", 1)
222
  line = self._task_expected_line(observation, line)
223
- return {
224
- "original_line": line,
225
- "suggested_code": "query = \"SELECT * FROM users WHERE id = ?\"\nreturn database.execute(query, [user_id])",
226
- "explanation": "Use parameterized queries to prevent SQL injection.",
227
- }
228
 
229
  if "i + 1" in code_diff and "range(len(" in code_diff:
230
  line = self._line_number(code_diff, "i + 1", 1)
231
  line = self._task_expected_line(observation, line)
232
- return {
233
- "original_line": line,
234
- "suggested_code": "for i in range(len(items) - 1):\n item = items[i]\n next_item = items[i + 1]\n process_pair(item, next_item)",
235
- "explanation": "Stop one element early to avoid indexing past the array end.",
236
- }
237
 
238
  if "result = result +" in code_diff:
239
  line = self._line_number(code_diff, "result = result +", 1)
240
  line = self._task_expected_line(observation, line)
241
- return {
242
- "original_line": line,
243
- "suggested_code": "return \",\".join(items)",
244
- "explanation": "join() avoids quadratic-time string concatenation.",
245
- }
246
 
247
  if "current = self.count" in code_diff and "self.count = current + 1" in code_diff:
248
  line = self._line_number(code_diff, "self.count = current + 1", 1)
249
  line = self._task_expected_line(observation, line)
250
- return {
251
- "original_line": line,
252
- "suggested_code": "with self._lock:\n self.count += 1\n return self.count",
253
- "explanation": "Protect shared state with a lock for thread safety.",
254
- }
255
 
256
- return {
257
- "original_line": 1,
258
- "suggested_code": "# apply targeted fix here",
259
- "explanation": "Provide a minimal fix for the identified issue.",
260
- }
261
 
262
  def _coerce_action_for_phase(self, action_data: Dict[str, Any], observation: Dict[str, Any]) -> Dict[str, Any]:
263
  phase = self.phase
@@ -265,39 +201,19 @@ class CodeReviewAgent:
265
 
266
  if phase == 1:
267
  if no_issue_task:
268
- return {
269
- "action_type": "add_comment",
270
- "comments": [],
271
- "suggestions": [],
272
- "final_decision": None,
273
- }
274
  comments = action_data.get("comments") or []
275
  if action_data.get("action_type") != "add_comment" or not comments:
276
  comments = [self._heuristic_comment(observation)]
277
- return {
278
- "action_type": "add_comment",
279
- "comments": comments,
280
- "suggestions": [],
281
- "final_decision": None,
282
- }
283
 
284
  if phase == 2:
285
  if no_issue_task:
286
- return {
287
- "action_type": "suggest_fix",
288
- "comments": [],
289
- "suggestions": [],
290
- "final_decision": None,
291
- }
292
  suggestions = action_data.get("suggestions") or []
293
  if action_data.get("action_type") != "suggest_fix" or not suggestions:
294
  suggestions = [self._heuristic_suggestion(observation)]
295
- return {
296
- "action_type": "suggest_fix",
297
- "comments": [],
298
- "suggestions": suggestions,
299
- "final_decision": None,
300
- }
301
 
302
  prior_comments = observation.get("previous_comments", [])
303
  prior_suggestions = observation.get("previous_suggestions", [])
@@ -309,6 +225,11 @@ class CodeReviewAgent:
309
  "final_decision": final_decision,
310
  }
311
 
 
 
 
 
 
312
  def get_action(self, observation: Dict[str, Any]) -> str:
313
 
314
  system_prompt = """You are an expert code reviewer. You MUST follow this exact sequence:
@@ -360,6 +281,8 @@ Respond ONLY with a valid JSON object, no extra text:
360
  for s in prev_suggestions
361
  ]) or "None yet"
362
 
 
 
363
  if self.phase == 1:
364
  phase_instruction = """
365
  YOUR TASK NOW (Phase 1 - Add Comments):
@@ -397,6 +320,7 @@ File Context:
397
  {observation.get('file_context', '')}
398
 
399
  Current Step: {observation.get('current_step', 0)}/{observation.get('max_steps', 50)}
 
400
 
401
  Comments already made:
402
  {comments_text}
@@ -438,7 +362,6 @@ Respond with JSON only.
438
  action_data["suggestions"] = []
439
 
440
  action_data = self._coerce_action_for_phase(action_data, observation)
441
-
442
  self.phase += 1
443
  return json.dumps(action_data)
444
 
@@ -478,41 +401,18 @@ Respond with JSON only.
478
  return {"action_type": "request_changes", "comments": [], "suggestions": []}
479
 
480
 
481
- def main():
482
- sys.path.append('.')
483
-
484
- try:
485
- from environment.env import CodeReviewEnv
486
- except ImportError as e:
487
- print(f"Failed to import environment: {e}")
488
- print("Make sure you're in the correct directory and environment is installed.")
489
- sys.exit(1)
490
-
491
- parser = argparse.ArgumentParser(description="Run code review agent")
492
- parser.add_argument("--task-id", type=str, default="bug_detection_easy_1")
493
- parser.add_argument("--max-steps", type=int, default=50)
494
- parser.add_argument("--output", type=str, default="baseline_results.json")
495
- args = parser.parse_args()
496
-
497
- print("=" * 60)
498
- print("Code Review Agent")
499
- print("=" * 60)
500
-
501
- env = CodeReviewEnv()
502
- env.max_steps = args.max_steps
503
- agent = CodeReviewAgent()
504
-
505
- obs = env.reset(task_id=args.task_id)
506
  done = False
507
  step = 0
508
  total_reward = 0.0
509
 
510
- print(f"\nTask : {args.task_id}")
511
  print(f"Desc : {obs.get('task_description', 'N/A')}")
512
- print(f"Model : {MODEL_NAME}")
513
  print("-" * 60)
514
 
515
- while not done and step < args.max_steps:
516
  action_str = agent.get_action(obs)
517
  action = agent.parse_action(action_str)
518
  action = agent.validate_action(action, obs)
@@ -521,7 +421,7 @@ def main():
521
  total_reward += reward
522
  step += 1
523
 
524
- print(f"\nStep {step}/{args.max_steps}:")
525
  print(f" Phase : {agent.phase - 1}")
526
  print(f" Action : {action.get('action_type')}")
527
  print(f" Comments : {len(action.get('comments', []))}")
@@ -529,38 +429,120 @@ def main():
529
  print(f" Reward : {reward:.3f}")
530
  print(f" Total : {total_reward:.3f}")
531
  print(f" Score : {info.get('task_score', 0):.3f}")
 
532
 
533
  if info.get('last_action_valid') is False:
534
  print(f" Warning : {info.get('error', 'Invalid action')}")
535
 
536
  final_score = env.get_task_score()
 
537
 
538
- print("\n" + "=" * 60)
539
- print("Final Results:")
540
- print(f" Task : {args.task_id}")
541
- print(f" Total Reward : {total_reward:.3f}")
542
- print(f" Task Score : {final_score:.3f}/1.0")
543
- print(f" Steps : {step}")
544
- print("=" * 60)
545
-
546
- env.close()
547
-
548
- results = {
549
- "task_id": args.task_id,
550
  "total_reward": round(total_reward, 4),
551
  "task_score": round(final_score, 4),
552
  "steps": step,
553
- "max_steps": args.max_steps,
554
- "provider": "openai-client",
 
 
 
 
555
  "model": MODEL_NAME,
556
- "api_base_url": API_BASE_URL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  }
558
 
559
- with open(args.output, "w") as f:
560
- json.dump(results, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
561
 
562
- print(f"\nResults saved to {args.output}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
 
565
  if __name__ == "__main__":
566
- main()
 
6
  import json
7
  import argparse
8
  import sys
9
+ from typing import Dict, Any, List
10
  from openai import OpenAI
11
 
12
  API_BASE_URL = os.environ.get("API_BASE_URL", "")
 
72
  print(f"Endpoint: {self.base_url}")
73
  print(f"Model: {self.model}\n")
74
 
75
+ def chat_completion(self, messages: list, temperature: float = 0.0, max_tokens: int = 2000) -> str:
76
  last_error = None
 
77
  for _ in range(2):
78
  try:
79
  completion = self.client.chat.completions.create(
 
131
  if " / len(" in code_diff:
132
  line = self._line_number(code_diff, " / len(", 1)
133
  line = self._task_expected_line(observation, line)
134
+ return {"line_number": line, "content": "Possible division_by_zero when list is empty before dividing by len(...).", "is_issue": True, "severity": "high"}
 
 
 
 
 
135
 
136
  if "open(" in code_diff and ".read(" in code_diff and "with open" not in code_diff:
137
  line = self._line_number(code_diff, "open(", 1)
138
  line = self._task_expected_line(observation, line)
139
+ return {"line_number": line, "content": "Potential resource_leak: file handle opened without context manager or explicit close().", "is_issue": True, "severity": "high"}
 
 
 
 
 
140
 
141
  if "SELECT" in code_diff and "{" in code_diff and "}" in code_diff:
142
  line = self._line_number(code_diff, "SELECT", 1)
143
  line = self._task_expected_line(observation, line)
144
+ return {"line_number": line, "content": "Potential sql_injection due to string interpolation in SQL query.", "is_issue": True, "severity": "critical"}
 
 
 
 
 
145
 
146
  if "i + 1" in code_diff and "range(len(" in code_diff:
147
  line = self._line_number(code_diff, "i + 1", 1)
148
  line = self._task_expected_line(observation, line)
149
+ return {"line_number": line, "content": "Potential index_error: i + 1 can go out of bounds on the last iteration.", "is_issue": True, "severity": "medium"}
 
 
 
 
 
150
 
151
  if "result = result +" in code_diff:
152
  line = self._line_number(code_diff, "result = result +", 1)
153
  line = self._task_expected_line(observation, line)
154
+ return {"line_number": line, "content": "Potential performance issue from repeated string concatenation in a loop.", "is_issue": True, "severity": "medium"}
 
 
 
 
 
155
 
156
  if "current = self.count" in code_diff and "self.count = current + 1" in code_diff:
157
  line = self._line_number(code_diff, "self.count = current + 1", 1)
158
  line = self._task_expected_line(observation, line)
159
+ return {"line_number": line, "content": "Potential race_condition: increment is not atomic without synchronization.", "is_issue": True, "severity": "high"}
 
 
 
 
 
160
 
161
+ return {"line_number": 1, "content": "Potential correctness issue requires manual validation.", "is_issue": True, "severity": "low"}
 
 
 
 
 
162
 
163
  def _heuristic_suggestion(self, observation: Dict[str, Any]) -> Dict[str, Any]:
164
  code_diff = observation.get("code_diff", "")
 
166
  if " / len(" in code_diff:
167
  line = self._line_number(code_diff, " / len(", 1)
168
  line = self._task_expected_line(observation, line)
169
+ return {"original_line": line, "suggested_code": "return total / len(numbers) if numbers else 0", "explanation": "Guard against empty input before division."}
 
 
 
 
170
 
171
  if "open(" in code_diff and ".read(" in code_diff and "with open" not in code_diff:
172
  line = self._line_number(code_diff, "open(", 1)
173
  line = self._task_expected_line(observation, line)
174
+ return {"original_line": line, "suggested_code": "with open(filename, 'r') as f:\n data = f.read()", "explanation": "Use a context manager so file handles are always closed."}
 
 
 
 
175
 
176
  if "SELECT" in code_diff and "{" in code_diff and "}" in code_diff:
177
  line = self._line_number(code_diff, "SELECT", 1)
178
  line = self._task_expected_line(observation, line)
179
+ return {"original_line": line, "suggested_code": "query = \"SELECT * FROM users WHERE id = ?\"\nreturn database.execute(query, [user_id])", "explanation": "Use parameterized queries to prevent SQL injection."}
 
 
 
 
180
 
181
  if "i + 1" in code_diff and "range(len(" in code_diff:
182
  line = self._line_number(code_diff, "i + 1", 1)
183
  line = self._task_expected_line(observation, line)
184
+ return {"original_line": line, "suggested_code": "for i in range(len(items) - 1):\n item = items[i]\n next_item = items[i + 1]\n process_pair(item, next_item)", "explanation": "Stop one element early to avoid indexing past the array end."}
 
 
 
 
185
 
186
  if "result = result +" in code_diff:
187
  line = self._line_number(code_diff, "result = result +", 1)
188
  line = self._task_expected_line(observation, line)
189
+ return {"original_line": line, "suggested_code": "return \",\".join(items)", "explanation": "join() avoids quadratic-time string concatenation."}
 
 
 
 
190
 
191
  if "current = self.count" in code_diff and "self.count = current + 1" in code_diff:
192
  line = self._line_number(code_diff, "self.count = current + 1", 1)
193
  line = self._task_expected_line(observation, line)
194
+ return {"original_line": line, "suggested_code": "with self._lock:\n self.count += 1\n return self.count", "explanation": "Protect shared state with a lock for thread safety."}
 
 
 
 
195
 
196
+ return {"original_line": 1, "suggested_code": "# apply targeted fix here", "explanation": "Provide a minimal fix for the identified issue."}
 
 
 
 
197
 
198
  def _coerce_action_for_phase(self, action_data: Dict[str, Any], observation: Dict[str, Any]) -> Dict[str, Any]:
199
  phase = self.phase
 
201
 
202
  if phase == 1:
203
  if no_issue_task:
204
+ return {"action_type": "add_comment", "comments": [], "suggestions": [], "final_decision": None}
 
 
 
 
 
205
  comments = action_data.get("comments") or []
206
  if action_data.get("action_type") != "add_comment" or not comments:
207
  comments = [self._heuristic_comment(observation)]
208
+ return {"action_type": "add_comment", "comments": comments, "suggestions": [], "final_decision": None}
 
 
 
 
 
209
 
210
  if phase == 2:
211
  if no_issue_task:
212
+ return {"action_type": "suggest_fix", "comments": [], "suggestions": [], "final_decision": None}
 
 
 
 
 
213
  suggestions = action_data.get("suggestions") or []
214
  if action_data.get("action_type") != "suggest_fix" or not suggestions:
215
  suggestions = [self._heuristic_suggestion(observation)]
216
+ return {"action_type": "suggest_fix", "comments": [], "suggestions": suggestions, "final_decision": None}
 
 
 
 
 
217
 
218
  prior_comments = observation.get("previous_comments", [])
219
  prior_suggestions = observation.get("previous_suggestions", [])
 
225
  "final_decision": final_decision,
226
  }
227
 
228
+ def reset(self):
229
+ self.phase = 1
230
+ self.model_unavailable = False
231
+ self.history = []
232
+
233
  def get_action(self, observation: Dict[str, Any]) -> str:
234
 
235
  system_prompt = """You are an expert code reviewer. You MUST follow this exact sequence:
 
281
  for s in prev_suggestions
282
  ]) or "None yet"
283
 
284
+ valid_actions = observation.get("valid_actions", [])
285
+
286
  if self.phase == 1:
287
  phase_instruction = """
288
  YOUR TASK NOW (Phase 1 - Add Comments):
 
320
  {observation.get('file_context', '')}
321
 
322
  Current Step: {observation.get('current_step', 0)}/{observation.get('max_steps', 50)}
323
+ Valid Actions: {valid_actions}
324
 
325
  Comments already made:
326
  {comments_text}
 
362
  action_data["suggestions"] = []
363
 
364
  action_data = self._coerce_action_for_phase(action_data, observation)
 
365
  self.phase += 1
366
  return json.dumps(action_data)
367
 
 
401
  return {"action_type": "request_changes", "comments": [], "suggestions": []}
402
 
403
 
404
+ def run_episode(env, agent, task_id: str, max_steps: int) -> Dict[str, Any]:
405
+ agent.reset()
406
+ obs = env.reset(task_id=task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  done = False
408
  step = 0
409
  total_reward = 0.0
410
 
411
+ print(f"\nTask : {task_id}")
412
  print(f"Desc : {obs.get('task_description', 'N/A')}")
 
413
  print("-" * 60)
414
 
415
+ while not done and step < max_steps:
416
  action_str = agent.get_action(obs)
417
  action = agent.parse_action(action_str)
418
  action = agent.validate_action(action, obs)
 
421
  total_reward += reward
422
  step += 1
423
 
424
+ print(f"\nStep {step}/{max_steps}:")
425
  print(f" Phase : {agent.phase - 1}")
426
  print(f" Action : {action.get('action_type')}")
427
  print(f" Comments : {len(action.get('comments', []))}")
 
429
  print(f" Reward : {reward:.3f}")
430
  print(f" Total : {total_reward:.3f}")
431
  print(f" Score : {info.get('task_score', 0):.3f}")
432
+ print(f" Valid Actions: {info.get('valid_actions', [])}")
433
 
434
  if info.get('last_action_valid') is False:
435
  print(f" Warning : {info.get('error', 'Invalid action')}")
436
 
437
  final_score = env.get_task_score()
438
+ diagnostics = env.summary()
439
 
440
+ return {
441
+ "task_id": task_id,
 
 
 
 
 
 
 
 
 
 
442
  "total_reward": round(total_reward, 4),
443
  "task_score": round(final_score, 4),
444
  "steps": step,
445
+ "max_steps": max_steps,
446
+ "precision": diagnostics.get("precision", 0),
447
+ "recall": diagnostics.get("recall", 0),
448
+ "f1": diagnostics.get("f1", 0),
449
+ "false_positive_count": diagnostics.get("false_positive_count", 0),
450
+ "efficiency_bonus": diagnostics.get("efficiency_bonus", 0),
451
  "model": MODEL_NAME,
452
+ "api_base_url": API_BASE_URL,
453
+ }
454
+
455
+
456
+ def run_batch(env, agent, task_ids: List[str], max_steps: int, output: str):
457
+ all_results = []
458
+ print("=" * 60)
459
+ print(f"Batch Evaluation: {len(task_ids)} tasks")
460
+ print("=" * 60)
461
+
462
+ for task_id in task_ids:
463
+ result = run_episode(env, agent, task_id, max_steps)
464
+ all_results.append(result)
465
+
466
+ avg_score = sum(r["task_score"] for r in all_results) / len(all_results)
467
+ avg_reward = sum(r["total_reward"] for r in all_results) / len(all_results)
468
+ avg_f1 = sum(r["f1"] for r in all_results) / len(all_results)
469
+
470
+ print("\n" + "=" * 60)
471
+ print("Batch Results:")
472
+ print(f" Tasks evaluated : {len(all_results)}")
473
+ print(f" Avg Task Score : {avg_score:.3f}")
474
+ print(f" Avg Reward : {avg_reward:.3f}")
475
+ print(f" Avg F1 : {avg_f1:.3f}")
476
+ print("=" * 60)
477
+
478
+ batch_output = {
479
+ "summary": {
480
+ "total_tasks": len(all_results),
481
+ "avg_task_score": round(avg_score, 4),
482
+ "avg_total_reward": round(avg_reward, 4),
483
+ "avg_f1": round(avg_f1, 4),
484
+ "model": MODEL_NAME,
485
+ },
486
+ "results": all_results,
487
  }
488
 
489
+ with open(output, "w") as f:
490
+ json.dump(batch_output, f, indent=2)
491
+
492
+ print(f"\nBatch results saved to {output}")
493
+
494
+
495
+ def main():
496
+ sys.path.append('.')
497
+
498
+ try:
499
+ from environment.env import CodeReviewEnv
500
+ except ImportError as e:
501
+ print(f"Failed to import environment: {e}")
502
+ print("Make sure you're in the correct directory and environment is installed.")
503
+ sys.exit(1)
504
 
505
+ parser = argparse.ArgumentParser(description="Run code review agent")
506
+ parser.add_argument("--task-id", type=str, default="bug_detection_easy_1")
507
+ parser.add_argument("--max-steps", type=int, default=50)
508
+ parser.add_argument("--output", type=str, default="baseline_results.json")
509
+ parser.add_argument("--batch", action="store_true", help="Run all tasks in batch mode")
510
+ parser.add_argument("--difficulty", type=str, default=None, help="Filter batch by difficulty: easy, medium, hard")
511
+ args = parser.parse_args()
512
+
513
+ print("=" * 60)
514
+ print("Code Review Agent")
515
+ print("=" * 60)
516
+
517
+ env = CodeReviewEnv()
518
+ env.max_steps = args.max_steps
519
+ agent = CodeReviewAgent()
520
+
521
+ if args.batch:
522
+ from environment.tasks import TaskDefinitions
523
+ if args.difficulty:
524
+ task_ids = [t["task_id"] for t in TaskDefinitions.get_tasks_by_difficulty(args.difficulty)]
525
+ else:
526
+ task_ids = [t["task_id"] for t in TaskDefinitions.get_all_tasks()]
527
+ run_batch(env, agent, task_ids, args.max_steps, args.output)
528
+ else:
529
+ result = run_episode(env, agent, args.task_id, args.max_steps)
530
+
531
+ print("\n" + "=" * 60)
532
+ print("Final Results:")
533
+ print(f" Task : {result['task_id']}")
534
+ print(f" Total Reward : {result['total_reward']:.3f}")
535
+ print(f" Task Score : {result['task_score']:.3f}/1.0")
536
+ print(f" Steps : {result['steps']}")
537
+ print("=" * 60)
538
+
539
+ with open(args.output, "w") as f:
540
+ json.dump(result, f, indent=2)
541
+
542
+ print(f"\nResults saved to {args.output}")
543
+
544
+ env.close()
545
 
546
 
547
  if __name__ == "__main__":
548
+ main()