Harshit2N commited on
Commit
923bb71
·
unverified ·
1 Parent(s): 883ba26

Update inference.py

Browse files

improve agent scoring with multi-phase review strategy

Files changed (1) hide show
  1. inference.py +161 -89
inference.py CHANGED
@@ -1,4 +1,6 @@
1
  #!/usr/bin/env python3
 
 
2
 
3
  import os
4
  import json
@@ -21,26 +23,21 @@ if not API_BASE_URL:
21
  print("\nPlease set the following environment variables:\n")
22
  print(" API_BASE_URL - Your API endpoint")
23
  print(" MODEL_NAME - Model identifier")
24
- print(" HF_TOKEN - Your Hugging Face / API key\n")
25
  print("Examples:\n")
26
- print(" OpenAI:")
27
- print(" export API_BASE_URL=https://api.openai.com/v1")
28
- print(" export MODEL_NAME=gpt-4")
29
- print(" export HF_TOKEN=sk-xxxxx\n")
30
- print(" Gemini:")
31
- print(" export API_BASE_URL=https://generativelanguage.googleapis.com")
32
- print(" export MODEL_NAME=gemini-1.5-pro")
33
- print(" export HF_TOKEN=AIzaSyxxxxx\n")
34
- print(" Local:")
35
  print(" export API_BASE_URL=http://localhost:11434/v1")
36
- print(" export MODEL_NAME=llama2")
37
  print(" export HF_TOKEN=not-needed\n")
38
  print("=" * 60)
39
  sys.exit(1)
40
 
41
  if not MODEL_NAME:
42
  print("ERROR: MODEL_NAME environment variable is required")
43
- print("Example: export MODEL_NAME=gpt-4")
44
  sys.exit(1)
45
 
46
  if not API_KEY:
@@ -55,14 +52,22 @@ FALLBACK_ACTION = json.dumps({
55
  })
56
 
57
 
 
 
 
 
 
58
  class LLMClient:
59
 
60
  def __init__(self, base_url: str, api_key: str, model: str):
61
  self.base_url = base_url.rstrip("/")
62
  self.api_key = api_key
63
  self.model = model
64
- self.client = OpenAI(base_url=self.base_url, api_key=self.api_key, timeout=REQUEST_TIMEOUT)
65
-
 
 
 
66
  print("Connected using OpenAI client")
67
  print(f"Endpoint: {self.base_url}")
68
  print(f"Model: {self.model}\n")
@@ -79,94 +84,162 @@ class LLMClient:
79
 
80
 
81
  class CodeReviewAgent:
82
-
83
  def __init__(self):
84
  self.client = LLMClient(API_BASE_URL, API_KEY, MODEL_NAME)
85
  self.history = []
86
-
 
87
  def get_action(self, observation: Dict[str, Any]) -> str:
88
-
89
- system_prompt = """You are an expert code reviewer. Your task is to review code changes and provide feedback.
90
 
91
- Review the code diff and identify issues. You can:
92
- 1. ADD_COMMENT: Add a comment about an issue on a specific line
93
- 2. SUGGEST_FIX: Suggest a specific code fix for an issue
94
- 3. APPROVE: Approve the code changes (only if no critical issues)
95
- 4. REQUEST_CHANGES: Request changes (if issues are found)
96
 
97
- Respond with a JSON object in this format:
 
 
 
 
 
 
 
 
 
 
 
98
  {
99
  "action_type": "add_comment" | "suggest_fix" | "approve" | "request_changes",
100
  "comments": [
101
  {
102
- "line_number": 10,
103
- "content": "This line has a potential bug...",
104
  "is_issue": true,
105
- "severity": "high"
106
  }
107
  ],
108
  "suggestions": [
109
  {
110
- "original_line": 10,
111
- "suggested_code": "if x != 0:",
112
- "explanation": "Prevents division by zero"
113
  }
114
  ],
115
- "final_decision": "approved" | "changes_requested" (only if action_type is approve or request_changes)
116
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- Be thorough but concise. Focus on real issues like bugs, security vulnerabilities, performance problems, and code quality."""
119
-
120
  user_prompt = f"""
121
  Code Review Task:
122
  {observation.get('task_description', 'Review the following code changes')}
123
 
124
- Code Diff:
125
- {observation.get('code_diff', '')}
126
 
127
  File Context:
128
  {observation.get('file_context', '')}
129
 
130
- Current step: {observation.get('current_step', 0)}/{observation.get('max_steps', 50)}
131
- Previous actions taken: {len(observation.get('previous_comments', []))} comments, {len(observation.get('previous_suggestions', []))} suggestions
 
 
 
 
 
 
 
132
 
133
- Please provide your review action as JSON.
134
  """
