100XZX001 commited on
Commit
667b6db
·
verified ·
1 Parent(s): a064cf3

Update redteam.py

Browse files
Files changed (1) hide show
  1. redteam.py +97 -141
redteam.py CHANGED
@@ -1,142 +1,98 @@
1
- # redteam.py – AST-based bug injection + dataset examples + noise
2
- import ast
3
- import random
4
- from dataclasses import dataclass, field
5
- from typing import Tuple, Optional, List, Dict
6
-
7
- # ----------------------------------------------------------------------
8
- # AST-based bug injector
9
- # ----------------------------------------------------------------------
10
- class ASTBugInjector(ast.NodeTransformer):
11
- def __init__(self, bug_type: str):
12
- super().__init__()
13
- self.bug_type = bug_type
14
- self.modified = False
15
-
16
- def visit_If(self, node: ast.If):
17
- if self.bug_type == "null_check" and not self.modified:
18
- if node.body and len(node.body) == 1:
19
- self.modified = True
20
- return node.body[0]
21
- return self.generic_visit(node)
22
-
23
- def visit_For(self, node: ast.For):
24
- if self.bug_type == "off_by_one" and not self.modified:
25
- if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name):
26
- if node.iter.func.id == "range":
27
- new_iter = ast.Call(
28
- func=ast.Name(id='range', ctx=ast.Load()),
29
- args=[
30
- ast.Constant(value=1),
31
- ast.BinOp(
32
- left=node.iter.args[0],
33
- op=ast.Sub(),
34
- right=ast.Constant(value=1)
35
- )
36
- ],
37
- keywords=[]
38
- )
39
- node.iter = new_iter
40
- self.modified = True
41
- return self.generic_visit(node)
42
-
43
- def visit_BinOp(self, node: ast.BinOp):
44
- if self.bug_type == "wrong_operator" and not self.modified:
45
- if isinstance(node.op, ast.Add):
46
- node.op = ast.Sub()
47
- self.modified = True
48
- return self.generic_visit(node)
49
-
50
- # ----------------------------------------------------------------------
51
- # Dataset-driven realistic bugs
52
- # ----------------------------------------------------------------------
53
- DATASET_EXAMPLES: List[Dict] = [
54
- {
55
- "bug_type": "mutation_side_effect",
56
- "original": """def update_config_file(filepath, new_pair, default_config=None):
57
- if default_config is None:
58
- default_config = {}
59
- config = default_config.copy()
60
- key, value = new_pair
61
- config[key] = value
62
- return config""",
63
- "buggy": """def update_config_file(filepath, new_pair, default_config={}):
64
- config = default_config
65
- key, value = new_pair
66
- config[key] = value
67
- return config""",
68
- },
69
- {
70
- "bug_type": "infinite_loop",
71
- "original": """def retry(attempts, max_retries):
72
- while attempts < max_retries:
73
- success = False
74
- if success:
75
- break
76
- attempts += 1""",
77
- "buggy": """def retry(attempts, max_retries):
78
- while attempts < max_retries:
79
- success = False
80
- if success:
81
- break""",
82
- }
83
- ]
84
-
85
- # ----------------------------------------------------------------------
86
- # RedTeam Controller
87
- # ----------------------------------------------------------------------
88
- @dataclass
89
- class RedTeam:
90
- task: str
91
- seed: Optional[int] = 42
92
- noise_prob: float = 0.2
93
- dataset_prob: float = 0.4 # probability of using dataset instead of AST
94
- _random: random.Random = field(init=False)
95
-
96
- def __post_init__(self):
97
- self._random = random.Random(self.seed)
98
-
99
- def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
100
- """
101
- Returns:
102
- (buggy_code, bug_type, description, oracle_fix)
103
- """
104
- # Decide injection mode
105
- use_dataset = self._random.random() < self.dataset_prob
106
-
107
- if use_dataset:
108
- example = self._random.choice(DATASET_EXAMPLES)
109
- buggy_code = example["buggy"]
110
- oracle_fix = example["original"]
111
- bug_type = example["bug_type"]
112
- description = f"Dataset bug: {bug_type}"
113
- else:
114
- bug_types = ["null_check", "off_by_one", "wrong_operator"]
115
- bug_type = self._random.choice(bug_types)
116
-
117
- try:
118
- tree = ast.parse(original_code)
119
- except SyntaxError:
120
- return original_code, "parse_error", "Syntax error", original_code
121
-
122
- injector = ASTBugInjector(bug_type)
123
- modified_tree = injector.visit(tree)
124
- ast.fix_missing_locations(modified_tree)
125
-
126
- if injector.modified:
127
- buggy_code = ast.unparse(modified_tree)
128
- oracle_fix = original_code
129
- else:
130
- # fallback: no injection
131
- buggy_code = original_code
132
- oracle_fix = original_code
133
- bug_type = "no_op"
134
- description = "No modification applied"
135
-
136
- description = f"AST bug: {bug_type}"
137
-
138
- # Add noise
139
- if self._random.random() < self.noise_prob:
140
- buggy_code += "\n# TODO: refactor later"
141
-
142
  return buggy_code, bug_type, description, oracle_fix
 
