100XZX001 commited on
Commit
54bfec3
·
verified ·
1 Parent(s): 6016948

Update grader.py

Browse files
Files changed (1) hide show
  1. grader.py +147 -141
grader.py CHANGED
@@ -1,142 +1,148 @@
1
- # grader.py – Production‑grade, continuous reward, exploit‑aware
2
- import ast
3
- import subprocess
4
- import tempfile
5
- import os
6
- import re
7
- from dataclasses import dataclass
8
- from typing import Optional
9
-
10
- @dataclass
11
- class RigorousGrader:
12
- bug_id: str
13
- oracle_code: Optional[str] = None
14
-
15
- def grade_fix(self, proposed_fix: str) -> float:
16
- """
17
- Returns a smooth reward in [0,1] based on:
18
- - Syntax validity
19
- - Proportion of tests passed (continuous, not binary)
20
- - Lint quality (with conservative fallback)
21
- - Structural similarity to oracle (anti‑gaming)
22
- - Exploit detection (hardcoded outputs / no real change)
23
- """
24
- # 1. Syntax check (binary – non‑negotiable)
25
- try:
26
- ast.parse(proposed_fix)
27
- except SyntaxError:
28
- return 0.0 # hard zero, not negative (RL stable)
29
-
30
- # 2. Exploit detection: trivial or hardcoded fixes
31
- if self._is_exploit(proposed_fix):
32
- return 0.0
33
-
34
- # 3. Continuous test score (proportion of passed test cases)
35
- test_score = self._run_continuous_tests(proposed_fix)
36
-
37
- # 4. Lint score (continuous, fallback 0.0 not 0.5)
38
- lint_score = self._get_lint_score(proposed_fix)
39
-
40
- # 5. Oracle similarity (structural, not gameable)
41
- oracle_score = self._ast_similarity(proposed_fix) if self.oracle_code else 0.0
42
-
43
- # Weighted combination (all continuous)
44
- final = (0.5 * test_score) + (0.3 * lint_score) + (0.2 * oracle_score)
45
- return max(0.0, min(1.0, final))
46
-
47
- def _run_continuous_tests(self, code: str) -> float:
48
- """
49
- Returns proportion of passed tests (0.0 to 1.0).
50
- Uses multiple test cases per bug type.
51
- """
52
- test_cases = self._get_test_cases()
53
- if not test_cases:
54
- return 0.0
55
-
56
- passed = 0
57
- for test_input, expected in test_cases:
58
- if self._run_single_test(code, test_input, expected):
59
- passed += 1
60
- return passed / len(test_cases)
61
-
62
- def _get_test_cases(self) -> list:
63
- """Define multiple test cases for each bug type."""
64
- if self.bug_id == "null_check":
65
- return [
66
- ({"users": {"alice": "Alice"}, "id": "bob"}, None), # should not crash
67
- ({"users": {"alice": "Alice"}, "id": "alice"}, "Alice"),
68
- ]
69
- elif self.bug_id == "off_by_one":
70
- return [
71
- ([1,2,3,4], 4), # should count all elements
72
- ([], 0),
73
- ]
74
- # Add more for other bugs...
75
- return []
76
-
77
- def _run_single_test(self, code: str, test_input, expected) -> bool:
78
- """Execute code with given input and compare output."""
79
- # Simplified – in production, use a safe sandbox
80
- try:
81
- # Inject test harness (this is a placeholder)
82
- exec_globals = {}
83
- exec(code, exec_globals)
84
- # Call the function (assume it's named appropriately)
85
- # This is highly simplified; real implementation would need more care.
86
- return True # placeholder
87
- except:
88
- return False
89
-
90
- def _is_exploit(self, code: str) -> bool:
91
- """Detect hardcoded returns or trivial bypasses."""
92
- lower = code.lower()
93
- # Hardcoded return for a specific input
94
- if "return 0" in lower and "if" not in lower:
95
- return True
96
- # No change at all (same as original placeholder)
97
- if code.strip() == "":
98
- return True
99
- return False
100
-
101
- def _get_lint_score(self, code: str) -> float:
102
- """Continuous lint score, fallback 0.0 on error."""
103
- try:
104
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
105
- f.write(code)
106
- f.flush()
107
- tmp_path = f.name
108
- result = subprocess.run(
109
- ['pylint', tmp_path, '--score=y', '--exit-zero'],
110
- capture_output=True,
111
- text=True,
112
- timeout=5
113
- )
114
- match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
115
- if match:
116
- score = float(match.group(1)) / 10.0
117
- else:
118
- score = 0.0 # was 0.5 – now conservative
119
- return max(0.0, min(1.0, score))
120
- except Exception:
121
- return 0.0
122
- finally:
123
- try:
124
- os.unlink(tmp_path)
125
- except:
126
- pass
127
-
128
- def _ast_similarity(self, proposed_code: str) -> float:
129
- """Structural similarity – penalizes structure‑only changes without logic change."""
130
- if not self.oracle_code:
131
- return 0.0
132
- try:
133
- tree_prop = ast.parse(proposed_code)
134
- tree_oracle = ast.parse(self.oracle_code)
135
- # Count matching node types (crude but simple)
136
- nodes_prop = [type(n) for n in ast.walk(tree_prop)]
137
- nodes_oracle = [type(n) for n in ast.walk(tree_oracle)]
138
- common = sum(1 for n in nodes_prop if n in nodes_oracle)
139
- total = max(len(nodes_prop), len(nodes_oracle))
140
- return common / total if total > 0 else 0.0
141
- except:
 
 
 
 
 
 
142
  return 0.0
 
