Update redteam.py
Browse files- redteam.py +97 -141
redteam.py
CHANGED
|
@@ -1,142 +1,98 @@
|
|
| 1 |
-
# redteam.py – AST-based bug injection
|
| 2 |
-
import ast
|
| 3 |
-
import random
|
| 4 |
-
from dataclasses import dataclass, field
|
| 5 |
-
from typing import Tuple, Optional, List, Dict
|
| 6 |
-
|
| 7 |
-
# ----------------------------------------------------------------------
|
| 8 |
-
# AST-based bug injector
|
| 9 |
-
# ----------------------------------------------------------------------
|
| 10 |
-
class ASTBugInjector(ast.NodeTransformer):
|
| 11 |
-
def __init__(self, bug_type: str):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.bug_type = bug_type
|
| 14 |
-
self.modified = False
|
| 15 |
-
|
| 16 |
-
def visit_If(self, node: ast.If):
|
| 17 |
-
if self.bug_type == "null_check" and not self.modified:
|
| 18 |
-
if node.body and len(node.body) == 1:
|
| 19 |
-
self.modified = True
|
| 20 |
-
return node.body[0]
|
| 21 |
-
return self.generic_visit(node)
|
| 22 |
-
|
| 23 |
-
def visit_For(self, node: ast.For):
|
| 24 |
-
if self.bug_type == "off_by_one" and not self.modified:
|
| 25 |
-
if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name):
|
| 26 |
-
if node.iter.func.id == "range":
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
ast.
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
#
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
"
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
|
| 100 |
-
"""
|
| 101 |
-
Returns:
|
| 102 |
-
(buggy_code, bug_type, description, oracle_fix)
|
| 103 |
-
"""
|
| 104 |
-
# Decide injection mode
|
| 105 |
-
use_dataset = self._random.random() < self.dataset_prob
|
| 106 |
-
|
| 107 |
-
if use_dataset:
|
| 108 |
-
example = self._random.choice(DATASET_EXAMPLES)
|
| 109 |
-
buggy_code = example["buggy"]
|
| 110 |
-
oracle_fix = example["original"]
|
| 111 |
-
bug_type = example["bug_type"]
|
| 112 |
-
description = f"Dataset bug: {bug_type}"
|
| 113 |
-
else:
|
| 114 |
-
bug_types = ["null_check", "off_by_one", "wrong_operator"]
|
| 115 |
-
bug_type = self._random.choice(bug_types)
|
| 116 |
-
|
| 117 |
-
try:
|
| 118 |
-
tree = ast.parse(original_code)
|
| 119 |
-
except SyntaxError:
|
| 120 |
-
return original_code, "parse_error", "Syntax error", original_code
|
| 121 |
-
|
| 122 |
-
injector = ASTBugInjector(bug_type)
|
| 123 |
-
modified_tree = injector.visit(tree)
|
| 124 |
-
ast.fix_missing_locations(modified_tree)
|
| 125 |
-
|
| 126 |
-
if injector.modified:
|
| 127 |
-
buggy_code = ast.unparse(modified_tree)
|
| 128 |
-
oracle_fix = original_code
|
| 129 |
-
else:
|
| 130 |
-
# fallback: no injection
|
| 131 |
-
buggy_code = original_code
|
| 132 |
-
oracle_fix = original_code
|
| 133 |
-
bug_type = "no_op"
|
| 134 |
-
description = "No modification applied"
|
| 135 |
-
|
| 136 |
-
description = f"AST bug: {bug_type}"
|
| 137 |
-
|
| 138 |
-
# Add noise
|
| 139 |
-
if self._random.random() < self.noise_prob:
|
| 140 |
-
buggy_code += "\n# TODO: refactor later"
|
| 141 |
-
|
| 142 |
return buggy_code, bug_type, description, oracle_fix
|
|
|
|
| 1 |
+
# redteam.py – AST-based bug injection (no dataset, always modifies given code)
|
| 2 |
+
import ast
|
| 3 |
+
import random
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Tuple, Optional, List, Dict
|
| 6 |
+
|
| 7 |
+
# ----------------------------------------------------------------------
|
| 8 |
+
# AST-based bug injector
|
| 9 |
+
# ----------------------------------------------------------------------
|
| 10 |
+
class ASTBugInjector(ast.NodeTransformer):
|
| 11 |
+
def __init__(self, bug_type: str):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.bug_type = bug_type
|
| 14 |
+
self.modified = False
|
| 15 |
+
|
| 16 |
+
def visit_If(self, node: ast.If):
|
| 17 |
+
if self.bug_type == "null_check" and not self.modified:
|
| 18 |
+
if node.body and len(node.body) == 1:
|
| 19 |
+
self.modified = True
|
| 20 |
+
return node.body[0] # remove the if, directly execute body
|
| 21 |
+
return self.generic_visit(node)
|
| 22 |
+
|
| 23 |
+
def visit_For(self, node: ast.For):
|
| 24 |
+
if self.bug_type == "off_by_one" and not self.modified:
|
| 25 |
+
if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name):
|
| 26 |
+
if node.iter.func.id == "range":
|
| 27 |
+
# Change range(x) to range(1, x-1) to introduce off-by-one
|
| 28 |
+
new_iter = ast.Call(
|
| 29 |
+
func=ast.Name(id='range', ctx=ast.Load()),
|
| 30 |
+
args=[
|
| 31 |
+
ast.Constant(value=1),
|
| 32 |
+
ast.BinOp(
|
| 33 |
+
left=node.iter.args[0],
|
| 34 |
+
op=ast.Sub(),
|
| 35 |
+
right=ast.Constant(value=1)
|
| 36 |
+
)
|
| 37 |
+
],
|
| 38 |
+
keywords=[]
|
| 39 |
+
)
|
| 40 |
+
node.iter = new_iter
|
| 41 |
+
self.modified = True
|
| 42 |
+
return self.generic_visit(node)
|
| 43 |
+
|
| 44 |
+
def visit_BinOp(self, node: ast.BinOp):
|
| 45 |
+
if self.bug_type == "wrong_operator" and not self.modified:
|
| 46 |
+
if isinstance(node.op, ast.Add):
|
| 47 |
+
node.op = ast.Sub()
|
| 48 |
+
self.modified = True
|
| 49 |
+
return self.generic_visit(node)
|
| 50 |
+
|
| 51 |
+
# ----------------------------------------------------------------------
|
| 52 |
+
# RedTeam Controller
|
| 53 |
+
# ----------------------------------------------------------------------
|
| 54 |
+
@dataclass
|
| 55 |
+
class RedTeam:
|
| 56 |
+
task: str
|
| 57 |
+
seed: Optional[int] = 42
|
| 58 |
+
noise_prob: float = 0.2
|
| 59 |
+
_random: random.Random = field(init=False)
|
| 60 |
+
|
| 61 |
+
def __post_init__(self):
|
| 62 |
+
self._random = random.Random(self.seed)
|
| 63 |
+
|
| 64 |
+
def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
|
| 65 |
+
"""
|
| 66 |
+
Always modifies the given original_code using an AST bug.
|
| 67 |
+
Returns (buggy_code, bug_type, description, oracle_fix).
|
| 68 |
+
oracle_fix is the original (correct) code.
|
| 69 |
+
"""
|
| 70 |
+
bug_types = ["null_check", "off_by_one", "wrong_operator"]
|
| 71 |
+
bug_type = self._random.choice(bug_types)
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
tree = ast.parse(original_code)
|
| 75 |
+
except SyntaxError:
|
| 76 |
+
# If the code can't be parsed, return it unchanged
|
| 77 |
+
return original_code, "parse_error", "Syntax error in original code", original_code
|
| 78 |
+
|
| 79 |
+
injector = ASTBugInjector(bug_type)
|
| 80 |
+
modified_tree = injector.visit(tree)
|
| 81 |
+
ast.fix_missing_locations(modified_tree)
|
| 82 |
+
|
| 83 |
+
if injector.modified:
|
| 84 |
+
buggy_code = ast.unparse(modified_tree)
|
| 85 |
+
oracle_fix = original_code
|
| 86 |
+
description = f"AST bug: {bug_type}"
|
| 87 |
+
else:
|
| 88 |
+
# Fallback: no injection possible (e.g., code doesn't contain the target pattern)
|
| 89 |
+
buggy_code = original_code
|
| 90 |
+
oracle_fix = original_code
|
| 91 |
+
bug_type = "no_op"
|
| 92 |
+
description = "No suitable code structure found for injection"
|
| 93 |
+
|
| 94 |
+
# Add noise
|
| 95 |
+
if self._random.random() < self.noise_prob:
|
| 96 |
+
buggy_code += "\n# TODO: refactor later"
|
| 97 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
return buggy_code, bug_type, description, oracle_fix
|