135
-
136
  messages = [
137
  {"role": "system", "content": system_prompt},
138
  {"role": "user", "content": user_prompt}
139
  ]
140
-
141
  try:
142
  response = self.client.chat_completion(messages, TEMPERATURE, MAX_TOKENS)
143
-
144
  response = response.strip()
145
-
146
  if "```json" in response:
147
  response = response.split("```json")[1].split("```")[0]
148
  elif "```" in response:
149
  response = response.split("```")[1].split("```")[0]
150
-
151
  action_data = json.loads(response.strip())
152
-
153
  if "action_type" not in action_data:
154
  action_data["action_type"] = "request_changes"
155
  if "comments" not in action_data:
156
  action_data["comments"] = []
157
  if "suggestions" not in action_data:
158
  action_data["suggestions"] = []
159
-
 
160
  return json.dumps(action_data)
161
-
162
  except json.JSONDecodeError as e:
163
  print(f"Failed to parse JSON response: {e}")
164
  print(f"Raw response: {response[:200]}...")
 
165
  return FALLBACK_ACTION
166
  except Exception as e:
167
  print(f"Error getting action from LLM: {e}")
168
  return FALLBACK_ACTION
169
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def parse_action(self, action_str: str) -> Dict[str, Any]:
171
  try:
172
  return json.loads(action_str)
@@ -175,89 +248,88 @@ Please provide your review action as JSON.
175
 
176
 
177
  def main():
178
- import sys
179
  sys.path.append('.')
180
-
181
  try:
182
  from environment.env import CodeReviewEnv
183
  except ImportError as e:
184
  print(f"Failed to import environment: {e}")
185
  print("Make sure you're in the correct directory and environment is installed.")
186
  sys.exit(1)
187
-
188
- parser = argparse.ArgumentParser(description="Run code review agent with any LLM provider")
189
- parser.add_argument("--task-id", type=str, default="bug_detection_easy_1",
190
- help="Task ID to run (e.g. bug_detection_easy_1, memory_leak_medium_1, security_hard_1)")
191
- parser.add_argument("--max-steps", type=int, default=50,
192
- help="Maximum steps per episode")
193
- parser.add_argument("--output", type=str, default="baseline_results.json",
194
- help="Output file for results")
195
  args = parser.parse_args()
196
-
197
  print("=" * 60)
198
- print("Code Review Agent - Running with ANY LLM Provider")
199
  print("=" * 60)
200
-
201
  env = CodeReviewEnv()
202
  env.max_steps = args.max_steps
203
-
204
  agent = CodeReviewAgent()
205
-
206
  obs = env.reset(task_id=args.task_id)
207
  done = False
208
  step = 0
209
  total_reward = 0.0
210
-
211
- print(f"\nTask: {args.task_id}")
212
- print(f"Description: {obs.get('task_description', 'N/A')}")
 
213
  print("-" * 60)
214
-
215
  while not done and step < args.max_steps:
216
  action_str = agent.get_action(obs)
217
  action = agent.parse_action(action_str)
218
-
 
219
  obs, reward, done, info = env.step(action)
220
-
221
  total_reward += reward
222
  step += 1
223
-
224
  print(f"\nStep {step}/{args.max_steps}:")
225
- print(f" Action: {action.get('action_type')}")
226
- print(f" Comments: {len(action.get('comments', []))}")
227
- print(f" Suggestions: {len(action.get('suggestions', []))}")
228
- print(f" Reward: {reward:.3f}")
229
- print(f" Total: {total_reward:.3f}")
230
-
 
 
231
  if info.get('last_action_valid') is False:
232
- print(f" Warning: {info.get('error', 'Invalid action')}")
233
-
234
  final_score = env.get_task_score()
 
235
  print("\n" + "=" * 60)
236
  print("Final Results:")