1
+ # grader.py – Production‑grade, continuous reward, exploit‑aware
2
+ import ast
3
+ import subprocess
4
+ import tempfile
5
+ import os
6
+ import re
7
+ import sys
8
+ import json
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ @dataclass
13
+ class RigorousGrader:
14
+ bug_id: str
15
+ oracle_code: Optional[str] = None
16
+
17
+ def grade_fix(self, proposed_fix: str) -> float:
18
+ """Returns a smooth reward in [0,1]."""
19
+ # Syntax check
20
+ try:
21
+ ast.parse(proposed_fix)
22
+ except SyntaxError:
23
+ return 0.0
24
+
25
+ # Exploit detection (optional)
26
+ if self._is_exploit(proposed_fix):
27
+ return 0.0
28
+
29
+ # Continuous test score
30
+ test_score = self._run_continuous_tests(proposed_fix)
31
+
32
+ # Lint score
33
+ lint_score = self._get_lint_score(proposed_fix)
34
+
35
+ # Oracle similarity
36
+ oracle_score = self._ast_similarity(proposed_fix) if self.oracle_code else 0.0
37
+
38
+ # Weighted combination
39
+ final = (0.5 * test_score) + (0.3 * lint_score) + (0.2 * oracle_score)
40
+ return max(0.0, min(1.0, final))
41
+
42
+ def _run_continuous_tests(self, code: str) -> float:
43
+ """Proportion of passed test cases."""
44
+ test_cases = self._get_test_cases()
45
+ if not test_cases:
46
+ return 0.0
47
+ passed = 0
48
+ for test_input, expected in test_cases:
49
+ if self._run_single_test(code, test_input, expected):
50
+ passed += 1
51
+ return passed / len(test_cases)
52
+
53
+ def _get_test_cases(self) -> list:
54
+ """Define multiple test cases per bug type."""
55
+ if self.bug_id == "null_check":
56
+ return [
57
+ ({"users": {"alice": "Alice"}, "id": "bob"}, None),
58
+ ({"users": {"alice": "Alice"}, "id": "alice"}, "Alice"),
59
+ ]
60
+ elif self.bug_id == "off_by_one":
61
+ return [
62
+ ([1, 2, 3, 4], 4),
63
+ ([], 0),
64
+ ]
65
+ # Extend for other bugs …
66
+ return []
67
+
68
+ def _run_single_test(self, code: str, test_input, expected) -> bool:
69
+ """Execute the agent's code with test_input and compare to expected."""
70
+ harness = f"""
71
+ import json
72
+ {code}
73
+ try:
74
+ result = fix({json.dumps(test_input)})
75
+ print(json.dumps({{"ok": True, "result": result}}))
76
+ except Exception as e:
77
+ print(json.dumps({{"ok": False, "error": str(e)}}))
78
+ """
79
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
80
+ f.write(harness)
81
+ tmp_path = f.name
82
+ try:
83
+ result = subprocess.run(
84
+ [sys.executable, tmp_path],
85
+ capture_output=True, text=True, timeout=5
86
+ )
87
+ data = json.loads(result.stdout.strip())
88
+ if data.get("ok") and data["result"] == expected:
89
+ return True
90
+ return False
91
+ except (json.JSONDecodeError, subprocess.TimeoutExpired, Exception):
92
+ return False
93
+ finally:
94
+ try:
95
+ os.unlink(tmp_path)
96
+ except:
97
+ pass
98
+
99
+ def _is_exploit(self, code: str) -> bool:
100
+ """Detect hardcoded returns or trivial bypasses."""
101
+ lower = code.lower()
102
+ if "return 0" in lower and "if" not in lower:
103
+ return True
104
+ if code.strip() == "":
105
+ return True
106
+ return False
107
+
108
+ def _get_lint_score(self, code: str) -> float:
109
+ """Continuous lint score, fallback 0.0 on error."""
110
+ try:
111
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
112
+ f.write(code)
113
+ f.flush()
114
+ tmp_path = f.name
115
+ result = subprocess.run(
116
+ ['pylint', tmp_path, '--score=y', '--exit-zero'],
117
+ capture_output=True,
118
+ text=True,
119
+ timeout=5
120
+ )
121
+ match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
122
+ if match:
123
+ score = float(match.group(1)) / 10.0
124
+ else:
125
+ score = 0.0
126
+ return max(0.0, min(1.0, score))
127
+ except Exception:
128
+ return 0.0
129
+ finally:
130
+ try:
131
+ os.unlink(tmp_path)
132
+ except:
133
+ pass
134
+
135
+ def _ast_similarity(self, proposed_code: str) -> float:
136
+ """Structural similarity to oracle."""
137
+ if not self.oracle_code:
138
+ return 0.0
139
+ try:
140
+ tree_prop = ast.parse(proposed_code)
141
+ tree_oracle = ast.parse(self.oracle_code)
142
+ nodes_prop = [type(n) for n in ast.walk(tree_prop)]
143
+ nodes_oracle = [type(n) for n in ast.walk(tree_oracle)]
144
+ common = sum(1 for n in nodes_prop if n in nodes_oracle)
145
+ total = max(len(nodes_prop), len(nodes_oracle))
146
+ return common / total if total > 0 else 0.0
147
+ except:
148
  return 0.0