100XZX001 commited on
Commit
7f2021e
·
verified ·
1 Parent(s): a40bd87

Update redteam.py

Browse files
Files changed (1) hide show
  1. redteam.py +200 -24
redteam.py CHANGED
@@ -1,11 +1,11 @@
1
- # redteam.py – AST-based bug injection (no dataset, always modifies given code)
2
  import ast
3
  import random
4
  from dataclasses import dataclass, field
5
  from typing import Tuple, Optional, List, Dict
6
 
7
  # ----------------------------------------------------------------------
8
- # AST-based bug injector
9
  # ----------------------------------------------------------------------
10
  class ASTBugInjector(ast.NodeTransformer):
11
  def __init__(self, bug_type: str):
@@ -13,43 +13,214 @@ class ASTBugInjector(ast.NodeTransformer):
13
  self.bug_type = bug_type
14
  self.modified = False
15
 
 
16
  def visit_If(self, node: ast.If):
 
17
  if self.bug_type == "null_check" and not self.modified:
18
  if node.body and len(node.body) == 1:
19
  self.modified = True
20
- return node.body[0] # remove the if, directly execute body
 
 
 
 
 
 
 
 
21
  return self.generic_visit(node)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def visit_For(self, node: ast.For):
24
- if self.bug_type == "off_by_one" and not self.modified:
25
- if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name):
26
- if node.iter.func.id == "range":
27
- # Change range(x) to range(1, x-1) to introduce off-by-one
 
28
  new_iter = ast.Call(
29
  func=ast.Name(id='range', ctx=ast.Load()),
30
  args=[
31
  ast.Constant(value=1),
32
- ast.BinOp(
33
- left=node.iter.args[0],
34
- op=ast.Sub(),
35
- right=ast.Constant(value=1)
36
- )
37
  ],
38
  keywords=[]
39
  )
40
  node.iter = new_iter
41
  self.modified = True
 
 
 
 
 
 
 
 
42
  return self.generic_visit(node)
43
 
44
  def visit_BinOp(self, node: ast.BinOp):
45
- if self.bug_type == "wrong_operator" and not self.modified:
46
- if isinstance(node.op, ast.Add):
47
- node.op = ast.Sub()
 
 
 
 
 
 
 
 
48
  self.modified = True
49
  return self.generic_visit(node)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # ----------------------------------------------------------------------
52
- # RedTeam Controller
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # ----------------------------------------------------------------------
54
  @dataclass
55
  class RedTeam:
@@ -63,17 +234,24 @@ class RedTeam:
63
 
64
  def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
65
  """
66
- Always modifies the given original_code using an AST bug.
67
- Returns (buggy_code, bug_type, description, oracle_fix).
68
- oracle_fix is the original (correct) code.
69
  """
70
- bug_types = ["null_check", "off_by_one", "wrong_operator"]
71
- bug_type = self._random.choice(bug_types)
 
 
 
 
 
 
 
 
72
 
 
73
  try:
74
  tree = ast.parse(original_code)
75
  except SyntaxError:
76
- # If the code can't be parsed, return it unchanged
77
  return original_code, "parse_error", "Syntax error in original code", original_code
78
 
79
  injector = ASTBugInjector(bug_type)
@@ -85,13 +263,11 @@ class RedTeam:
85
  oracle_fix = original_code
86
  description = f"AST bug: {bug_type}"
87
  else:
88
- # Fallback: no injection possible (e.g., code doesn't contain the target pattern)
89
  buggy_code = original_code
90
  oracle_fix = original_code
91
  bug_type = "no_op"
92
  description = "No suitable code structure found for injection"
93
 
94
- # Add noise
95
  if self._random.random() < self.noise_prob:
96
  buggy_code += "\n# TODO: refactor later"
97
 
 
1
+ # redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels)
2
  import ast
3
  import random
4
  from dataclasses import dataclass, field
5
  from typing import Tuple, Optional, List, Dict
6
 
7
  # ----------------------------------------------------------------------
8
+ # 1. AST Bug Injector (extended for all simple bugs)
9
  # ----------------------------------------------------------------------
10
  class ASTBugInjector(ast.NodeTransformer):
11
  def __init__(self, bug_type: str):
 
13
  self.bug_type = bug_type
14
  self.modified = False
15
 
16
+ # --- Easy: null_check, simple_typo, string_index, default_value, empty_return ---
17
  def visit_If(self, node: ast.If):
18
+ # null_check: remove the if-guard
19
  if self.bug_type == "null_check" and not self.modified:
20
  if node.body and len(node.body) == 1:
