Spaces:
Sleeping
Sleeping
Update environment.py
Browse files- environment.py +55 -23
environment.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from typing import Tuple, Dict, Any
|
| 2 |
from models import Observation, Action, Reward, State
|
| 3 |
-
from grader import grade_comment
|
| 4 |
|
| 5 |
class CodeReviewEnv:
|
| 6 |
def __init__(self, task: str = "easy"):
|
|
@@ -18,47 +18,49 @@ class CodeReviewEnv:
|
|
| 18 |
self.step_count = 0
|
| 19 |
self.agent_comment = None
|
| 20 |
self.done = False
|
|
|
|
| 21 |
|
| 22 |
-
# Task definitions
|
| 23 |
if self.task == "easy":
|
| 24 |
self.pr_title = "Fix missing null check in user lookup"
|
| 25 |
self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError."
|
| 26 |
self.code_snippet = "def get_user(id):\n return users[id] # missing null check"
|
| 27 |
-
self.comments = [
|
| 28 |
self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"]
|
| 29 |
self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError."
|
| 30 |
-
|
| 31 |
elif self.task == "medium":
|
| 32 |
self.pr_title = "Improve loop efficiency"
|
| 33 |
self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable."
|
| 34 |
self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
|
| 35 |
-
self.comments = [
|
| 36 |
self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
|
| 37 |
self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop."
|
| 38 |
-
|
| 39 |
elif self.task == "hard":
|
| 40 |
self.pr_title = "Handle division by zero in average calculation"
|
| 41 |
self.pr_description = "The function crashes when the input list is empty."
|
| 42 |
self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
|
| 43 |
-
self.comments = [
|
| 44 |
self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"]
|
| 45 |
self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error."
|
| 46 |
-
|
| 47 |
elif self.task == "harder":
|
| 48 |
self.pr_title = "Fix race condition in counter increment"
|
| 49 |
self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates."
|
| 50 |
self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads"
|
| 51 |
-
self.comments = [
|
| 52 |
self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"]
|
| 53 |
self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`."
|
| 54 |
-
|
| 55 |
else: # hardest
|
| 56 |
self.pr_title = "Fix deadlock in database transaction"
|
| 57 |
self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock."
|
| 58 |
self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1"
|
| 59 |
-
self.comments = [
|
| 60 |
self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"]
|
| 61 |
self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock."
|
|
|
|
| 62 |
|
| 63 |
return self._get_observation()
|
| 64 |
|
|
@@ -72,16 +74,34 @@ class CodeReviewEnv:
|
|
| 72 |
if action.action_type == "write_comment":
|
| 73 |
self.agent_comment = action.comment_text or ""
|
| 74 |
reward = 0.2 # dense bonus for writing
|
| 75 |
-
|
| 76 |
-
quality_score = grade_comment(
|
| 77 |
-
self.agent_comment,
|
| 78 |
-
self.expected_keywords,
|
| 79 |
-
self.expert_comment
|
| 80 |
-
)
|
| 81 |
reward += quality_score
|
| 82 |
-
if reward > 1.0:
|
| 83 |
-
reward = 1.0
|
| 84 |
self.done = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
elif action.action_type == "skip":
|
| 86 |
reward = -0.1
|
| 87 |
self.done = True
|
|
@@ -96,13 +116,25 @@ class CodeReviewEnv:
|
|
| 96 |
obs = self._get_observation()
|
| 97 |
return obs, Reward(value=reward), self.done, info
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def _get_observation(self) -> Observation:
|
| 100 |
return Observation(
|
| 101 |
pr_title=self.pr_title,
|
| 102 |
pr_description=self.pr_description,
|
| 103 |
code_snippet=self.code_snippet,
|
| 104 |
-
comments=self.comments,
|
| 105 |
-
|
| 106 |
step=self.step_count,
|
| 107 |
done=self.done
|
| 108 |
)
|
|
@@ -112,8 +144,8 @@ class CodeReviewEnv:
|
|
| 112 |
pr_title=self.pr_title,
|
| 113 |
pr_description=self.pr_description,
|
| 114 |
code_snippet=self.code_snippet,
|
| 115 |
-
comments=self.comments,
|
| 116 |
-
|
| 117 |
step=self.step_count,
|
| 118 |
done=self.done
|
| 119 |
)
|
|
|
|
| 1 |
from typing import Tuple, Dict, Any
|
| 2 |
from models import Observation, Action, Reward, State
|
| 3 |
+
from grader import grade_comment, grade_question, grade_fix
|
| 4 |
|
| 5 |
class CodeReviewEnv:
|
| 6 |
def __init__(self, task: str = "easy"):
|
|
|
|
| 18 |
self.step_count = 0
|
| 19 |
self.agent_comment = None
|
| 20 |
self.done = False
|
| 21 |
+
self.test_results = None
|
| 22 |
|
| 23 |
+
# Task definitions (same as before)
|
| 24 |
if self.task == "easy":
|
| 25 |
self.pr_title = "Fix missing null check in user lookup"
|
| 26 |
self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError."
|
| 27 |
self.code_snippet = "def get_user(id):\n return users[id] # missing null check"
|
| 28 |
+
self.comments = []
|
| 29 |
self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"]
|
| 30 |
self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError."
|
| 31 |
+
self.expected_fix_keywords = ["if id in users"]
|
| 32 |
elif self.task == "medium":
|
| 33 |
self.pr_title = "Improve loop efficiency"
|
| 34 |
self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable."
|
| 35 |
self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
|
| 36 |
+
self.comments = []
|
| 37 |
self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
|
| 38 |
self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop."
|
| 39 |
+
self.expected_fix_keywords = ["for item in items", "for i, item in enumerate"]
|
| 40 |
elif self.task == "hard":
|
| 41 |
self.pr_title = "Handle division by zero in average calculation"
|
| 42 |
self.pr_description = "The function crashes when the input list is empty."
|
| 43 |
self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
|
| 44 |
+
self.comments = []
|
| 45 |
self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"]
|
| 46 |
self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error."
|
| 47 |
+
self.expected_fix_keywords = ["if not data", "if len(data)==0"]
|
| 48 |
elif self.task == "harder":
|
| 49 |
self.pr_title = "Fix race condition in counter increment"
|
| 50 |
self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates."
|
| 51 |
self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads"
|
| 52 |
+
self.comments = []
|
| 53 |
self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"]
|
| 54 |
self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`."
|
| 55 |
+
self.expected_fix_keywords = ["lock", "threading.Lock", "with lock"]
|
| 56 |
else: # hardest
|
| 57 |
self.pr_title = "Fix deadlock in database transaction"
|
| 58 |
self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock."
|
| 59 |
self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1"
|
| 60 |
+
self.comments = []
|
| 61 |
self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"]
|
| 62 |
self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock."
|
| 63 |
+
self.expected_fix_keywords = ["same order", "lock order", "acquire lock1 then lock2"]
|
| 64 |
|
| 65 |
return self._get_observation()
|
| 66 |
|
|
|
|
| 74 |
if action.action_type == "write_comment":
|
| 75 |
self.agent_comment = action.comment_text or ""
|
| 76 |
reward = 0.2 # dense bonus for writing
|
| 77 |
+
quality_score = grade_comment(self.agent_comment, self.expected_keywords, self.task)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
reward += quality_score
|
|
|
|
|
|
|
| 79 |
self.done = True
|
| 80 |
+
|
| 81 |
+
elif action.action_type == "ask_question":
|
| 82 |
+
if not action.question:
|
| 83 |
+
reward = -0.1
|
| 84 |
+
else:
|
| 85 |
+
q_score = grade_question(action.question)
|
| 86 |
+
reward = 0.1 + q_score # small bonus + quality
|
| 87 |
+
# Simulate a helpful answer
|
| 88 |
+
answer = self._answer_question(action.question)
|
| 89 |
+
self.comments.append(f"Agent: {action.question}")
|
| 90 |
+
self.comments.append(f"Env: {answer}")
|
| 91 |
+
self.step_count += 1
|
| 92 |
+
# Episode continues, not done
|
| 93 |
+
|
| 94 |
+
elif action.action_type == "propose_fix":
|
| 95 |
+
if not action.fix_code:
|
| 96 |
+
reward = -0.2
|
| 97 |
+
else:
|
| 98 |
+
# We'll use a simple keyword check for demonstration
|
| 99 |
+
# In a full version, you'd run unit tests
|
| 100 |
+
fix_score = grade_fix(action.fix_code, self.expected_fix_keywords, None)
|
| 101 |
+
reward = 0.3 + fix_score
|
| 102 |
+
self.test_results = f"Fix evaluated with score {fix_score:.2f}"
|
| 103 |
+
self.done = True
|
| 104 |
+
|
| 105 |
elif action.action_type == "skip":
|
| 106 |
reward = -0.1
|
| 107 |
self.done = True
|
|
|
|
| 116 |
obs = self._get_observation()
|
| 117 |
return obs, Reward(value=reward), self.done, info
|
| 118 |
|
| 119 |
+
def _answer_question(self, question: str) -> str:
|
| 120 |
+
# Simple rule‑based answers – you can expand
|
| 121 |
+
q = question.lower()
|
| 122 |
+
if "what" in q and "purpose" in q:
|
| 123 |
+
return "The purpose of this function is to retrieve a user by ID from a dictionary."
|
| 124 |
+
elif "expected" in q:
|
| 125 |
+
return "The function should return the user object if the ID exists, otherwise raise a KeyError."
|
| 126 |
+
elif "how" in q and "fix" in q:
|
| 127 |
+
return "You might consider adding a check for missing keys or using a safer dictionary method like `get`."
|
| 128 |
+
else:
|
| 129 |
+
return "I'm not sure. Could you be more specific?"
|
| 130 |
+
|
| 131 |
def _get_observation(self) -> Observation:
|
| 132 |
return Observation(
|
| 133 |
pr_title=self.pr_title,
|
| 134 |
pr_description=self.pr_description,
|
| 135 |
code_snippet=self.code_snippet,
|
| 136 |
+
comments=self.comments.copy(),
|
| 137 |
+
test_results=self.test_results,
|
| 138 |
step=self.step_count,
|
| 139 |
done=self.done
|
| 140 |
)
|
|
|
|
| 144 |
pr_title=self.pr_title,
|
| 145 |
pr_description=self.pr_description,
|
| 146 |
code_snippet=self.code_snippet,
|
| 147 |
+
comments=self.comments.copy(),
|
| 148 |
+
test_results=self.test_results,
|
| 149 |
step=self.step_count,
|
| 150 |
done=self.done
|
| 151 |
)
|