Spaces:
Sleeping
Sleeping
File size: 10,756 Bytes
8846f87 f9889e8 7694e55 8846f87 f9889e8 8846f87 f9889e8 8846f87 f9889e8 5fdf7b4 f9889e8 8846f87 f9889e8 7694e55 f9889e8 7694e55 f9889e8 5fdf7b4 7694e55 5fdf7b4 7694e55 f9889e8 5fdf7b4 7694e55 5fdf7b4 7694e55 f9889e8 5fdf7b4 7694e55 5fdf7b4 7694e55 5fdf7b4 7694e55 5fdf7b4 7694e55 5fdf7b4 7694e55 5fdf7b4 7694e55 f9889e8 5fdf7b4 59a2092 f9889e8 8846f87 f9889e8 7694e55 8846f87 7694e55 8846f87 7694e55 8846f87 7694e55 f9889e8 5fdf7b4 7694e55 f9889e8 5fdf7b4 7694e55 f9889e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | from typing import Tuple, Dict, Any, List, Optional
from models import Observation, Action, Reward, State
from grader import grade_comment, grade_question, grade_fix
import sys
import io
import contextlib
# ------------------------- Simulated CI / Unit tests -------------------------
def run_unit_tests(fix_code: str, task: str) -> float:
"""
Runs a small set of unit tests for the given task.
Returns a score in [0,1] based on passed tests.
"""
# Define tests per task
test_code = ""
if task == "easy":
# Test that the function handles missing keys
test_code = f"""
{fix_code}
def test():
try:
users = {{"alice": "Alice"}}
result = get_user("bob")
return False # should not get here if key missing
except KeyError:
return True # expected: KeyError
except Exception:
return False
"""
elif task == "medium":
test_code = f"""
{fix_code}
def test():
items = [1,2,3]
# We cannot directly test the loop, but we can check that 'process' is called correctly.
# For demonstration, we'll assume the fix uses 'enumerate' or 'for item in'.
# Here we just check that the code compiles and runs without error.
try:
exec(compile("{fix_code}", "<string>", "exec"))
return True
except Exception:
return False
"""
elif task == "hard":
test_code = f"""
{fix_code}
def test():
# Test empty list
try:
result = calculate_average([])
return result == 0 # expect 0 or some default
except ZeroDivisionError:
return False
"""
elif task == "harder":
test_code = f"""
{fix_code}
def test():
# Check that a lock is used
if "lock" in "{fix_code}".lower():
return True
return False
"""
else: # hardest
test_code = f"""
{fix_code}
def test():
# Check for lock order mention
if "same order" in "{fix_code}".lower() or "lock order" in "{fix_code}".lower():
return True
return False
"""
# Execute the test in a safe sandbox
try:
# Capture stdout/stderr
f = io.StringIO()
with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
exec(test_code, {})
# Check if test function returns True
local_ns = {}
exec(test_code, {}, local_ns)
if 'test' in local_ns and callable(local_ns['test']):
passed = local_ns['test']()
return 1.0 if passed else 0.0
else:
return 0.0
except Exception:
return 0.0
# ------------------------- Simulated PR Author -------------------------
class SimulatedAuthor:
"""Responds to the agent's questions and comments as if they were the PR author."""
def __init__(self, task: str):
self.task = task
def respond(self, agent_comment: str, agent_question: str = None) -> str:
if agent_question:
q = agent_question.lower()
if "what" in q and "purpose" in q:
return "The purpose is to retrieve a user safely."
elif "expected" in q:
return "It should return the user or raise KeyError."
else:
return "Could you be more specific?"
else:
# Generic response to a comment
if "good" in agent_comment.lower():
return "Thanks for the feedback!"
else:
return "I'll consider your suggestion."
# ------------------------- Main Environment -------------------------
class CodeReviewEnv:
def __init__(self, task: str = "easy"):
self.task = task
self.author = None
self.reset()
def set_task(self, task: str):
if task not in ["easy", "medium", "hard", "harder", "hardest"]:
raise ValueError(f"Unknown task: {task}")
self.task = task
self.author = SimulatedAuthor(task)
def reset(self) -> Observation:
if self.task is None:
raise RuntimeError("Task not set. Call set_task() first.")
self.step_count = 0
self.agent_comment = None
self.done = False
self.test_results = None
# Task definitions (same as before)
if self.task == "easy":
self.pr_title = "Fix missing null check in user lookup"
self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError."
self.code_snippet = "def get_user(id):\n return users[id] # missing null check"
self.comments = []
self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"]
self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError."
self.expected_fix_keywords = ["if id in users"]
elif self.task == "medium":
self.pr_title = "Improve loop efficiency"
self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable."
self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
self.comments = []
self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop."
self.expected_fix_keywords = ["for item in items", "for i, item in enumerate"]
elif self.task == "hard":
self.pr_title = "Handle division by zero in average calculation"
self.pr_description = "The function crashes when the input list is empty."
self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
self.comments = []
self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"]
self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error."
self.expected_fix_keywords = ["if not data", "if len(data)==0"]
elif self.task == "harder":
self.pr_title = "Fix race condition in counter increment"
self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates."
self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads"
self.comments = []
self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"]
self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`."
self.expected_fix_keywords = ["lock", "threading.Lock", "with lock"]
else: # hardest
self.pr_title = "Fix deadlock in database transaction"
self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock."
self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1"
self.comments = []
self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"]
self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock."
self.expected_fix_keywords = ["same order", "lock order", "acquire lock1 then lock2"]
return self._get_observation()
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
if self.done:
raise RuntimeError("Episode already finished")
reward = 0.0
info = {}
if action.action_type == "write_comment":
self.agent_comment = action.comment_text or ""
reward = 0.2 # dense bonus for writing
quality_score = grade_comment(self.agent_comment, self.expected_keywords, self.expert_comment)
reward += quality_score
# Simulate author response
author_response = self.author.respond(self.agent_comment)
self.comments.append(f"Agent: {self.agent_comment}")
self.comments.append(f"Author: {author_response}")
self.done = True
elif action.action_type == "ask_question":
if not action.question:
reward = -0.1
else:
q_score = grade_question(action.question)
reward = 0.1 + q_score
# Get answer from simulated author
answer = self.author.respond(agent_question=action.question)
self.comments.append(f"Agent: {action.question}")
self.comments.append(f"Author: {answer}")
self.step_count += 1
# Episode continues, not done
elif action.action_type == "propose_fix":
if not action.fix_code:
reward = -0.2
else:
# Run CI tests
test_score = run_unit_tests(action.fix_code, self.task)
# Also keyword match for partial credit
kw_score = grade_fix(action.fix_code, self.expected_fix_keywords, None)
# Combined score: 70% tests, 30% keywords
combined_score = 0.7 * test_score + 0.3 * kw_score
reward = 0.3 + combined_score
self.test_results = f"CI tests passed: {test_score:.0%}, Keywords: {kw_score:.0%}"
self.done = True
elif action.action_type == "skip":
reward = -0.1
self.done = True
elif action.action_type == "done":
reward = -0.5
self.done = True
else:
reward = -0.2
self.done = True
self.step_count += 1
obs = self._get_observation()
return obs, Reward(value=reward), self.done, info
def _get_observation(self) -> Observation:
return Observation(
pr_title=self.pr_title,
pr_description=self.pr_description,
code_snippet=self.code_snippet,
comments=self.comments.copy(),
test_results=self.test_results,
step=self.step_count,
done=self.done
)
def state(self) -> State:
return State(
pr_title=self.pr_title,
pr_description=self.pr_description,
code_snippet=self.code_snippet,
comments=self.comments.copy(),
test_results=self.test_results,
step=self.step_count,
done=self.done
) |