| |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| import os |
| import json |
| import argparse |
| import sys |
| from typing import Dict, Any |
| from openai import OpenAI |
|
|
|
|
| def resolve_api_key() -> str: |
| |
| return ( |
| (os.environ.get("API_KEY") or "").strip() |
| or (os.environ.get("HF_TOKEN") or "").strip() |
| or (os.environ.get("OPENAI_API_KEY") or "").strip() |
| ) |
|
|
|
|
| API_BASE_URL = os.environ.get("API_BASE_URL", "") |
| MODEL_NAME = os.environ.get("MODEL_NAME", "") |
| API_KEY = resolve_api_key() |
| TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.7")) |
| MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2000")) |
| REQUEST_TIMEOUT = int(os.environ.get("REQUEST_TIMEOUT", "60")) |
|
|
| if not API_BASE_URL: |
| print("=" * 60) |
| print("API Configuration Required") |
| print("=" * 60) |
| print("\nPlease set the following environment variables:\n") |
| print(" API_BASE_URL - OpenAI-compatible API endpoint") |
| print(" MODEL_NAME - Model identifier") |
| print(" API_KEY - API key (canonical)\n") |
| print("Supported auth aliases (backward compatibility):") |
| print(" HF_TOKEN") |
| print(" OPENAI_API_KEY\n") |
| print("Examples:\n") |
| print(" OpenAI:") |
| print(" export API_BASE_URL=https://api.openai.com/v1") |
| print(" export MODEL_NAME=gpt-4o-mini") |
| print(" export API_KEY=sk-xxxxx\n") |
| print(" Groq:") |
| print(" export API_BASE_URL=https://api.groq.com/openai/v1") |
| print(" export MODEL_NAME=llama-3.3-70b-versatile") |
| print(" export API_KEY=gsk_xxxxx\n") |
| print(" Local Ollama:") |
| print(" export API_BASE_URL=http://localhost:11434/v1") |
| print(" export MODEL_NAME=llama3") |
| print(" export API_KEY=not-needed\n") |
| print("=" * 60) |
| sys.exit(1) |
|
|
| if not MODEL_NAME: |
| print("ERROR: MODEL_NAME environment variable is required") |
| sys.exit(1) |
|
|
| if not API_KEY: |
| print("ERROR: Missing auth token. Set API_KEY (preferred), or HF_TOKEN/OPENAI_API_KEY.") |
| sys.exit(1) |
|
|
| FALLBACK_ACTION = json.dumps({ |
| "action_type": "request_changes", |
| "comments": [], |
| "suggestions": [], |
| "final_decision": "changes_requested" |
| }) |
|
|
|
|
| def add_line_numbers(code: str) -> str: |
| lines = code.split("\n") |
| return "\n".join(f"{i+1}: {line}" for i, line in enumerate(lines)) |
|
|
|
|
| class LLMClient: |
|
|
| def __init__(self, base_url: str, api_key: str, model: str): |
| self.base_url = base_url.rstrip("/") |
| self.api_key = api_key |
| self.model = model |
| self.client = OpenAI( |
| base_url=self.base_url, |
| api_key=self.api_key, |
| timeout=REQUEST_TIMEOUT |
| ) |
| print("Connected using OpenAI client") |
| print(f"Endpoint: {self.base_url}") |
| print(f"Model: {self.model}\n") |
|
|
| def chat_completion(self, messages: list, temperature: float = 0.7, max_tokens: int = 2000) -> str: |
| completion = self.client.chat.completions.create( |
| model=self.model, |
| messages=messages, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| stream=False, |
| ) |
| return completion.choices[0].message.content or "" |
|
|
|
|
| class CodeReviewAgent: |
|
|
| def __init__(self): |
| self.client = LLMClient(API_BASE_URL, API_KEY, MODEL_NAME) |
| self.history = [] |
| self.phase = 1 |
|
|
| def get_action(self, observation: Dict[str, Any]) -> str: |
|
|
| system_prompt = """You are an expert code reviewer. You MUST follow this exact sequence: |
| |
| PHASE 1 - Add Comments: Use action_type "add_comment" to identify ALL bugs with exact line numbers |
| PHASE 2 - Suggest Fixes: Use action_type "suggest_fix" to provide fixes for every bug found |
| PHASE 3 - Final Decision: Use action_type "request_changes" with final_decision "changes_requested" |
| |
| RULES: |
| - NEVER skip straight to approve or request_changes without first adding comments and suggestions |
| - NEVER combine phases - each action should do ONE thing |
| - ALWAYS use the exact line numbers shown in the code diff |
| - ALWAYS set severity for comments: "critical", "high", "medium", or "low" |
| - If no bugs found in Phase 1, skip to Phase 3 with "approved" |
| |
| Respond ONLY with a valid JSON object, no extra text: |
| { |
| "action_type": "add_comment" | "suggest_fix" | "approve" | "request_changes", |
| "comments": [ |
| { |
| "line_number": <exact line number>, |
| "content": "Detailed explanation of the bug", |
| "is_issue": true, |
| "severity": "critical" | "high" | "medium" | "low" |
| } |
| ], |
| "suggestions": [ |
| { |
| "original_line": <exact line number>, |
| "suggested_code": "corrected code here", |
| "explanation": "why this fix works" |
| } |
| ], |
| "final_decision": "approved" | "changes_requested" |
| }""" |
|
|
| prev_comments = observation.get('previous_comments', []) |
| prev_suggestions = observation.get('previous_suggestions', []) |
|
|
| comments_text = "\n".join([ |
| f" Line {c.get('line_number') if isinstance(c, dict) else c.line_number}: " |
| f"{c.get('content') if isinstance(c, dict) else c.content}" |
| for c in prev_comments |
| ]) or "None yet" |
|
|
| suggestions_text = "\n".join([ |
| f" Line {s.get('original_line') if isinstance(s, dict) else s.original_line}: " |
| f"{s.get('suggested_code') if isinstance(s, dict) else s.suggested_code}" |
| for s in prev_suggestions |
| ]) or "None yet" |
|
|
| if self.phase == 1: |
| phase_instruction = """ |
| YOUR TASK NOW (Phase 1 - Add Comments): |
| - action_type MUST be "add_comment" |
| - Carefully read the code diff line by line |
| - Find ALL bugs, vulnerabilities, or issues |
| - Comment on each one with the EXACT line number shown |
| - Do NOT make a final decision yet |
| - Do NOT suggest fixes yet |
| """ |
| elif self.phase == 2: |
| phase_instruction = """ |
| YOUR TASK NOW (Phase 2 - Suggest Fixes): |
| - action_type MUST be "suggest_fix" |
| - For every bug you commented on, provide a concrete code fix |
| - Use the same line numbers as your comments |
| - Do NOT make a final decision yet |
| """ |
| else: |
| phase_instruction = """ |
| YOUR TASK NOW (Phase 3 - Final Decision): |
| - action_type MUST be "request_changes" |
| - Set final_decision to "changes_requested" |
| - No new comments or suggestions needed |
| """ |
|
|
| user_prompt = f""" |
| Code Review Task: |
| {observation.get('task_description', 'Review the following code changes')} |
| |
| Code Diff (USE THESE EXACT LINE NUMBERS in your response): |
| {add_line_numbers(observation.get('code_diff', ''))} |
| |
| File Context: |
| {observation.get('file_context', '')} |
| |
| Current Step: {observation.get('current_step', 0)}/{observation.get('max_steps', 50)} |
| |
| Comments already made: |
| {comments_text} |
| |
| Suggestions already made: |
| {suggestions_text} |
| |
| {phase_instruction} |
| |
| Respond with JSON only. |
| """ |
|
|
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt} |
| ] |
|
|
| try: |
| response = self.client.chat_completion(messages, TEMPERATURE, MAX_TOKENS) |
| response = response.strip() |
|
|
| if "```json" in response: |
| response = response.split("```json")[1].split("```")[0] |
| elif "```" in response: |
| response = response.split("```")[1].split("```")[0] |
|
|
| action_data = json.loads(response.strip()) |
|
|
| if "action_type" not in action_data: |
| action_data["action_type"] = "request_changes" |
| if "comments" not in action_data: |
| action_data["comments"] = [] |
| if "suggestions" not in action_data: |
| action_data["suggestions"] = [] |
|
|
| self.phase += 1 |
| return json.dumps(action_data) |
|
|
| except json.JSONDecodeError as e: |
| print(f"Failed to parse JSON response: {e}") |
| print(f"Raw response: {response[:200]}...") |
| self.phase += 1 |
| return FALLBACK_ACTION |
| except Exception as e: |
| print(f"Error getting action from LLM: {e}") |
| return FALLBACK_ACTION |
|
|
| def validate_action(self, action: Dict, observation: Dict) -> Dict: |
| line_count = observation.get('line_count', 999) |
|
|
| for comment in action.get("comments", []): |
| comment["line_number"] = max(1, min(comment.get("line_number", 1), line_count)) |
| if not comment.get("severity"): |
| comment["severity"] = "medium" |
| if "is_issue" not in comment: |
| comment["is_issue"] = True |
|
|
| for suggestion in action.get("suggestions", []): |
| suggestion["original_line"] = max(1, min(suggestion.get("original_line", 1), line_count)) |
|
|
| return action |
|
|
| def parse_action(self, action_str: str) -> Dict[str, Any]: |
| try: |
| return json.loads(action_str) |
| except json.JSONDecodeError: |
| return {"action_type": "request_changes", "comments": [], "suggestions": []} |
|
|
|
|
| def main(): |
| sys.path.append('.') |
|
|
| try: |
| from environment.env import CodeReviewEnv |
| except ImportError as e: |
| print(f"Failed to import environment: {e}") |
| print("Make sure you're in the correct directory and environment is installed.") |
| sys.exit(1) |
|
|
| parser = argparse.ArgumentParser(description="Run code review agent") |
| parser.add_argument("--task-id", type=str, default="bug_detection_easy_1") |
| parser.add_argument("--max-steps", type=int, default=50) |
| parser.add_argument("--output", type=str, default="baseline_results.json") |
| args = parser.parse_args() |
|
|
| print("=" * 60) |
| print("Code Review Agent") |
| print("=" * 60) |
|
|
| env = CodeReviewEnv() |
| env.max_steps = args.max_steps |
| agent = CodeReviewAgent() |
|
|
| obs = env.reset(task_id=args.task_id) |
| done = False |
| step = 0 |
| total_reward = 0.0 |
|
|
| print(f"\nTask : {args.task_id}") |
| print(f"Desc : {obs.get('task_description', 'N/A')}") |
| print(f"Model : {MODEL_NAME}") |
| print("-" * 60) |
|
|
| while not done and step < args.max_steps: |
| action_str = agent.get_action(obs) |
| action = agent.parse_action(action_str) |
| action = agent.validate_action(action, obs) |
|
|
| obs, reward, done, info = env.step(action) |
| total_reward += reward |
| step += 1 |
|
|
| print(f"\nStep {step}/{args.max_steps}:") |
| print(f" Phase : {agent.phase - 1}") |
| print(f" Action : {action.get('action_type')}") |
| print(f" Comments : {len(action.get('comments', []))}") |
| print(f" Suggestions : {len(action.get('suggestions', []))}") |
| print(f" Reward : {reward:.3f}") |
| print(f" Total : {total_reward:.3f}") |
| print(f" Score : {info.get('task_score', 0):.3f}") |
|
|
| if info.get('last_action_valid') is False: |
| print(f" Warning : {info.get('error', 'Invalid action')}") |
|
|
| final_score = env.get_task_score() |
|
|
| print("\n" + "=" * 60) |
| print("Final Results:") |
| print(f" Task : {args.task_id}") |
| print(f" Total Reward : {total_reward:.3f}") |
| print(f" Task Score : {final_score:.3f}/1.0") |
| print(f" Steps : {step}") |
| print("=" * 60) |
|
|
| env.close() |
|
|
| results = { |
| "task_id": args.task_id, |
| "total_reward": round(total_reward, 4), |
| "task_score": round(final_score, 4), |
| "steps": step, |
| "max_steps": args.max_steps, |
| "provider": "openai-client", |
| "model": MODEL_NAME, |
| "api_base_url": API_BASE_URL |
| } |
|
|
| with open(args.output, "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print(f"\nResults saved to {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|