File size: 5,310 Bytes
2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 62f081e 2182d10 62f081e 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 3d38d37 2182d10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import os
import sys
import json
from datasets import load_dataset
class CodeReviewEnv:
def __init__(self, dataset_name="Krish-05/krish-bug-detect-fix", split="train", difficulty="medium"):
self.benchmark_name = "krish_bug_detect_benchmark"
self.dataset_name = dataset_name
self.split = split
self.difficulty = difficulty
self.task_name = f"code_review_task_{difficulty}"
self.steps_taken = 0
self.rewards = []
self.max_steps = 4
self.current_sample = None
self.correct_comments = 0
self._load_dataset()
def _load_dataset(self):
try:
self.dataset = load_dataset(self.dataset_name, split=self.split)
# Filter dataset by difficulty based on code length
filtered_ds = []
for item in self.dataset:
num_lines = len(item['modified_code'].split('\n'))
if self.difficulty == "easy" and num_lines <= 10:
filtered_ds.append(item)
elif self.difficulty == "medium" and 10 < num_lines <= 30:
filtered_ds.append(item)
elif self.difficulty == "hard" and num_lines > 30:
filtered_ds.append(item)
self.filtered_datasets = filtered_ds if filtered_ds else list(self.dataset)
self.current_idx = 0
except Exception as e:
print(f"Error loading dataset: {e}")
self.dataset = None
self.filtered_datasets = []
def reset(self):
self.steps_taken = 0
self.rewards = []
self.correct_comments = 0
if not self.filtered_datasets:
return "Error: Dataset not loaded or empty."
self.current_sample = self.filtered_datasets[self.current_idx % len(self.filtered_datasets)]
self.current_idx += 1
buggy_code_raw = self.current_sample.get('modified_code', 'No code found')
# Enumerate lines 1-indexed for the agent to review
enumerated_lines = [f"{i+1} | {line}" for i, line in enumerate(buggy_code_raw.split('\n'))]
buggy_code = '\n'.join(enumerated_lines)
observation = f"""You are a strict code reviewer. Please review the following code identifying any bugs. Note that line numbers are provided on the left.
{buggy_code}
Available actions:
1. COMMENT <line_number> <issue_description>
2. APPROVE
3. REQUEST_CHANGES
"""
return observation
def step(self, action):
self.steps_taken += 1
done = False
reward = 0.0
action = action.strip()
true_bug_line = self.current_sample.get('number_of_line', -1)
if action.startswith("COMMENT"):
try:
parts = action.split(' ', 2)
line_str = parts[1]
# Strip punctuation just in case
for p in ['.', ':', ',']:
line_str = line_str.replace(p, '')
comment_line = int(line_str)
# Deterministic Grader Check
if comment_line == true_bug_line:
reward = 0.8 # High intermediate reward for locating the exact bug line
self.correct_comments += 1
obs = f"Valid bug identified on line {comment_line}. Will you APPROVE or REQUEST_CHANGES?"
else:
reward = -0.2 # False positive penalty
obs = f"Line {comment_line} appears correct. False positive. Continue review or APPROVE/REQUEST_CHANGES."
except Exception as e:
reward = -0.1
obs = f"Malformed format. Use COMMENT <line_number> <description>. Action parsed error: {e}"
elif action.startswith("APPROVE"):
if self.correct_comments == 0: # Missed the bug entirely
reward = -1.0
obs = "You approved buggy code. deployment failed."
else: # Found the bug but still approved?
reward = -0.5
obs = "You found a bug but approved the PR anyway."
done = True
elif action.startswith("REQUEST_CHANGES"):
if self.correct_comments > 0: # Perfectly identified the bug and rejected the PR
reward = 1.0
obs = "Code review complete. Changes requested successfully."
else: # Rejected without pointing out the correct bug
reward = -0.5
obs = "You requested changes without accurately commenting on the bug line."
done = True
else:
reward = -0.1
obs = "Invalid action format. Use COMMENT <line_number>, APPROVE, or REQUEST_CHANGES."
if self.steps_taken >= self.max_steps:
done = True
if not action.startswith("APPROVE") and not action.startswith("REQUEST_CHANGES"):
obs = "Time limit exceeded. PR auto-closed."
self.rewards.append(reward)
formatted_reward = f"{reward:.2f}"
return obs, formatted_reward, done, None
|