21
  self.modified = True
22
+ return node.body[0]
23
+ # division_by_zero_empty: remove the empty check
24
+ if self.bug_type == "division_by_zero_empty" and not self.modified:
25
+ # pattern: if not data: return 0 – we delete the entire if
26
+ if (isinstance(node.test, ast.UnaryOp) and
27
+ isinstance(node.test.op, ast.Not) and
28
+ isinstance(node.test.operand, ast.Name)):
29
+ self.modified = True
30
+ return None # signal to remove this node from parent
31
  return self.generic_visit(node)
32
 
33
+ def visit_Name(self, node: ast.Name):
34
+ if self.bug_type == "simple_typo" and not self.modified:
35
+ if node.id == "users":
36
+ self.modified = True
37
+ return ast.Name(id="usres", ctx=node.ctx)
38
+ return self.generic_visit(node)
39
+
40
+ def visit_Subscript(self, node: ast.Subscript):
41
+ if self.bug_type == "string_index" and not self.modified:
42
+ if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
43
+ old_val = node.slice.value.value
44
+ if isinstance(old_val, int):
45
+ self.modified = True
46
+ node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
47
+ return self.generic_visit(node)
48
+
49
+ def visit_Call(self, node: ast.Call):
50
+ # default_value: change dict.get(key) to dict[key] (no default)
51
+ if self.bug_type == "default_value" and not self.modified:
52
+ if (isinstance(node.func, ast.Attribute) and
53
+ node.func.attr == "get" and len(node.args) == 1):
54
+ self.modified = True
55
+ return ast.Subscript(
56
+ value=node.func.value,
57
+ slice=ast.Index(value=node.args[0]),
58
+ ctx=node.ctx
59
+ )
60
+ # abs_usage: remove abs()
61
+ if self.bug_type == "abs_usage" and not self.modified:
62
+ if isinstance(node.func, ast.Name) and node.func.id == "abs":
63
+ self.modified = True
64
+ return node.args[0]
65
+ return self.generic_visit(node)
66
+
67
+ def visit_FunctionDef(self, node: ast.FunctionDef):
68
+ # empty_return: insert a premature return None
69
+ if self.bug_type == "empty_return" and not self.modified:
70
+ self.modified = True
71
+ node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
72
+ return self.generic_visit(node)
73
+
74
+ # --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var ---
75
  def visit_For(self, node: ast.For):
76
+ if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
77
+ if (isinstance(node.iter, ast.Call) and
78
+ isinstance(node.iter.func, ast.Name) and
79
+ node.iter.func.id == "range"):
80
+ if self.bug_type == "off_by_one":
81
  new_iter = ast.Call(
82
  func=ast.Name(id='range', ctx=ast.Load()),
83
  args=[
84
  ast.Constant(value=1),
85
+ ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
 
 
 
 
86
  ],
87
  keywords=[]
88
  )
89
  node.iter = new_iter
90
  self.modified = True
91
+ elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
92
+ new_iter = ast.Call(
93
+ func=ast.Name(id='range', ctx=ast.Load()),
94
+ args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
95
+ keywords=[]
96
+ )
97
+ node.iter = new_iter
98
+ self.modified = True
99
  return self.generic_visit(node)
100
 
101
  def visit_BinOp(self, node: ast.BinOp):
102
+ # sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv
103
+ if not self.modified:
104
+ if self.bug_type in ("wrong_operator", "sign_error"):
105
+ if isinstance(node.op, ast.Add):
106
+ node.op = ast.Sub()
107
+ self.modified = True
108
+ elif isinstance(node.op, ast.Sub):
109
+ node.op = ast.Add()
110
+ self.modified = True
111
+ elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
112
+ node.op = ast.FloorDiv()
113
  self.modified = True
114
  return self.generic_visit(node)
115
 
116
+ def visit_arguments(self, node: ast.arguments):
117
+ # swap_args: swap first two arguments of a function
118
+ if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
119
+ self.modified = True
120
+ node.args[0], node.args[1] = node.args[1], node.args[0]
121
+ return self.generic_visit(node)
122
+
123
+ def visit_Assign(self, node: ast.Assign):
124
+ # uninitialised_var: remove an assignment statement (replaced with Pass)
125
+ if self.bug_type == "uninitialised_var" and not self.modified:
126
+ self.modified = True
127
+ return ast.Pass()
128
+ return self.generic_visit(node)
129
+
130
  # ----------------------------------------------------------------------
