File size: 5,310 Bytes
2182d10
 
 
 
 
 
3d38d37
2182d10
 
 
3d38d37
 
 
2182d10
 
3d38d37
 
2182d10
3d38d37
 
2182d10
 
 
 
 
3d38d37
 
 
 
 
 
 
 
 
 
 
 
 
2182d10
 
 
 
3d38d37
2182d10
 
 
 
3d38d37
2182d10
3d38d37
 
2182d10
3d38d37
 
 
 
2182d10
3d38d37
 
 
2182d10
62f081e
2182d10
 
 
 
 
 
 
62f081e
2182d10
 
 
 
 
 
 
 
3d38d37
2182d10
 
3d38d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2182d10
3d38d37
 
 
 
 
 
 
2182d10
 
3d38d37
 
 
 
 
 
 
2182d10
 
 
 
3d38d37
2182d10
 
 
3d38d37
 
2182d10
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import sys
import json
from datasets import load_dataset

class CodeReviewEnv:
    def __init__(self, dataset_name="Krish-05/krish-bug-detect-fix", split="train", difficulty="medium"):
        self.benchmark_name = "krish_bug_detect_benchmark"
        self.dataset_name = dataset_name
        self.split = split
        self.difficulty = difficulty
        self.task_name = f"code_review_task_{difficulty}"
        
        self.steps_taken = 0
        self.rewards = []
        self.max_steps = 4 
        
        self.current_sample = None
        self.correct_comments = 0
        
        self._load_dataset()

    def _load_dataset(self):
        try:
            self.dataset = load_dataset(self.dataset_name, split=self.split)
            
            # Filter dataset by difficulty based on code length
            filtered_ds = []
            for item in self.dataset:
                num_lines = len(item['modified_code'].split('\n'))
                if self.difficulty == "easy" and num_lines <= 10:
                    filtered_ds.append(item)
                elif self.difficulty == "medium" and 10 < num_lines <= 30:
                    filtered_ds.append(item)
                elif self.difficulty == "hard" and num_lines > 30:
                    filtered_ds.append(item)
            
            self.filtered_datasets = filtered_ds if filtered_ds else list(self.dataset)
            self.current_idx = 0
        except Exception as e:
            print(f"Error loading dataset: {e}")
            self.dataset = None
            self.filtered_datasets = []

    def reset(self):
        self.steps_taken = 0
        self.rewards = []
        self.correct_comments = 0
        
        if not self.filtered_datasets:
            return "Error: Dataset not loaded or empty."
            
        self.current_sample = self.filtered_datasets[self.current_idx % len(self.filtered_datasets)]
        self.current_idx += 1
        
        buggy_code_raw = self.current_sample.get('modified_code', 'No code found')
        
        # Enumerate lines 1-indexed for the agent to review
        enumerated_lines = [f"{i+1} | {line}" for i, line in enumerate(buggy_code_raw.split('\n'))]
        buggy_code = '\n'.join(enumerated_lines)
        
        observation = f"""You are a strict code reviewer. Please review the following code identifying any bugs. Note that line numbers are provided on the left.
        
{buggy_code}

Available actions:
1. COMMENT <line_number> <issue_description>
2. APPROVE
3. REQUEST_CHANGES
"""
        return observation

    def step(self, action):
        self.steps_taken += 1
        done = False
        reward = 0.0
        
        action = action.strip()
        true_bug_line = self.current_sample.get('number_of_line', -1)
        
        if action.startswith("COMMENT"):
            try:
                parts = action.split(' ', 2)
                line_str = parts[1]
                # Strip punctuation just in case
                for p in ['.', ':', ',']:
                    line_str = line_str.replace(p, '')
                    
                comment_line = int(line_str)
                
                # Deterministic Grader Check
                if comment_line == true_bug_line:
                    reward = 0.8  # High intermediate reward for locating the exact bug line
                    self.correct_comments += 1
                    obs = f"Valid bug identified on line {comment_line}. Will you APPROVE or REQUEST_CHANGES?"
                else:
                    reward = -0.2 # False positive penalty
                    obs = f"Line {comment_line} appears correct. False positive. Continue review or APPROVE/REQUEST_CHANGES."
            except Exception as e:
                reward = -0.1
                obs = f"Malformed format. Use COMMENT <line_number> <description>. Action parsed error: {e}"
        
        elif action.startswith("APPROVE"):
            if self.correct_comments == 0: # Missed the bug entirely
                reward = -1.0 
                obs = "You approved buggy code. deployment failed."
            else: # Found the bug but still approved?
                reward = -0.5 
                obs = "You found a bug but approved the PR anyway."
            done = True
            
        elif action.startswith("REQUEST_CHANGES"):
            if self.correct_comments > 0: # Perfectly identified the bug and rejected the PR
                reward = 1.0 
                obs = "Code review complete. Changes requested successfully."
            else: # Rejected without pointing out the correct bug
                reward = -0.5 
                obs = "You requested changes without accurately commenting on the bug line."
            done = True
            
        else:
            reward = -0.1
            obs = "Invalid action format. Use COMMENT <line_number>, APPROVE, or REQUEST_CHANGES."
               
        if self.steps_taken >= self.max_steps:
            done = True
            if not action.startswith("APPROVE") and not action.startswith("REQUEST_CHANGES"):
                obs = "Time limit exceeded. PR auto-closed."
            
        self.rewards.append(reward)
        formatted_reward = f"{reward:.2f}"
        
        return obs, formatted_reward, done, None