AIMLxDIV commited on
Commit
548872c
·
1 Parent(s): 58b79ef

Add codereview_env/env.py

Browse files
Files changed (1) hide show
  1. codereview_env/env.py +160 -0
codereview_env/env.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from codereview_env.models import (
2
+ TaskId, Action, Observation, StepResult, ResetResult,
3
+ ActionType, ActionRecord, EpisodeResult
4
+ )
5
+ from codereview_env.scenario_bank import get_scenario
6
+ from codereview_env.graders.grader_utils import find_best_match
7
+ from codereview_env.graders.bug_grader import grade_bug_detection
8
+ from codereview_env.graders.security_grader import grade_security_audit
9
+ from codereview_env.graders.arch_grader import grade_architectural_review
10
+
11
+ class CodeReviewEnv:
12
+ TASK_MAX_STEPS = {
13
+ TaskId.BUG_DETECTION: 10,
14
+ TaskId.SECURITY_AUDIT: 15,
15
+ TaskId.ARCHITECTURAL_REVIEW: 20,
16
+ }
17
+
18
+ def __init__(self):
19
+ self._state = None
20
+
21
+ def reset(self, task_id: TaskId, seed: int = 42) -> ResetResult:
22
+ scenario = get_scenario(task_id, seed)
23
+ self._state = {
24
+ "task_id": task_id,
25
+ "seed": seed,
26
+ "scenario": scenario,
27
+ "step_count": 0,
28
+ "noise_budget": 5,
29
+ "max_steps": self.TASK_MAX_STEPS[task_id],
30
+ "history": [],
31
+ "running_score": 0.0,
32
+ "done": False,
33
+ "issues_found": set(), # Set of ground truth issue IDs
34
+ "false_positives": [] # List of action bodies that were FPs
35
+ }
36
+ return ResetResult(
37
+ observation=self._build_obs(),
38
+ task_id=task_id,
39
+ seed=seed,
40
+ scenario_hash=scenario.hash
41
+ )
42
+
43
+ def step(self, action: Action) -> StepResult:
44
+ if self._state is None or self._state["done"]:
45
+ raise RuntimeError("Episode is done or not initialized. Call reset().")
46
+
47
+ s = self._state
48
+ s["step_count"] += 1
49
+
50
+ # Record action in history
51
+ s["history"].append(ActionRecord(
52
+ action_type=action.action_type,
53
+ body=action.body,
54
+ filename=action.filename,
55
+ line_number=action.line_number,
56
+ severity=action.severity,
57
+ category=action.category,
58
+ verdict=action.verdict
59
+ ))
60
+
61
+ # Apply logic
62
+ reward_delta = self._apply_action(action)
63
+ s["running_score"] += reward_delta
64
+
65
+ # Check termination
66
+ s["done"] = (
67
+ action.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES)
68
+ or s["step_count"] >= s["max_steps"]
69
+ or s["noise_budget"] <= 0
70
+ )
71
+
72
+ return StepResult(
73
+ observation=self._build_obs(),
74
+ reward=round(s["running_score"], 4),
75
+ done=s["done"],
76
+ info={
77
+ "step": s["step_count"],
78
+ "score": s["running_score"],
79
+ "noise_budget": s["noise_budget"],
80
+ "issues_found_count": len(s["issues_found"])
81
+ }
82
+ )
83
+
84
+ def _build_obs(self) -> Observation:
85
+ s = self._state
86
+ sc = s["scenario"]
87
+ return Observation(
88
+ task_id=s["task_id"],
89
+ pr_title=sc.pr_title,
90
+ pr_description=sc.pr_description,
91
+ diff="\n".join([f.patch for f in sc.files_changed]),
92
+ files_changed=sc.files_changed,
93
+ step_count=s["step_count"],
94
+ max_steps=s["max_steps"],
95
+ history=s["history"],
96
+ noise_budget=s["noise_budget"]
97
+ )
98
+
99
+ def _apply_action(self, action: Action) -> float:
100
+ """
101
+ Updates the running score using specialized graders.
102
+ """
103
+ s = self._state
104
+ sc = s["scenario"]
105
+
106
+ if action.action_type == ActionType.FLAG_ISSUE:
107
+ matched_gt = find_best_match(action, sc.ground_truth_issues, s["issues_found"])
108
+ if matched_gt:
109
+ s["issues_found"].add(matched_gt.id)
110
+ else:
111
+ s["noise_budget"] -= 1
112
+ s["false_positives"].append(action.body)
113
+
114
+ # Recalculate full score based on current history
115
+ if s["task_id"] == TaskId.BUG_DETECTION:
116
+ new_score = grade_bug_detection(sc, s["history"])
117
+ elif s["task_id"] == TaskId.SECURITY_AUDIT:
118
+ new_score = grade_security_audit(sc, s["history"])
119
+ else:
120
+ new_score = grade_architectural_review(sc, s["history"])
121
+
122
+ reward_delta = new_score - s["running_score"]
123
+ return reward_delta
124
+
125
+ def get_final_result(self) -> EpisodeResult:
126
+ s = self._state
127
+ sc = s["scenario"]
128
+
129
+ # Calculate missed issues
130
+ all_gt_ids = {gt.id for gt in sc.ground_truth_issues}
131
+ missed_ids = list(all_gt_ids - s["issues_found"])
132
+
133
+ # Calculate official score via specialized graders
134
+ if s["task_id"] == TaskId.BUG_DETECTION:
135
+ final_score = grade_bug_detection(sc, s["history"])
136
+ elif s["task_id"] == TaskId.SECURITY_AUDIT:
137
+ final_score = grade_security_audit(sc, s["history"])
138
+ else:
139
+ final_score = grade_architectural_review(sc, s["history"])
140
+
141
+ # Check verdict correct for Arch tasks handled by grader already,
142
+ # but let's keep the return schema consistent
143
+ verdict_correct = None
144
+ if s["task_id"] == TaskId.ARCHITECTURAL_REVIEW:
145
+ final_action = s["history"][-1] if s["history"] else None
146
+ if final_action and final_action.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
147
+ required_verdicts = [gt.required_verdict for gt in sc.ground_truth_issues if gt.required_verdict]
148
+ if required_verdicts:
149
+ verdict_correct = final_action.verdict == required_verdicts[0]
150
+
151
+ return EpisodeResult(
152
+ task_id=s["task_id"],
153
+ seed=s["seed"],
154
+ total_steps=s["step_count"],
155
+ final_score=round(final_score, 4),
156
+ issues_found=list(s["issues_found"]),
157
+ issues_missed=missed_ids,
158
+ false_positives=s["false_positives"],
159
+ verdict_correct=verdict_correct
160
+ )