OpenEnv / server /environment.py
krishnachoudhary-hclguvi
Reorganize project structure to match openenv-course multi-mode deployment guidelines
62f081e unverified
import os
import sys
import json
from datasets import load_dataset
class CodeReviewEnv:
def __init__(self, dataset_name="Krish-05/krish-bug-detect-fix", split="train", difficulty="medium"):
self.benchmark_name = "krish_bug_detect_benchmark"
self.dataset_name = dataset_name
self.split = split
self.difficulty = difficulty
self.task_name = f"code_review_task_{difficulty}"
self.steps_taken = 0
self.rewards = []
self.max_steps = 4
self.current_sample = None
self.correct_comments = 0
self._load_dataset()
def _load_dataset(self):
try:
self.dataset = load_dataset(self.dataset_name, split=self.split)
# Filter dataset by difficulty based on code length
filtered_ds = []
for item in self.dataset:
num_lines = len(item['modified_code'].split('\n'))
if self.difficulty == "easy" and num_lines <= 10:
filtered_ds.append(item)
elif self.difficulty == "medium" and 10 < num_lines <= 30:
filtered_ds.append(item)
elif self.difficulty == "hard" and num_lines > 30:
filtered_ds.append(item)
self.filtered_datasets = filtered_ds if filtered_ds else list(self.dataset)
self.current_idx = 0
except Exception as e:
print(f"Error loading dataset: {e}")
self.dataset = None
self.filtered_datasets = []
def reset(self):
self.steps_taken = 0
self.rewards = []
self.correct_comments = 0
if not self.filtered_datasets:
return "Error: Dataset not loaded or empty."
self.current_sample = self.filtered_datasets[self.current_idx % len(self.filtered_datasets)]
self.current_idx += 1
buggy_code_raw = self.current_sample.get('modified_code', 'No code found')
# Enumerate lines 1-indexed for the agent to review
enumerated_lines = [f"{i+1} | {line}" for i, line in enumerate(buggy_code_raw.split('\n'))]
buggy_code = '\n'.join(enumerated_lines)
observation = f"""You are a strict code reviewer. Please review the following code identifying any bugs. Note that line numbers are provided on the left.
{buggy_code}
Available actions:
1. COMMENT <line_number> <issue_description>
2. APPROVE
3. REQUEST_CHANGES
"""
return observation
def step(self, action):
self.steps_taken += 1
done = False
reward = 0.0
action = action.strip()
true_bug_line = self.current_sample.get('number_of_line', -1)
if action.startswith("COMMENT"):
try:
parts = action.split(' ', 2)
line_str = parts[1]
# Strip punctuation just in case
for p in ['.', ':', ',']:
line_str = line_str.replace(p, '')
comment_line = int(line_str)
# Deterministic Grader Check
if comment_line == true_bug_line:
reward = 0.8 # High intermediate reward for locating the exact bug line
self.correct_comments += 1
obs = f"Valid bug identified on line {comment_line}. Will you APPROVE or REQUEST_CHANGES?"
else:
reward = -0.2 # False positive penalty
obs = f"Line {comment_line} appears correct. False positive. Continue review or APPROVE/REQUEST_CHANGES."
except Exception as e:
reward = -0.1
obs = f"Malformed format. Use COMMENT <line_number> <description>. Action parsed error: {e}"
elif action.startswith("APPROVE"):
if self.correct_comments == 0: # Missed the bug entirely
reward = -1.0
obs = "You approved buggy code. deployment failed."
else: # Found the bug but still approved?
reward = -0.5
obs = "You found a bug but approved the PR anyway."
done = True
elif action.startswith("REQUEST_CHANGES"):
if self.correct_comments > 0: # Perfectly identified the bug and rejected the PR
reward = 1.0
obs = "Code review complete. Changes requested successfully."
else: # Rejected without pointing out the correct bug
reward = -0.5
obs = "You requested changes without accurately commenting on the bug line."
done = True
else:
reward = -0.1
obs = "Invalid action format. Use COMMENT <line_number>, APPROVE, or REQUEST_CHANGES."
if self.steps_taken >= self.max_steps:
done = True
if not action.startswith("APPROVE") and not action.startswith("REQUEST_CHANGES"):
obs = "Time limit exceeded. PR auto-closed."
self.rewards.append(reward)
formatted_reward = f"{reward:.2f}"
return obs, formatted_reward, done, None