131
+ # 2. Bug database (25 bugs, categorized by difficulty)
132
+ # ----------------------------------------------------------------------
133
+ BUG_DB = {
134
+ "easy": {
135
+ "null_check": {"type": "ast", "bug_type": "null_check"},
136
+ "simple_typo": {"type": "ast", "bug_type": "simple_typo"},
137
+ "string_index": {"type": "ast", "bug_type": "string_index"},
138
+ "default_value": {"type": "ast", "bug_type": "default_value"},
139
+ "empty_return": {"type": "ast", "bug_type": "empty_return"},
140
+ },
141
+ "medium": {
142
+ "off_by_one": {"type": "ast", "bug_type": "off_by_one"},
143
+ "loop_skip": {"type": "ast", "bug_type": "loop_skip"},
144
+ "sign_error": {"type": "ast", "bug_type": "sign_error"},
145
+ "swap_args": {"type": "ast", "bug_type": "swap_args"},
146
+ "uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
147
+ },
148
+ "hard": {
149
+ "division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
150
+ "division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"}, # same injector
151
+ "float_precision": {"type": "ast", "bug_type": "float_precision"},
152
+ "abs_usage": {"type": "ast", "bug_type": "abs_usage"},
153
+ "round_error": {"type": "ast", "bug_type": "round_error"}, # can be extended
154
+ },
155
+ "harder": {
156
+ "missing_lock": {
157
+ "type": "template",
158
+ "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
159
+ "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1",
160
+ },
161
+ "double_lock": {
162
+ "type": "template",
163
+ "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
164
+ "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')",
165
+ },
166
+ "global_nonatomic": {
167
+ "type": "template",
168
+ "buggy": "count = 0\ndef add():\n global count\n count = count + 1",
169
+ "oracle": "count = 0\ndef add():\n global count\n count += 1",
170
+ },
171
+ "thread_safe_list": {
172
+ "type": "template",
173
+ "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
174
+ "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)",
175
+ },
176
+ "volatile_read": {
177
+ "type": "template",
178
+ "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
179
+ "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break",
180
+ },
181
+ },
182
+ "hardest": {
183
+ "deadlock_order": {
184
+ "type": "template",
185
+ "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
186
+ "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass",
187
+ },
188
+ "nested_lock_timeout": {
189
+ "type": "template",
190
+ "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
191
+ "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()",
192
+ },
193
+ "fork_join": {
194
+ "type": "template",
195
+ "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
196
+ "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
197
+ },
198
+ "mutex_release": {
199
+ "type": "template",
200
+ "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
201
+ "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass",
202
+ },
203
+ "race_on_init": {
204
+ "type": "template",
205
+ "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
206
+ "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
207
+ },
208
+ },
209
+ }
210
+
211
+ # ----------------------------------------------------------------------
212
+ # 3. Derived helpers
213
+ # ----------------------------------------------------------------------
214
+ TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}
215
+
216
+ TEMPLATE_BUGS = {}
217
+ for level, bugs in BUG_DB.items():
218
+ for bug_id, bug in bugs.items():
219
+ if bug["type"] == "template":
220
+ TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])
221
+
222
+ # ----------------------------------------------------------------------
223
+ # 4. RedTeam Controller (task‑aware)
224
  # ----------------------------------------------------------------------
225
  @dataclass
226
  class RedTeam:
 
234
 
235
  def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
236
  """
237
+ Returns: (buggy_code, bug_type, description, oracle_fix)
238
+ Selects a bug appropriate for the task difficulty.
 
239
  """
240
+ bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
241
+ bug_type = self._random.choice(bug_list)
242
+
243
+ # Template bug: return hardcoded buggy + oracle
244
+ if bug_type in TEMPLATE_BUGS:
245
+ buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
246
+ description = f"Template bug: {bug_type}"
247
+ if self._random.random() < self.noise_prob:
248
+ buggy_code += "\n# TODO: refactor later"
249
+ return buggy_code, bug_type, description, oracle_code
250
 
251
+ # AST injection
252
  try:
253
  tree = ast.parse(original_code)
254
  except SyntaxError:
 
255
  return original_code, "parse_error", "Syntax error in original code", original_code
256
 
257
  injector = ASTBugInjector(bug_type)
 
263
  oracle_fix = original_code
264
  description = f"AST bug: {bug_type}"
265
  else:
 
266
  buggy_code = original_code
267
  oracle_fix = original_code
268
  bug_type = "no_op"
269
  description = "No suitable code structure found for injection"
270
 
 
271
  if self._random.random() < self.noise_prob:
272
  buggy_code += "\n# TODO: refactor later"
273