237
- print(f" Task: {args.task_id}")
238
- print(f" Total Reward: {total_reward:.3f}")
239
- print(f" Task Score: {final_score:.3f}/1.0")
240
- print(f" Steps: {step}")
241
  print("=" * 60)
242
-
243
  env.close()
244
-
245
  results = {
246
  "task_id": args.task_id,
247
- "total_reward": total_reward,
248
- "task_score": final_score,
249
  "steps": step,
250
  "max_steps": args.max_steps,
251
  "provider": "openai-client",
252
  "model": MODEL_NAME,
253
  "api_base_url": API_BASE_URL
254
  }
255
-
256
  with open(args.output, "w") as f:
257
  json.dump(results, f, indent=2)
258
-
259
  print(f"\nResults saved to {args.output}")
260
 
261
 
262
  if __name__ == "__main__":
263
- main()
 
1
  #!/usr/bin/env python3
2
+ from dotenv import load_dotenv
3
+ load_dotenv()
4
 
5
  import os
6
  import json
 
23
  print("\nPlease set the following environment variables:\n")
24
  print(" API_BASE_URL - Your API endpoint")
25
  print(" MODEL_NAME - Model identifier")
26
+ print(" HF_TOKEN - Your API key\n")
27
  print("Examples:\n")
28
+ print(" Groq:")
29
+ print(" export API_BASE_URL=https://api.groq.com/openai/v1")
30
+ print(" export MODEL_NAME=llama-3.3-70b-versatile")
31
+ print(" export HF_TOKEN=gsk_xxxxx\n")
32
+ print(" Local Ollama:")
 
 
 
 
33
  print(" export API_BASE_URL=http://localhost:11434/v1")
34
+ print(" export MODEL_NAME=llama3")
35
  print(" export HF_TOKEN=not-needed\n")
36
  print("=" * 60)
37
  sys.exit(1)
38
 
39
  if not MODEL_NAME:
40
  print("ERROR: MODEL_NAME environment variable is required")
 
41
  sys.exit(1)
42
 
43
  if not API_KEY:
 
52
  })
53
 
54
 
55
+ def add_line_numbers(code: str) -> str:
56
+ lines = code.split("\n")
57
+ return "\n".join(f"{i+1}: {line}" for i, line in enumerate(lines))
58
+
59
+
60
  class LLMClient:
61
 
62
  def __init__(self, base_url: str, api_key: str, model: str):
63
  self.base_url = base_url.rstrip("/")
64
  self.api_key = api_key
65
  self.model = model
66
+ self.client = OpenAI(
67
+ base_url=self.base_url,
68
+ api_key=self.api_key,
69
+ timeout=REQUEST_TIMEOUT
70
+ )
71
  print("Connected using OpenAI client")
72
  print(f"Endpoint: {self.base_url}")
73
  print(f"Model: {self.model}\n")
 
84
 
85
 
86
  class CodeReviewAgent:
87
+
88
  def __init__(self):
89
  self.client = LLMClient(API_BASE_URL, API_KEY, MODEL_NAME)
90
  self.history = []
91
+ self.phase = 1
92
+
93
  def get_action(self, observation: Dict[str, Any]) -> str:
 
 
94
 