1
+ # redteam.py – AST-based bug injection (no dataset, always modifies given code)
2
+ import ast
3
+ import random
4
+ from dataclasses import dataclass, field
5
+ from typing import Tuple, Optional, List, Dict
6
+
7
+ # ----------------------------------------------------------------------
8
+ # AST-based bug injector
9
+ # ----------------------------------------------------------------------
10
+ class ASTBugInjector(ast.NodeTransformer):
11
+ def __init__(self, bug_type: str):
12
+ super().__init__()
13
+ self.bug_type = bug_type
14
+ self.modified = False
15
+
16
+ def visit_If(self, node: ast.If):
17
+ if self.bug_type == "null_check" and not self.modified:
18
+ if node.body and len(node.body) == 1:
19
+ self.modified = True
20
+ return node.body[0] # remove the if, directly execute body
21
+ return self.generic_visit(node)
22
+
23
+ def visit_For(self, node: ast.For):
24
+ if self.bug_type == "off_by_one" and not self.modified:
25
+ if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name):
26
+ if node.iter.func.id == "range":
27
+ # Change range(x) to range(1, x-1) to introduce off-by-one
28
+ new_iter = ast.Call(
29
+ func=ast.Name(id='range', ctx=ast.Load()),
30
+ args=[
31
+ ast.Constant(value=1),
32
+ ast.BinOp(
33
+ left=node.iter.args[0],
34
+ op=ast.Sub(),
35
+ right=ast.Constant(value=1)
36
+ )
37
+ ],
38
+ keywords=[]
39
+ )
40
+ node.iter = new_iter
41
+ self.modified = True
42
+ return self.generic_visit(node)
43
+
44
+ def visit_BinOp(self, node: ast.BinOp):
45
+ if self.bug_type == "wrong_operator" and not self.modified:
46
+ if isinstance(node.op, ast.Add):
47
+ node.op = ast.Sub()
48
+ self.modified = True
49
+ return self.generic_visit(node)
50
+
51
+ # ----------------------------------------------------------------------
52
+ # RedTeam Controller
53
+ # ----------------------------------------------------------------------
54
+ @dataclass
55
+ class RedTeam:
56
+ task: str
57
+ seed: Optional[int] = 42
58
+ noise_prob: float = 0.2
59
+ _random: random.Random = field(init=False)
60
+
61
+ def __post_init__(self):
62
+ self._random = random.Random(self.seed)
63
+
64
+ def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
65
+ """
66
+ Always modifies the given original_code using an AST bug.
67
+ Returns (buggy_code, bug_type, description, oracle_fix).
68
+ oracle_fix is the original (correct) code.
69
+ """
70
+ bug_types = ["null_check", "off_by_one", "wrong_operator"]
71
+ bug_type = self._random.choice(bug_types)
72
+
73
+ try:
74
+ tree = ast.parse(original_code)
75
+ except SyntaxError:
76
+ # If the code can't be parsed, return it unchanged
77
+ return original_code, "parse_error", "Syntax error in original code", original_code
78
+
79
+ injector = ASTBugInjector(bug_type)
80
+ modified_tree = injector.visit(tree)
81
+ ast.fix_missing_locations(modified_tree)
82
+
83
+ if injector.modified:
84
+ buggy_code = ast.unparse(modified_tree)
85
+ oracle_fix = original_code
86
+ description = f"AST bug: {bug_type}"
87
+ else:
88
+ # Fallback: no injection possible (e.g., code doesn't contain the target pattern)
89
+ buggy_code = original_code
90
+ oracle_fix = original_code
91
+ bug_type = "no_op"
92
+ description = "No suitable code structure found for injection"
93
+
94
+ # Add noise
95
+ if self._random.random() < self.noise_prob:
96
+ buggy_code += "\n# TODO: refactor later"
97
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return buggy_code, bug_type, description, oracle_fix