code-review-openenv / environment.py
100XZX001's picture
Update environment.py
8846f87 verified
from typing import Tuple, Dict, Any, List, Optional
from models import Observation, Action, Reward, State
from grader import grade_comment, grade_question, grade_fix
import sys
import io
import contextlib
# ------------------------- Simulated CI / Unit tests -------------------------
def run_unit_tests(fix_code: str, task: str) -> float:
"""
Runs a small set of unit tests for the given task.
Returns a score in [0,1] based on passed tests.
"""
# Define tests per task
test_code = ""
if task == "easy":
# Test that the function handles missing keys
test_code = f"""
{fix_code}
def test():
try:
users = {{"alice": "Alice"}}
result = get_user("bob")
return False # should not get here if key missing
except KeyError:
return True # expected: KeyError
except Exception:
return False
"""
elif task == "medium":
test_code = f"""
{fix_code}
def test():
items = [1,2,3]
# We cannot directly test the loop, but we can check that 'process' is called correctly.
# For demonstration, we'll assume the fix uses 'enumerate' or 'for item in'.
# Here we just check that the code compiles and runs without error.
try:
exec(compile("{fix_code}", "<string>", "exec"))
return True
except Exception:
return False
"""
elif task == "hard":
test_code = f"""
{fix_code}
def test():
# Test empty list
try:
result = calculate_average([])
return result == 0 # expect 0 or some default
except ZeroDivisionError:
return False
"""
elif task == "harder":
test_code = f"""
{fix_code}
def test():
# Check that a lock is used
if "lock" in "{fix_code}".lower():
return True
return False
"""
else: # hardest
test_code = f"""
{fix_code}
def test():
# Check for lock order mention
if "same order" in "{fix_code}".lower() or "lock order" in "{fix_code}".lower():
return True
return False
"""
# Execute the test in a safe sandbox
try:
# Capture stdout/stderr
f = io.StringIO()
with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
exec(test_code, {})
# Check if test function returns True
local_ns = {}
exec(test_code, {}, local_ns)
if 'test' in local_ns and callable(local_ns['test']):
passed = local_ns['test']()
return 1.0 if passed else 0.0
else:
return 0.0
except Exception:
return 0.0
# ------------------------- Simulated PR Author -------------------------
class SimulatedAuthor:
"""Responds to the agent's questions and comments as if they were the PR author."""
def __init__(self, task: str):
self.task = task
def respond(self, agent_comment: str, agent_question: str = None) -> str:
if agent_question:
q = agent_question.lower()
if "what" in q and "purpose" in q:
return "The purpose is to retrieve a user safely."
elif "expected" in q:
return "It should return the user or raise KeyError."
else:
return "Could you be more specific?"
else:
# Generic response to a comment
if "good" in agent_comment.lower():
return "Thanks for the feedback!"
else:
return "I'll consider your suggestion."
# ------------------------- Main Environment -------------------------
class CodeReviewEnv:
def __init__(self, task: str = "easy"):
self.task = task
self.author = None
self.reset()
def set_task(self, task: str):
if task not in ["easy", "medium", "hard", "harder", "hardest"]:
raise ValueError(f"Unknown task: {task}")
self.task = task
self.author = SimulatedAuthor(task)
def reset(self) -> Observation:
if self.task is None:
raise RuntimeError("Task not set. Call set_task() first.")
self.step_count = 0
self.agent_comment = None
self.done = False
self.test_results = None
# Task definitions (same as before)
if self.task == "easy":
self.pr_title = "Fix missing null check in user lookup"
self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError."
self.code_snippet = "def get_user(id):\n return users[id] # missing null check"
self.comments = []
self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"]
self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError."
self.expected_fix_keywords = ["if id in users"]
elif self.task == "medium":
self.pr_title = "Improve loop efficiency"
self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable."
self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
self.comments = []
self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop."
self.expected_fix_keywords = ["for item in items", "for i, item in enumerate"]
elif self.task == "hard":
self.pr_title = "Handle division by zero in average calculation"
self.pr_description = "The function crashes when the input list is empty."
self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
self.comments = []
self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"]
self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error."
self.expected_fix_keywords = ["if not data", "if len(data)==0"]
elif self.task == "harder":
self.pr_title = "Fix race condition in counter increment"
self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates."
self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads"
self.comments = []
self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"]
self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`."
self.expected_fix_keywords = ["lock", "threading.Lock", "with lock"]
else: # hardest
self.pr_title = "Fix deadlock in database transaction"
self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock."
self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1"
self.comments = []
self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"]
self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock."
self.expected_fix_keywords = ["same order", "lock order", "acquire lock1 then lock2"]
return self._get_observation()
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
if self.done:
raise RuntimeError("Episode already finished")
reward = 0.0
info = {}
if action.action_type == "write_comment":
self.agent_comment = action.comment_text or ""
reward = 0.2 # dense bonus for writing
quality_score = grade_comment(self.agent_comment, self.expected_keywords, self.expert_comment)
reward += quality_score
# Simulate author response
author_response = self.author.respond(self.agent_comment)
self.comments.append(f"Agent: {self.agent_comment}")
self.comments.append(f"Author: {author_response}")
self.done = True
elif action.action_type == "ask_question":
if not action.question:
reward = -0.1
else:
q_score = grade_question(action.question)
reward = 0.1 + q_score
# Get answer from simulated author
answer = self.author.respond(agent_question=action.question)
self.comments.append(f"Agent: {action.question}")
self.comments.append(f"Author: {answer}")
self.step_count += 1
# Episode continues, not done
elif action.action_type == "propose_fix":
if not action.fix_code:
reward = -0.2
else:
# Run CI tests
test_score = run_unit_tests(action.fix_code, self.task)
# Also keyword match for partial credit
kw_score = grade_fix(action.fix_code, self.expected_fix_keywords, None)
# Combined score: 70% tests, 30% keywords
combined_score = 0.7 * test_score + 0.3 * kw_score
reward = 0.3 + combined_score
self.test_results = f"CI tests passed: {test_score:.0%}, Keywords: {kw_score:.0%}"
self.done = True
elif action.action_type == "skip":
reward = -0.1
self.done = True
elif action.action_type == "done":
reward = -0.5
self.done = True
else:
reward = -0.2
self.done = True
self.step_count += 1
obs = self._get_observation()
return obs, Reward(value=reward), self.done, info
def _get_observation(self) -> Observation:
return Observation(
pr_title=self.pr_title,
pr_description=self.pr_description,
code_snippet=self.code_snippet,
comments=self.comments.copy(),
test_results=self.test_results,
step=self.step_count,
done=self.done
)
def state(self) -> State:
return State(
pr_title=self.pr_title,
pr_description=self.pr_description,
code_snippet=self.code_snippet,
comments=self.comments.copy(),
test_results=self.test_results,
step=self.step_count,
done=self.done
)