Spaces:
Sleeping
Sleeping
File size: 3,523 Bytes
0cc95e3 3a26e23 0cc95e3 3a26e23 8be8d83 3a26e23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | from typing import Tuple, Dict, Any, Optional
from models import Observation, Action, Reward, State
class CodeReviewEnv:
def __init__(self, task: str = "easy"):
self.task = task
self.reset()
def set_task(self, task: str):
if task not in ["easy", "medium", "hard"]:
raise ValueError(f"Unknown task: {task}")
self.task = task
def reset(self) -> Observation:
if self.task is None:
raise RuntimeError("Task not set. Call set_task() first.")
self.step_count = 0
self.agent_comment = None
self.done = False
if self.task == "easy":
self.pr_code = "def get_user(id):\n return users[id] # missing null check"
self.comments = ["Looks good!", "Maybe add a comment?"]
elif self.task == "medium":
self.pr_code = "for i in range(len(items)):\n process(items[i])\n# O(n^2) when it could be O(n)"
self.comments = ["Nice code"]
elif self.task == "hard":
self.pr_code = "def calculate_average(data):\n total = sum(data)\n return total / len(data) # what if data is empty?"
self.comments = ["LGTM"]
else:
raise RuntimeError(f"Invalid task: {self.task}")
return self._get_observation()
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
if self.done:
raise RuntimeError("Episode already finished")
reward = 0.0
info = {}
if action.action_type == "write_comment":
self.agent_comment = action.comment_text or ""
reward = 0.2
quality_score = self._grade_comment(self.agent_comment)
reward += quality_score
if reward > 1.0:
reward = 1.0
self.done = True
elif action.action_type == "skip":
reward = -0.1
self.done = True
elif action.action_type == "done":
reward = -0.5
self.done = True
else:
reward = -0.2
self.done = True
self.step_count += 1
obs = self._get_observation()
return obs, Reward(value=reward), self.done, info
def _grade_comment(self, comment: str) -> float:
if self.task == "easy":
keywords = ["null", "key", "missing", "check", "exists", "handle"]
matched = sum(1 for kw in keywords if kw in comment.lower())
return min(1.0, matched / 3)
elif self.task == "medium":
keywords = ["enumerate", "for item in", "range", "inefficient", "optimize"]
matched = sum(1 for kw in keywords if kw in comment.lower())
return min(1.0, matched / 3)
elif self.task == "hard":
keywords = ["empty", "zero", "length", "check", "handle", "exception"]
matched = sum(1 for kw in keywords if kw in comment.lower())
return min(1.0, matched / 3)
else:
return 0.0
def _get_observation(self) -> Observation:
return Observation(
pr_code=self.pr_code,
comments=self.comments,
agent_comment=self.agent_comment,
step=self.step_count,
done=self.done
)
def state(self) -> State:
return State(
pr_code=self.pr_code,
comments=self.comments,
agent_comment=self.agent_comment,
step=self.step_count,
done=self.done
) |