Update redteam.py
Browse files- redteam.py +200 -24
redteam.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
-
# redteam.py –
|
| 2 |
import ast
|
| 3 |
import random
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
from typing import Tuple, Optional, List, Dict
|
| 6 |
|
| 7 |
# ----------------------------------------------------------------------
|
| 8 |
-
# AST
|
| 9 |
# ----------------------------------------------------------------------
|
| 10 |
class ASTBugInjector(ast.NodeTransformer):
|
| 11 |
def __init__(self, bug_type: str):
|
|
@@ -13,43 +13,214 @@ class ASTBugInjector(ast.NodeTransformer):
|
|
| 13 |
self.bug_type = bug_type
|
| 14 |
self.modified = False
|
| 15 |
|
|
|
|
| 16 |
def visit_If(self, node: ast.If):
|
|
|
|
| 17 |
if self.bug_type == "null_check" and not self.modified:
|
| 18 |
if node.body and len(node.body) == 1:
|
| 19 |
self.modified = True
|
| 20 |
-
return node.body[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
return self.generic_visit(node)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def visit_For(self, node: ast.For):
|
| 24 |
-
if self.bug_type
|
| 25 |
-
if isinstance(node.iter, ast.Call) and
|
| 26 |
-
|
| 27 |
-
|
|
|
|
| 28 |
new_iter = ast.Call(
|
| 29 |
func=ast.Name(id='range', ctx=ast.Load()),
|
| 30 |
args=[
|
| 31 |
ast.Constant(value=1),
|
| 32 |
-
ast.BinOp(
|
| 33 |
-
left=node.iter.args[0],
|
| 34 |
-
op=ast.Sub(),
|
| 35 |
-
right=ast.Constant(value=1)
|
| 36 |
-
)
|
| 37 |
],
|
| 38 |
keywords=[]
|
| 39 |
)
|
| 40 |
node.iter = new_iter
|
| 41 |
self.modified = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return self.generic_visit(node)
|
| 43 |
|
| 44 |
def visit_BinOp(self, node: ast.BinOp):
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
self.modified = True
|
| 49 |
return self.generic_visit(node)
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# ----------------------------------------------------------------------
|
| 52 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# ----------------------------------------------------------------------
|
| 54 |
@dataclass
|
| 55 |
class RedTeam:
|
|
@@ -63,17 +234,24 @@ class RedTeam:
|
|
| 63 |
|
| 64 |
def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
oracle_fix is the original (correct) code.
|
| 69 |
"""
|
| 70 |
-
|
| 71 |
-
bug_type = self._random.choice(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
|
|
|
| 73 |
try:
|
| 74 |
tree = ast.parse(original_code)
|
| 75 |
except SyntaxError:
|
| 76 |
-
# If the code can't be parsed, return it unchanged
|
| 77 |
return original_code, "parse_error", "Syntax error in original code", original_code
|
| 78 |
|
| 79 |
injector = ASTBugInjector(bug_type)
|
|
@@ -85,13 +263,11 @@ class RedTeam:
|
|
| 85 |
oracle_fix = original_code
|
| 86 |
description = f"AST bug: {bug_type}"
|
| 87 |
else:
|
| 88 |
-
# Fallback: no injection possible (e.g., code doesn't contain the target pattern)
|
| 89 |
buggy_code = original_code
|
| 90 |
oracle_fix = original_code
|
| 91 |
bug_type = "no_op"
|
| 92 |
description = "No suitable code structure found for injection"
|
| 93 |
|
| 94 |
-
# Add noise
|
| 95 |
if self._random.random() < self.noise_prob:
|
| 96 |
buggy_code += "\n# TODO: refactor later"
|
| 97 |
|
|
|
|
| 1 |
+
# redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels)
|
| 2 |
import ast
|
| 3 |
import random
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
from typing import Tuple, Optional, List, Dict
|
| 6 |
|
| 7 |
# ----------------------------------------------------------------------
|
| 8 |
+
# 1. AST Bug Injector (extended for all simple bugs)
|
| 9 |
# ----------------------------------------------------------------------
|
| 10 |
class ASTBugInjector(ast.NodeTransformer):
|
| 11 |
def __init__(self, bug_type: str):
|
|
|
|
| 13 |
self.bug_type = bug_type
|
| 14 |
self.modified = False
|
| 15 |
|
| 16 |
+
# --- Easy: null_check, simple_typo, string_index, default_value, empty_return ---
|
| 17 |
def visit_If(self, node: ast.If):
|
| 18 |
+
# null_check: remove the if-guard
|
| 19 |
if self.bug_type == "null_check" and not self.modified:
|
| 20 |
if node.body and len(node.body) == 1:
|
| 21 |
self.modified = True
|
| 22 |
+
return node.body[0]
|
| 23 |
+
# division_by_zero_empty: remove the empty check
|
| 24 |
+
if self.bug_type == "division_by_zero_empty" and not self.modified:
|
| 25 |
+
# pattern: if not data: return 0 – we delete the entire if
|
| 26 |
+
if (isinstance(node.test, ast.UnaryOp) and
|
| 27 |
+
isinstance(node.test.op, ast.Not) and
|
| 28 |
+
isinstance(node.test.operand, ast.Name)):
|
| 29 |
+
self.modified = True
|
| 30 |
+
return None # signal to remove this node from parent
|
| 31 |
return self.generic_visit(node)
|
| 32 |
|
| 33 |
+
def visit_Name(self, node: ast.Name):
|
| 34 |
+
if self.bug_type == "simple_typo" and not self.modified:
|
| 35 |
+
if node.id == "users":
|
| 36 |
+
self.modified = True
|
| 37 |
+
return ast.Name(id="usres", ctx=node.ctx)
|
| 38 |
+
return self.generic_visit(node)
|
| 39 |
+
|
| 40 |
+
def visit_Subscript(self, node: ast.Subscript):
|
| 41 |
+
if self.bug_type == "string_index" and not self.modified:
|
| 42 |
+
if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
|
| 43 |
+
old_val = node.slice.value.value
|
| 44 |
+
if isinstance(old_val, int):
|
| 45 |
+
self.modified = True
|
| 46 |
+
node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
|
| 47 |
+
return self.generic_visit(node)
|
| 48 |
+
|
| 49 |
+
def visit_Call(self, node: ast.Call):
|
| 50 |
+
# default_value: change dict.get(key) to dict[key] (no default)
|
| 51 |
+
if self.bug_type == "default_value" and not self.modified:
|
| 52 |
+
if (isinstance(node.func, ast.Attribute) and
|
| 53 |
+
node.func.attr == "get" and len(node.args) == 1):
|
| 54 |
+
self.modified = True
|
| 55 |
+
return ast.Subscript(
|
| 56 |
+
value=node.func.value,
|
| 57 |
+
slice=ast.Index(value=node.args[0]),
|
| 58 |
+
ctx=node.ctx
|
| 59 |
+
)
|
| 60 |
+
# abs_usage: remove abs()
|
| 61 |
+
if self.bug_type == "abs_usage" and not self.modified:
|
| 62 |
+
if isinstance(node.func, ast.Name) and node.func.id == "abs":
|
| 63 |
+
self.modified = True
|
| 64 |
+
return node.args[0]
|
| 65 |
+
return self.generic_visit(node)
|
| 66 |
+
|
| 67 |
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
| 68 |
+
# empty_return: insert a premature return None
|
| 69 |
+
if self.bug_type == "empty_return" and not self.modified:
|
| 70 |
+
self.modified = True
|
| 71 |
+
node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
|
| 72 |
+
return self.generic_visit(node)
|
| 73 |
+
|
| 74 |
+
# --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var ---
|
| 75 |
def visit_For(self, node: ast.For):
|
| 76 |
+
if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
|
| 77 |
+
if (isinstance(node.iter, ast.Call) and
|
| 78 |
+
isinstance(node.iter.func, ast.Name) and
|
| 79 |
+
node.iter.func.id == "range"):
|
| 80 |
+
if self.bug_type == "off_by_one":
|
| 81 |
new_iter = ast.Call(
|
| 82 |
func=ast.Name(id='range', ctx=ast.Load()),
|
| 83 |
args=[
|
| 84 |
ast.Constant(value=1),
|
| 85 |
+
ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
],
|
| 87 |
keywords=[]
|
| 88 |
)
|
| 89 |
node.iter = new_iter
|
| 90 |
self.modified = True
|
| 91 |
+
elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
|
| 92 |
+
new_iter = ast.Call(
|
| 93 |
+
func=ast.Name(id='range', ctx=ast.Load()),
|
| 94 |
+
args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
|
| 95 |
+
keywords=[]
|
| 96 |
+
)
|
| 97 |
+
node.iter = new_iter
|
| 98 |
+
self.modified = True
|
| 99 |
return self.generic_visit(node)
|
| 100 |
|
| 101 |
def visit_BinOp(self, node: ast.BinOp):
|
| 102 |
+
# sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv
|
| 103 |
+
if not self.modified:
|
| 104 |
+
if self.bug_type in ("wrong_operator", "sign_error"):
|
| 105 |
+
if isinstance(node.op, ast.Add):
|
| 106 |
+
node.op = ast.Sub()
|
| 107 |
+
self.modified = True
|
| 108 |
+
elif isinstance(node.op, ast.Sub):
|
| 109 |
+
node.op = ast.Add()
|
| 110 |
+
self.modified = True
|
| 111 |
+
elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
|
| 112 |
+
node.op = ast.FloorDiv()
|
| 113 |
self.modified = True
|
| 114 |
return self.generic_visit(node)
|
| 115 |
|
| 116 |
+
def visit_arguments(self, node: ast.arguments):
|
| 117 |
+
# swap_args: swap first two arguments of a function
|
| 118 |
+
if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
|
| 119 |
+
self.modified = True
|
| 120 |
+
node.args[0], node.args[1] = node.args[1], node.args[0]
|
| 121 |
+
return self.generic_visit(node)
|
| 122 |
+
|
| 123 |
+
def visit_Assign(self, node: ast.Assign):
|
| 124 |
+
# uninitialised_var: remove an assignment statement (replaced with Pass)
|
| 125 |
+
if self.bug_type == "uninitialised_var" and not self.modified:
|
| 126 |
+
self.modified = True
|
| 127 |
+
return ast.Pass()
|
| 128 |
+
return self.generic_visit(node)
|
| 129 |
+
|
| 130 |
# ----------------------------------------------------------------------
|
| 131 |
+
# 2. Bug database (25 bugs, categorized by difficulty)
|
| 132 |
+
# ----------------------------------------------------------------------
|
| 133 |
+
BUG_DB = {
|
| 134 |
+
"easy": {
|
| 135 |
+
"null_check": {"type": "ast", "bug_type": "null_check"},
|
| 136 |
+
"simple_typo": {"type": "ast", "bug_type": "simple_typo"},
|
| 137 |
+
"string_index": {"type": "ast", "bug_type": "string_index"},
|
| 138 |
+
"default_value": {"type": "ast", "bug_type": "default_value"},
|
| 139 |
+
"empty_return": {"type": "ast", "bug_type": "empty_return"},
|
| 140 |
+
},
|
| 141 |
+
"medium": {
|
| 142 |
+
"off_by_one": {"type": "ast", "bug_type": "off_by_one"},
|
| 143 |
+
"loop_skip": {"type": "ast", "bug_type": "loop_skip"},
|
| 144 |
+
"sign_error": {"type": "ast", "bug_type": "sign_error"},
|
| 145 |
+
"swap_args": {"type": "ast", "bug_type": "swap_args"},
|
| 146 |
+
"uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
|
| 147 |
+
},
|
| 148 |
+
"hard": {
|
| 149 |
+
"division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
|
| 150 |
+
"division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"}, # same injector
|
| 151 |
+
"float_precision": {"type": "ast", "bug_type": "float_precision"},
|
| 152 |
+
"abs_usage": {"type": "ast", "bug_type": "abs_usage"},
|
| 153 |
+
"round_error": {"type": "ast", "bug_type": "round_error"}, # can be extended
|
| 154 |
+
},
|
| 155 |
+
"harder": {
|
| 156 |
+
"missing_lock": {
|
| 157 |
+
"type": "template",
|
| 158 |
+
"buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
|
| 159 |
+
"oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1",
|
| 160 |
+
},
|
| 161 |
+
"double_lock": {
|
| 162 |
+
"type": "template",
|
| 163 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
|
| 164 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')",
|
| 165 |
+
},
|
| 166 |
+
"global_nonatomic": {
|
| 167 |
+
"type": "template",
|
| 168 |
+
"buggy": "count = 0\ndef add():\n global count\n count = count + 1",
|
| 169 |
+
"oracle": "count = 0\ndef add():\n global count\n count += 1",
|
| 170 |
+
},
|
| 171 |
+
"thread_safe_list": {
|
| 172 |
+
"type": "template",
|
| 173 |
+
"buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
|
| 174 |
+
"oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)",
|
| 175 |
+
},
|
| 176 |
+
"volatile_read": {
|
| 177 |
+
"type": "template",
|
| 178 |
+
"buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
|
| 179 |
+
"oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break",
|
| 180 |
+
},
|
| 181 |
+
},
|
| 182 |
+
"hardest": {
|
| 183 |
+
"deadlock_order": {
|
| 184 |
+
"type": "template",
|
| 185 |
+
"buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
|
| 186 |
+
"oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass",
|
| 187 |
+
},
|
| 188 |
+
"nested_lock_timeout": {
|
| 189 |
+
"type": "template",
|
| 190 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
|
| 191 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()",
|
| 192 |
+
},
|
| 193 |
+
"fork_join": {
|
| 194 |
+
"type": "template",
|
| 195 |
+
"buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
|
| 196 |
+
"oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
|
| 197 |
+
},
|
| 198 |
+
"mutex_release": {
|
| 199 |
+
"type": "template",
|
| 200 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
|
| 201 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass",
|
| 202 |
+
},
|
| 203 |
+
"race_on_init": {
|
| 204 |
+
"type": "template",
|
| 205 |
+
"buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
|
| 206 |
+
"oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
|
| 207 |
+
},
|
| 208 |
+
},
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
# ----------------------------------------------------------------------
|
| 212 |
+
# 3. Derived helpers
|
| 213 |
+
# ----------------------------------------------------------------------
|
| 214 |
+
TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}
|
| 215 |
+
|
| 216 |
+
TEMPLATE_BUGS = {}
|
| 217 |
+
for level, bugs in BUG_DB.items():
|
| 218 |
+
for bug_id, bug in bugs.items():
|
| 219 |
+
if bug["type"] == "template":
|
| 220 |
+
TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])
|
| 221 |
+
|
| 222 |
+
# ----------------------------------------------------------------------
|
| 223 |
+
# 4. RedTeam Controller (task‑aware)
|
| 224 |
# ----------------------------------------------------------------------
|
| 225 |
@dataclass
|
| 226 |
class RedTeam:
|
|
|
|
| 234 |
|
| 235 |
def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
|
| 236 |
"""
|
| 237 |
+
Returns: (buggy_code, bug_type, description, oracle_fix)
|
| 238 |
+
Selects a bug appropriate for the task difficulty.
|
|
|
|
| 239 |
"""
|
| 240 |
+
bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
|
| 241 |
+
bug_type = self._random.choice(bug_list)
|
| 242 |
+
|
| 243 |
+
# Template bug: return hardcoded buggy + oracle
|
| 244 |
+
if bug_type in TEMPLATE_BUGS:
|
| 245 |
+
buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
|
| 246 |
+
description = f"Template bug: {bug_type}"
|
| 247 |
+
if self._random.random() < self.noise_prob:
|
| 248 |
+
buggy_code += "\n# TODO: refactor later"
|
| 249 |
+
return buggy_code, bug_type, description, oracle_code
|
| 250 |
|
| 251 |
+
# AST injection
|
| 252 |
try:
|
| 253 |
tree = ast.parse(original_code)
|
| 254 |
except SyntaxError:
|
|
|
|
| 255 |
return original_code, "parse_error", "Syntax error in original code", original_code
|
| 256 |
|
| 257 |
injector = ASTBugInjector(bug_type)
|
|
|
|
| 263 |
oracle_fix = original_code
|
| 264 |
description = f"AST bug: {bug_type}"
|
| 265 |
else:
|
|
|
|
| 266 |
buggy_code = original_code
|
| 267 |
oracle_fix = original_code
|
| 268 |
bug_type = "no_op"
|
| 269 |
description = "No suitable code structure found for injection"
|
| 270 |
|
|
|
|
| 271 |
if self._random.random() < self.noise_prob:
|
| 272 |
buggy_code += "\n# TODO: refactor later"
|
| 273 |
|