95
+ system_prompt = """You are an expert code reviewer. You MUST follow this exact sequence:
 
 
 
 
96
 
97
+ PHASE 1 - Add Comments: Use action_type "add_comment" to identify ALL bugs with exact line numbers
98
+ PHASE 2 - Suggest Fixes: Use action_type "suggest_fix" to provide fixes for every bug found
99
+ PHASE 3 - Final Decision: Use action_type "request_changes" with final_decision "changes_requested"
100
+
101
+ RULES:
102
+ - NEVER skip straight to approve or request_changes without first adding comments and suggestions
103
+ - NEVER combine phases - each action should do ONE thing
104
+ - ALWAYS use the exact line numbers shown in the code diff
105
+ - ALWAYS set severity for comments: "critical", "high", "medium", or "low"
106
+ - If no bugs found in Phase 1, skip to Phase 3 with "approved"
107
+
108
+ Respond ONLY with a valid JSON object, no extra text:
109
  {
110
  "action_type": "add_comment" | "suggest_fix" | "approve" | "request_changes",
111
  "comments": [
112
  {
113
+ "line_number": <exact line number>,
114
+ "content": "Detailed explanation of the bug",
115
  "is_issue": true,
116
+ "severity": "critical" | "high" | "medium" | "low"
117
  }
118
  ],
119
  "suggestions": [
120
  {
121
+ "original_line": <exact line number>,
122
+ "suggested_code": "corrected code here",
123
+ "explanation": "why this fix works"
124
  }
125
  ],
126
+ "final_decision": "approved" | "changes_requested"
127
+ }"""
128
+
129
+ prev_comments = observation.get('previous_comments', [])
130
+ prev_suggestions = observation.get('previous_suggestions', [])
131
+
132
+ comments_text = "\n".join([
133
+ f" Line {c.get('line_number') if isinstance(c, dict) else c.line_number}: "
134
+ f"{c.get('content') if isinstance(c, dict) else c.content}"
135
+ for c in prev_comments
136
+ ]) or "None yet"
137
+
138
+ suggestions_text = "\n".join([
139
+ f" Line {s.get('original_line') if isinstance(s, dict) else s.original_line}: "
140
+ f"{s.get('suggested_code') if isinstance(s, dict) else s.suggested_code}"
141
+ for s in prev_suggestions
142
+ ]) or "None yet"
143
+
144
+ if self.phase == 1:
145
+ phase_instruction = """
146
+ YOUR TASK NOW (Phase 1 - Add Comments):
147
+ - action_type MUST be "add_comment"
148
+ - Carefully read the code diff line by line
149
+ - Find ALL bugs, vulnerabilities, or issues
150
+ - Comment on each one with the EXACT line number shown
151
+ - Do NOT make a final decision yet
152
+ - Do NOT suggest fixes yet
153
+ """
154
+ elif self.phase == 2:
155
+ phase_instruction = """
156
+ YOUR TASK NOW (Phase 2 - Suggest Fixes):
157
+ - action_type MUST be "suggest_fix"
158
+ - For every bug you commented on, provide a concrete code fix
159
+ - Use the same line numbers as your comments
160
+ - Do NOT make a final decision yet
161
+ """
162
+ else:
163
+ phase_instruction = """
164
+ YOUR TASK NOW (Phase 3 - Final Decision):
165
+ - action_type MUST be "request_changes"
166
+ - Set final_decision to "changes_requested"
167
+ - No new comments or suggestions needed
168
+ """
169
 
 
 
170
  user_prompt = f"""
171
  Code Review Task:
172
  {observation.get('task_description', 'Review the following code changes')}
173
 
174
+ Code Diff (USE THESE EXACT LINE NUMBERS in your response):
175
+ {add_line_numbers(observation.get('code_diff', ''))}
176
 
177
  File Context:
178
  {observation.get('file_context', '')}
179
 
180
+ Current Step: {observation.get('current_step', 0)}/{observation.get('max_steps', 50)}
181
+
182
+ Comments already made:
183
+ {comments_text}
184
+
185
+ Suggestions already made:
186
+ {suggestions_text}
187
+
188
+ {phase_instruction}
189
 
190
+ Respond with JSON only.
191
  """
192
+
193
  messages = [
194
  {"role": "system", "content": system_prompt},
195
  {"role": "user", "content": user_prompt}
196
  ]
197
+
198
  try:
199
  response = self.client.chat_completion(messages, TEMPERATURE, MAX_TOKENS)
 
200
  response = response.strip()
201
+
202
  if "```json" in response:
203
  response = response.split("```json")[1].split("```")[0]
204
  elif "```" in response:
205
  response = response.split("```")[1].split("```")[0]
206
+
207
  action_data = json.loads(response.strip())
208
+
209
  if "action_type" not in action_data:
210
  action_data["action_type"] = "request_changes"
211
  if "comments" not in action_data:
212
  action_data["comments"] = []
213
  if "suggestions" not in action_data:
214
  action_data["suggestions"] = []
215
+
216
+ self.phase += 1
217
  return json.dumps(action_data)
218
+
219
  except json.JSONDecodeError as e:
220
  print(f"Failed to parse JSON response: {e}")
221
  print(f"Raw response: {response[:200]}...")
222
+ self.phase += 1
223
  return FALLBACK_ACTION
