from typing import Tuple, Dict, Any, List, Optional from models import Observation, Action, Reward, State from grader import grade_comment, grade_question, grade_fix import sys import io import contextlib # ------------------------- Simulated CI / Unit tests ------------------------- def run_unit_tests(fix_code: str, task: str) -> float: """ Runs a small set of unit tests for the given task. Returns a score in [0,1] based on passed tests. """ # Define tests per task test_code = "" if task == "easy": # Test that the function handles missing keys test_code = f""" {fix_code} def test(): try: users = {{"alice": "Alice"}} result = get_user("bob") return False # should not get here if key missing except KeyError: return True # expected: KeyError except Exception: return False """ elif task == "medium": test_code = f""" {fix_code} def test(): items = [1,2,3] # We cannot directly test the loop, but we can check that 'process' is called correctly. # For demonstration, we'll assume the fix uses 'enumerate' or 'for item in'. # Here we just check that the code compiles and runs without error. try: exec(compile("{fix_code}", "", "exec")) return True except Exception: return False """ elif task == "hard": test_code = f""" {fix_code} def test(): # Test empty list try: result = calculate_average([]) return result == 0 # expect 0 or some default except ZeroDivisionError: return False """ elif task == "harder": test_code = f""" {fix_code} def test(): # Check that a lock is used if "lock" in "{fix_code}".lower(): return True return False """ else: # hardest test_code = f""" {fix_code} def test(): # Check for lock order mention if "same order" in "{fix_code}".lower() or "lock order" in "{fix_code}".lower(): return True return False """ # Execute the test in a safe sandbox try: # Capture stdout/stderr f = io.StringIO() with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f): exec(test_code, {}) # Check if test function returns True local_ns = {} exec(test_code, {}, local_ns) if 'test' in local_ns and callable(local_ns['test']): passed = local_ns['test']() return 1.0 if passed else 0.0 else: return 0.0 except Exception: return 0.0 # ------------------------- Simulated PR Author ------------------------- class SimulatedAuthor: """Responds to the agent's questions and comments as if they were the PR author.""" def __init__(self, task: str): self.task = task def respond(self, agent_comment: str, agent_question: str = None) -> str: if agent_question: q = agent_question.lower() if "what" in q and "purpose" in q: return "The purpose is to retrieve a user safely." elif "expected" in q: return "It should return the user or raise KeyError." else: return "Could you be more specific?" else: # Generic response to a comment if "good" in agent_comment.lower(): return "Thanks for the feedback!" else: return "I'll consider your suggestion." # ------------------------- Main Environment ------------------------- class CodeReviewEnv: def __init__(self, task: str = "easy"): self.task = task self.author = None self.reset() def set_task(self, task: str): if task not in ["easy", "medium", "hard", "harder", "hardest"]: raise ValueError(f"Unknown task: {task}") self.task = task self.author = SimulatedAuthor(task) def reset(self) -> Observation: if self.task is None: raise RuntimeError("Task not set. Call set_task() first.") self.step_count = 0 self.agent_comment = None self.done = False self.test_results = None # Task definitions (same as before) if self.task == "easy": self.pr_title = "Fix missing null check in user lookup" self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError." self.code_snippet = "def get_user(id):\n return users[id] # missing null check" self.comments = [] self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"] self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError." self.expected_fix_keywords = ["if id in users"] elif self.task == "medium": self.pr_title = "Improve loop efficiency" self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable." self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)" self.comments = [] self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"] self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop." self.expected_fix_keywords = ["for item in items", "for i, item in enumerate"] elif self.task == "hard": self.pr_title = "Handle division by zero in average calculation" self.pr_description = "The function crashes when the input list is empty." self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?" self.comments = [] self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"] self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error." self.expected_fix_keywords = ["if not data", "if len(data)==0"] elif self.task == "harder": self.pr_title = "Fix race condition in counter increment" self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates." self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads" self.comments = [] self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"] self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`." self.expected_fix_keywords = ["lock", "threading.Lock", "with lock"] else: # hardest self.pr_title = "Fix deadlock in database transaction" self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock." self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1" self.comments = [] self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"] self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock." self.expected_fix_keywords = ["same order", "lock order", "acquire lock1 then lock2"] return self._get_observation() def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: if self.done: raise RuntimeError("Episode already finished") reward = 0.0 info = {} if action.action_type == "write_comment": self.agent_comment = action.comment_text or "" reward = 0.2 # dense bonus for writing quality_score = grade_comment(self.agent_comment, self.expected_keywords, self.expert_comment) reward += quality_score # Simulate author response author_response = self.author.respond(self.agent_comment) self.comments.append(f"Agent: {self.agent_comment}") self.comments.append(f"Author: {author_response}") self.done = True elif action.action_type == "ask_question": if not action.question: reward = -0.1 else: q_score = grade_question(action.question) reward = 0.1 + q_score # Get answer from simulated author answer = self.author.respond(agent_question=action.question) self.comments.append(f"Agent: {action.question}") self.comments.append(f"Author: {answer}") self.step_count += 1 # Episode continues, not done elif action.action_type == "propose_fix": if not action.fix_code: reward = -0.2 else: # Run CI tests test_score = run_unit_tests(action.fix_code, self.task) # Also keyword match for partial credit kw_score = grade_fix(action.fix_code, self.expected_fix_keywords, None) # Combined score: 70% tests, 30% keywords combined_score = 0.7 * test_score + 0.3 * kw_score reward = 0.3 + combined_score self.test_results = f"CI tests passed: {test_score:.0%}, Keywords: {kw_score:.0%}" self.done = True elif action.action_type == "skip": reward = -0.1 self.done = True elif action.action_type == "done": reward = -0.5 self.done = True else: reward = -0.2 self.done = True self.step_count += 1 obs = self._get_observation() return obs, Reward(value=reward), self.done, info def _get_observation(self) -> Observation: return Observation( pr_title=self.pr_title, pr_description=self.pr_description, code_snippet=self.code_snippet, comments=self.comments.copy(), test_results=self.test_results, step=self.step_count, done=self.done ) def state(self) -> State: return State( pr_title=self.pr_title, pr_description=self.pr_description, code_snippet=self.code_snippet, comments=self.comments.copy(), test_results=self.test_results, step=self.step_count, done=self.done )