Spaces:
Sleeping
Sleeping
| from typing import Tuple, Dict, Any, List, Optional | |
| from models import Observation, Action, Reward, State | |
| from grader import grade_comment, grade_question, grade_fix | |
| import sys | |
| import io | |
| import contextlib | |
| # ------------------------- Simulated CI / Unit tests ------------------------- | |
| def run_unit_tests(fix_code: str, task: str) -> float: | |
| """ | |
| Runs a small set of unit tests for the given task. | |
| Returns a score in [0,1] based on passed tests. | |
| """ | |
| # Define tests per task | |
| test_code = "" | |
| if task == "easy": | |
| # Test that the function handles missing keys | |
| test_code = f""" | |
| {fix_code} | |
| def test(): | |
| try: | |
| users = {{"alice": "Alice"}} | |
| result = get_user("bob") | |
| return False # should not get here if key missing | |
| except KeyError: | |
| return True # expected: KeyError | |
| except Exception: | |
| return False | |
| """ | |
| elif task == "medium": | |
| test_code = f""" | |
| {fix_code} | |
| def test(): | |
| items = [1,2,3] | |
| # We cannot directly test the loop, but we can check that 'process' is called correctly. | |
| # For demonstration, we'll assume the fix uses 'enumerate' or 'for item in'. | |
| # Here we just check that the code compiles and runs without error. | |
| try: | |
| exec(compile("{fix_code}", "<string>", "exec")) | |
| return True | |
| except Exception: | |
| return False | |
| """ | |
| elif task == "hard": | |
| test_code = f""" | |
| {fix_code} | |
| def test(): | |
| # Test empty list | |
| try: | |
| result = calculate_average([]) | |
| return result == 0 # expect 0 or some default | |
| except ZeroDivisionError: | |
| return False | |
| """ | |
| elif task == "harder": | |
| test_code = f""" | |
| {fix_code} | |
| def test(): | |
| # Check that a lock is used | |
| if "lock" in "{fix_code}".lower(): | |
| return True | |
| return False | |
| """ | |
| else: # hardest | |
| test_code = f""" | |
| {fix_code} | |
| def test(): | |
| # Check for lock order mention | |
| if "same order" in "{fix_code}".lower() or "lock order" in "{fix_code}".lower(): | |
| return True | |
| return False | |
| """ | |
| # Execute the test in a safe sandbox | |
| try: | |
| # Capture stdout/stderr | |
| f = io.StringIO() | |
| with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f): | |
| exec(test_code, {}) | |
| # Check if test function returns True | |
| local_ns = {} | |
| exec(test_code, {}, local_ns) | |
| if 'test' in local_ns and callable(local_ns['test']): | |
| passed = local_ns['test']() | |
| return 1.0 if passed else 0.0 | |
| else: | |
| return 0.0 | |
| except Exception: | |
| return 0.0 | |
| # ------------------------- Simulated PR Author ------------------------- | |
| class SimulatedAuthor: | |
| """Responds to the agent's questions and comments as if they were the PR author.""" | |
| def __init__(self, task: str): | |
| self.task = task | |
| def respond(self, agent_comment: str, agent_question: str = None) -> str: | |
| if agent_question: | |
| q = agent_question.lower() | |
| if "what" in q and "purpose" in q: | |
| return "The purpose is to retrieve a user safely." | |
| elif "expected" in q: | |
| return "It should return the user or raise KeyError." | |
| else: | |
| return "Could you be more specific?" | |
| else: | |
| # Generic response to a comment | |
| if "good" in agent_comment.lower(): | |
| return "Thanks for the feedback!" | |
| else: | |
| return "I'll consider your suggestion." | |
| # ------------------------- Main Environment ------------------------- | |
| class CodeReviewEnv: | |
| def __init__(self, task: str = "easy"): | |
| self.task = task | |
| self.author = None | |
| self.reset() | |
| def set_task(self, task: str): | |
| if task not in ["easy", "medium", "hard", "harder", "hardest"]: | |
| raise ValueError(f"Unknown task: {task}") | |
| self.task = task | |
| self.author = SimulatedAuthor(task) | |
| def reset(self) -> Observation: | |
| if self.task is None: | |
| raise RuntimeError("Task not set. Call set_task() first.") | |
| self.step_count = 0 | |
| self.agent_comment = None | |
| self.done = False | |
| self.test_results = None | |
| # Task definitions (same as before) | |
| if self.task == "easy": | |
| self.pr_title = "Fix missing null check in user lookup" | |
| self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError." | |
| self.code_snippet = "def get_user(id):\n return users[id] # missing null check" | |
| self.comments = [] | |
| self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"] | |
| self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError." | |
| self.expected_fix_keywords = ["if id in users"] | |
| elif self.task == "medium": | |
| self.pr_title = "Improve loop efficiency" | |
| self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable." | |
| self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)" | |
| self.comments = [] | |
| self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"] | |
| self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop." | |
| self.expected_fix_keywords = ["for item in items", "for i, item in enumerate"] | |
| elif self.task == "hard": | |
| self.pr_title = "Handle division by zero in average calculation" | |
| self.pr_description = "The function crashes when the input list is empty." | |
| self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?" | |
| self.comments = [] | |
| self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"] | |
| self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error." | |
| self.expected_fix_keywords = ["if not data", "if len(data)==0"] | |
| elif self.task == "harder": | |
| self.pr_title = "Fix race condition in counter increment" | |
| self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates." | |
| self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads" | |
| self.comments = [] | |
| self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"] | |
| self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`." | |
| self.expected_fix_keywords = ["lock", "threading.Lock", "with lock"] | |
| else: # hardest | |
| self.pr_title = "Fix deadlock in database transaction" | |
| self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock." | |
| self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1" | |
| self.comments = [] | |
| self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"] | |
| self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock." | |
| self.expected_fix_keywords = ["same order", "lock order", "acquire lock1 then lock2"] | |
| return self._get_observation() | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: | |
| if self.done: | |
| raise RuntimeError("Episode already finished") | |
| reward = 0.0 | |
| info = {} | |
| if action.action_type == "write_comment": | |
| self.agent_comment = action.comment_text or "" | |
| reward = 0.2 # dense bonus for writing | |
| quality_score = grade_comment(self.agent_comment, self.expected_keywords, self.expert_comment) | |
| reward += quality_score | |
| # Simulate author response | |
| author_response = self.author.respond(self.agent_comment) | |
| self.comments.append(f"Agent: {self.agent_comment}") | |
| self.comments.append(f"Author: {author_response}") | |
| self.done = True | |
| elif action.action_type == "ask_question": | |
| if not action.question: | |
| reward = -0.1 | |
| else: | |
| q_score = grade_question(action.question) | |
| reward = 0.1 + q_score | |
| # Get answer from simulated author | |
| answer = self.author.respond(agent_question=action.question) | |
| self.comments.append(f"Agent: {action.question}") | |
| self.comments.append(f"Author: {answer}") | |
| self.step_count += 1 | |
| # Episode continues, not done | |
| elif action.action_type == "propose_fix": | |
| if not action.fix_code: | |
| reward = -0.2 | |
| else: | |
| # Run CI tests | |
| test_score = run_unit_tests(action.fix_code, self.task) | |
| # Also keyword match for partial credit | |
| kw_score = grade_fix(action.fix_code, self.expected_fix_keywords, None) | |
| # Combined score: 70% tests, 30% keywords | |
| combined_score = 0.7 * test_score + 0.3 * kw_score | |
| reward = 0.3 + combined_score | |
| self.test_results = f"CI tests passed: {test_score:.0%}, Keywords: {kw_score:.0%}" | |
| self.done = True | |
| elif action.action_type == "skip": | |
| reward = -0.1 | |
| self.done = True | |
| elif action.action_type == "done": | |
| reward = -0.5 | |
| self.done = True | |
| else: | |
| reward = -0.2 | |
| self.done = True | |
| self.step_count += 1 | |
| obs = self._get_observation() | |
| return obs, Reward(value=reward), self.done, info | |
| def _get_observation(self) -> Observation: | |
| return Observation( | |
| pr_title=self.pr_title, | |
| pr_description=self.pr_description, | |
| code_snippet=self.code_snippet, | |
| comments=self.comments.copy(), | |
| test_results=self.test_results, | |
| step=self.step_count, | |
| done=self.done | |
| ) | |
| def state(self) -> State: | |
| return State( | |
| pr_title=self.pr_title, | |
| pr_description=self.pr_description, | |
| code_snippet=self.code_snippet, | |
| comments=self.comments.copy(), | |
| test_results=self.test_results, | |
| step=self.step_count, | |
| done=self.done | |
| ) |