Spaces:
Sleeping
Sleeping
Update environment.py
Browse files- environment.py +49 -26
environment.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from typing import Tuple, Dict, Any
|
| 2 |
from models import Observation, Action, Reward, State
|
|
|
|
| 3 |
|
| 4 |
class CodeReviewEnv:
|
| 5 |
def __init__(self, task: str = "easy"):
|
|
@@ -7,7 +8,7 @@ class CodeReviewEnv:
|
|
| 7 |
self.reset()
|
| 8 |
|
| 9 |
def set_task(self, task: str):
|
| 10 |
-
if task not in ["easy", "medium", "hard"]:
|
| 11 |
raise ValueError(f"Unknown task: {task}")
|
| 12 |
self.task = task
|
| 13 |
|
|
@@ -18,17 +19,46 @@ class CodeReviewEnv:
|
|
| 18 |
self.agent_comment = None
|
| 19 |
self.done = False
|
| 20 |
|
|
|
|
| 21 |
if self.task == "easy":
|
| 22 |
-
self.
|
|
|
|
|
|
|
| 23 |
self.comments = ["Looks good!", "Maybe add a comment?"]
|
|
|
|
|
|
|
|
|
|
| 24 |
elif self.task == "medium":
|
| 25 |
-
self.
|
|
|
|
|
|
|
| 26 |
self.comments = ["Nice code"]
|
|
|
|
|
|
|
|
|
|
| 27 |
elif self.task == "hard":
|
| 28 |
-
self.
|
|
|
|
|
|
|
| 29 |
self.comments = ["LGTM"]
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
return self._get_observation()
|
| 34 |
|
|
@@ -41,8 +71,13 @@ class CodeReviewEnv:
|
|
| 41 |
|
| 42 |
if action.action_type == "write_comment":
|
| 43 |
self.agent_comment = action.comment_text or ""
|
| 44 |
-
reward = 0.2
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
reward += quality_score
|
| 47 |
if reward > 1.0:
|
| 48 |
reward = 1.0
|
|
@@ -61,25 +96,11 @@ class CodeReviewEnv:
|
|
| 61 |
obs = self._get_observation()
|
| 62 |
return obs, Reward(value=reward), self.done, info
|
| 63 |
|
| 64 |
-
def _grade_comment(self, comment: str) -> float:
|
| 65 |
-
if self.task == "easy":
|
| 66 |
-
keywords = ["null", "key", "missing", "check", "exists", "handle"]
|
| 67 |
-
matched = sum(1 for kw in keywords if kw in comment.lower())
|
| 68 |
-
return min(1.0, matched / 3)
|
| 69 |
-
elif self.task == "medium":
|
| 70 |
-
keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
|
| 71 |
-
matched = sum(1 for kw in keywords if kw in comment.lower())
|
| 72 |
-
return min(1.0, matched / 3)
|
| 73 |
-
elif self.task == "hard":
|
| 74 |
-
keywords = ["empty", "zero", "length", "check", "handle", "exception"]
|
| 75 |
-
matched = sum(1 for kw in keywords if kw in comment.lower())
|
| 76 |
-
return min(1.0, matched / 3)
|
| 77 |
-
else:
|
| 78 |
-
return 0.0
|
| 79 |
-
|
| 80 |
def _get_observation(self) -> Observation:
|
| 81 |
return Observation(
|
| 82 |
-
|
|
|
|
|
|
|
| 83 |
comments=self.comments,
|
| 84 |
agent_comment=self.agent_comment,
|
| 85 |
step=self.step_count,
|
|
@@ -88,7 +109,9 @@ class CodeReviewEnv:
|
|
| 88 |
|
| 89 |
def state(self) -> State:
|
| 90 |
return State(
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
comments=self.comments,
|
| 93 |
agent_comment=self.agent_comment,
|
| 94 |
step=self.step_count,
|
|
|
|
| 1 |
from typing import Tuple, Dict, Any
|
| 2 |
from models import Observation, Action, Reward, State
|
| 3 |
+
from grader import grade_comment
|
| 4 |
|
| 5 |
class CodeReviewEnv:
|
| 6 |
def __init__(self, task: str = "easy"):
|
|
|
|
| 8 |
self.reset()
|
| 9 |
|
| 10 |
def set_task(self, task: str):
|
| 11 |
+
if task not in ["easy", "medium", "hard", "harder", "hardest"]:
|
| 12 |
raise ValueError(f"Unknown task: {task}")
|
| 13 |
self.task = task
|
| 14 |
|
|
|
|
| 19 |
self.agent_comment = None
|
| 20 |
self.done = False
|
| 21 |
|
| 22 |
+
# Task definitions with richer context
|
| 23 |
if self.task == "easy":
|
| 24 |
+
self.pr_title = "Fix missing null check in user lookup"
|
| 25 |
+
self.pr_description = "The current code does not handle missing user IDs. It raises a KeyError."
|
| 26 |
+
self.code_snippet = "def get_user(id):\n return users[id] # missing null check"
|
| 27 |
self.comments = ["Looks good!", "Maybe add a comment?"]
|
| 28 |
+
self.expected_keywords = ["null", "key", "missing", "check", "exists", "handle"]
|
| 29 |
+
self.expert_comment = "Add a check to ensure the key exists before accessing the dictionary to avoid KeyError."
|
| 30 |
+
|
| 31 |
elif self.task == "medium":
|
| 32 |
+
self.pr_title = "Improve loop efficiency"
|
| 33 |
+
self.pr_description = "The loop uses `range(len(items))` which is inefficient and less readable."
|
| 34 |
+
self.code_snippet = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
|
| 35 |
self.comments = ["Nice code"]
|
| 36 |
+
self.expected_keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
|
| 37 |
+
self.expert_comment = "Use `for item in items:` for a more Pythonic and efficient loop."
|
| 38 |
+
|
| 39 |
elif self.task == "hard":
|
| 40 |
+
self.pr_title = "Handle division by zero in average calculation"
|
| 41 |
+
self.pr_description = "The function crashes when the input list is empty."
|
| 42 |
+
self.code_snippet = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
|
| 43 |
self.comments = ["LGTM"]
|
| 44 |
+
self.expected_keywords = ["empty", "zero", "length", "check", "handle", "exception"]
|
| 45 |
+
self.expert_comment = "Check if the list is empty and return a sensible default (e.g., 0) or raise a descriptive error."
|
| 46 |
+
|
| 47 |
+
elif self.task == "harder":
|
| 48 |
+
self.pr_title = "Fix race condition in counter increment"
|
| 49 |
+
self.pr_description = "Multiple threads increment a counter without synchronization, causing lost updates."
|
| 50 |
+
self.code_snippet = "counter = 0\ndef increment():\n global counter\n counter += 1\n# called from multiple threads"
|
| 51 |
+
self.comments = ["Simple code, should be fine?"]
|
| 52 |
+
self.expected_keywords = ["thread", "lock", "synchronization", "atomic", "race", "concurrent"]
|
| 53 |
+
self.expert_comment = "Use a threading.Lock to protect the counter increment, or use an atomic operation like `threading.atomic`."
|
| 54 |
+
|
| 55 |
+
else: # hardest
|
| 56 |
+
self.pr_title = "Fix deadlock in database transaction"
|
| 57 |
+
self.pr_description = "Two threads acquire locks in opposite order, leading to potential deadlock."
|
| 58 |
+
self.code_snippet = "with lock1:\n with lock2:\n do_work()\n# another thread does lock2 then lock1"
|
| 59 |
+
self.comments = ["Seems okay?"]
|
| 60 |
+
self.expected_keywords = ["deadlock", "lock order", "acquire", "release", "trylock", "timeout"]
|
| 61 |
+
self.expert_comment = "Ensure all threads acquire locks in the same order to prevent deadlock. Consider using a timeout or a single lock."
|
| 62 |
|
| 63 |
return self._get_observation()
|
| 64 |
|
|
|
|
| 71 |
|
| 72 |
if action.action_type == "write_comment":
|
| 73 |
self.agent_comment = action.comment_text or ""
|
| 74 |
+
reward = 0.2 # dense bonus for writing
|
| 75 |
+
# Semantic grader
|
| 76 |
+
quality_score = grade_comment(
|
| 77 |
+
self.agent_comment,
|
| 78 |
+
self.expected_keywords,
|
| 79 |
+
self.expert_comment
|
| 80 |
+
)
|
| 81 |
reward += quality_score
|
| 82 |
if reward > 1.0:
|
| 83 |
reward = 1.0
|
|
|
|
| 96 |
obs = self._get_observation()
|
| 97 |
return obs, Reward(value=reward), self.done, info
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def _get_observation(self) -> Observation:
|
| 100 |
return Observation(
|
| 101 |
+
pr_title=self.pr_title,
|
| 102 |
+
pr_description=self.pr_description,
|
| 103 |
+
code_snippet=self.code_snippet,
|
| 104 |
comments=self.comments,
|
| 105 |
agent_comment=self.agent_comment,
|
| 106 |
step=self.step_count,
|
|
|
|
| 109 |
|
| 110 |
def state(self) -> State:
|
| 111 |
return State(
|
| 112 |
+
pr_title=self.pr_title,
|
| 113 |
+
pr_description=self.pr_description,
|
| 114 |
+
code_snippet=self.code_snippet,
|
| 115 |
comments=self.comments,
|
| 116 |
agent_comment=self.agent_comment,
|
| 117 |
step=self.step_count,
|