224
  except Exception as e:
225
  print(f"Error getting action from LLM: {e}")
226
  return FALLBACK_ACTION
227
+
228
+ def validate_action(self, action: Dict, observation: Dict) -> Dict:
229
+ line_count = observation.get('line_count', 999)
230
+
231
+ for comment in action.get("comments", []):
232
+ comment["line_number"] = max(1, min(comment.get("line_number", 1), line_count))
233
+ if not comment.get("severity"):
234
+ comment["severity"] = "medium"
235
+ if "is_issue" not in comment:
236
+ comment["is_issue"] = True
237
+
238
+ for suggestion in action.get("suggestions", []):
239
+ suggestion["original_line"] = max(1, min(suggestion.get("original_line", 1), line_count))
240
+
241
+ return action
242
+
243
  def parse_action(self, action_str: str) -> Dict[str, Any]:
244
  try:
245
  return json.loads(action_str)
 
248
 
249
 
250
  def main():
 
251
  sys.path.append('.')
252
+
253
  try:
254
  from environment.env import CodeReviewEnv
255
  except ImportError as e:
256
  print(f"Failed to import environment: {e}")
257
  print("Make sure you're in the correct directory and environment is installed.")
258
  sys.exit(1)
259
+
260
+ parser = argparse.ArgumentParser(description="Run code review agent")
261
+ parser.add_argument("--task-id", type=str, default="bug_detection_easy_1")
262
+ parser.add_argument("--max-steps", type=int, default=50)
263
+ parser.add_argument("--output", type=str, default="baseline_results.json")
 
 
 
264
  args = parser.parse_args()
265
+
266
  print("=" * 60)
267
+ print("Code Review Agent")
268
  print("=" * 60)
269
+
270
  env = CodeReviewEnv()
271
  env.max_steps = args.max_steps
 
272
  agent = CodeReviewAgent()
273
+
274
  obs = env.reset(task_id=args.task_id)
275
  done = False
276
  step = 0
277
  total_reward = 0.0
278
+
279
+ print(f"\nTask : {args.task_id}")
280
+ print(f"Desc : {obs.get('task_description', 'N/A')}")
281
+ print(f"Model : {MODEL_NAME}")
282
  print("-" * 60)
283
+
284
  while not done and step < args.max_steps:
285
  action_str = agent.get_action(obs)
286
  action = agent.parse_action(action_str)
287
+ action = agent.validate_action(action, obs)
288
+
289
  obs, reward, done, info = env.step(action)
 
290
  total_reward += reward
291
  step += 1
292
+
293
  print(f"\nStep {step}/{args.max_steps}:")
294
+ print(f" Phase : {agent.phase - 1}")
295
+ print(f" Action : {action.get('action_type')}")
296
+ print(f" Comments : {len(action.get('comments', []))}")
297
+ print(f" Suggestions : {len(action.get('suggestions', []))}")
298
+ print(f" Reward : {reward:.3f}")
299
+ print(f" Total : {total_reward:.3f}")
300
+ print(f" Score : {info.get('task_score', 0):.3f}")
301
+
302
  if info.get('last_action_valid') is False:
303
+ print(f" Warning : {info.get('error', 'Invalid action')}")
304
+
305
  final_score = env.get_task_score()
306
+
307
  print("\n" + "=" * 60)
308
  print("Final Results:")
309
+ print(f" Task : {args.task_id}")
310
+ print(f" Total Reward : {total_reward:.3f}")
311
+ print(f" Task Score : {final_score:.3f}/1.0")
312
+ print(f" Steps : {step}")
313
  print("=" * 60)
314
+
315
  env.close()
316
+
317
  results = {
318
  "task_id": args.task_id,
319
+ "total_reward": round(total_reward, 4),
320
+ "task_score": round(final_score, 4),
321
  "steps": step,
322
  "max_steps": args.max_steps,
323
  "provider": "openai-client",
324
  "model": MODEL_NAME,
325
  "api_base_url": API_BASE_URL
326
  }
327
+
328
  with open(args.output, "w") as f:
329
  json.dump(results, f, indent=2)
330
+
331
  print(f"\nResults saved to {args.output}")
332
 
333
 
334
  if __name__ == "__main__":
335
+ main()