File size: 13,898 Bytes
1588266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels)
import ast
import random
from dataclasses import dataclass, field
from typing import Tuple, Optional, List, Dict

# ----------------------------------------------------------------------
# 1. AST Bug Injector (extended for all simple bugs)
# ----------------------------------------------------------------------
class ASTBugInjector(ast.NodeTransformer):
    def __init__(self, bug_type: str):
        super().__init__()
        self.bug_type = bug_type
        self.modified = False

    # --- Easy: null_check, simple_typo, string_index, default_value, empty_return ---
    def visit_If(self, node: ast.If):
        # null_check: remove the if-guard
        if self.bug_type == "null_check" and not self.modified:
            if node.body and len(node.body) == 1:
                self.modified = True
                return node.body[0]
        # division_by_zero_empty: remove the empty check
        if self.bug_type == "division_by_zero_empty" and not self.modified:
            # pattern: if not data: return 0  – we delete the entire if
            if (isinstance(node.test, ast.UnaryOp) and
                isinstance(node.test.op, ast.Not) and
                isinstance(node.test.operand, ast.Name)):
                self.modified = True
                return None  # signal to remove this node from parent
        return self.generic_visit(node)

    def visit_Name(self, node: ast.Name):
        if self.bug_type == "simple_typo" and not self.modified:
            if node.id == "users":
                self.modified = True
                return ast.Name(id="usres", ctx=node.ctx)
        return self.generic_visit(node)

    def visit_Subscript(self, node: ast.Subscript):
        if self.bug_type == "string_index" and not self.modified:
            if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
                old_val = node.slice.value.value
                if isinstance(old_val, int):
                    self.modified = True
                    node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
        return self.generic_visit(node)

    def visit_Call(self, node: ast.Call):
        # default_value: change dict.get(key) to dict[key] (no default)
        if self.bug_type == "default_value" and not self.modified:
            if (isinstance(node.func, ast.Attribute) and
                node.func.attr == "get" and len(node.args) == 1):
                self.modified = True
                return ast.Subscript(
                    value=node.func.value,
                    slice=ast.Index(value=node.args[0]),
                    ctx=node.ctx
                )
        # abs_usage: remove abs()
        if self.bug_type == "abs_usage" and not self.modified:
            if isinstance(node.func, ast.Name) and node.func.id == "abs":
                self.modified = True
                return node.args[0]
        return self.generic_visit(node)

    def visit_FunctionDef(self, node: ast.FunctionDef):
        # empty_return: insert a premature return None
        if self.bug_type == "empty_return" and not self.modified:
            self.modified = True
            node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
        return self.generic_visit(node)

    # --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var ---
    def visit_For(self, node: ast.For):
        if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
            if (isinstance(node.iter, ast.Call) and
                isinstance(node.iter.func, ast.Name) and
                node.iter.func.id == "range"):
                if self.bug_type == "off_by_one":
                    new_iter = ast.Call(
                        func=ast.Name(id='range', ctx=ast.Load()),
                        args=[
                            ast.Constant(value=1),
                            ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
                        ],
                        keywords=[]
                    )
                    node.iter = new_iter
                    self.modified = True
                elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
                    new_iter = ast.Call(
                        func=ast.Name(id='range', ctx=ast.Load()),
                        args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
                        keywords=[]
                    )
                    node.iter = new_iter
                    self.modified = True
        return self.generic_visit(node)

    def visit_BinOp(self, node: ast.BinOp):
        # sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv
        if not self.modified:
            if self.bug_type in ("wrong_operator", "sign_error"):
                if isinstance(node.op, ast.Add):
                    node.op = ast.Sub()
                    self.modified = True
                elif isinstance(node.op, ast.Sub):
                    node.op = ast.Add()
                    self.modified = True
            elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
                node.op = ast.FloorDiv()
                self.modified = True
        return self.generic_visit(node)

    def visit_arguments(self, node: ast.arguments):
        # swap_args: swap first two arguments of a function
        if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
            self.modified = True
            node.args[0], node.args[1] = node.args[1], node.args[0]
        return self.generic_visit(node)

    def visit_Assign(self, node: ast.Assign):
        # uninitialised_var: remove an assignment statement (replaced with Pass)
        if self.bug_type == "uninitialised_var" and not self.modified:
            self.modified = True
            return ast.Pass()
        return self.generic_visit(node)

# ----------------------------------------------------------------------
# 2. Bug database (25 bugs, categorized by difficulty)
# ----------------------------------------------------------------------
BUG_DB = {
    "easy": {
        "null_check":    {"type": "ast", "bug_type": "null_check"},
        "simple_typo":   {"type": "ast", "bug_type": "simple_typo"},
        "string_index":  {"type": "ast", "bug_type": "string_index"},
        "default_value": {"type": "ast", "bug_type": "default_value"},
        "empty_return":  {"type": "ast", "bug_type": "empty_return"},
    },
    "medium": {
        "off_by_one":     {"type": "ast", "bug_type": "off_by_one"},
        "loop_skip":      {"type": "ast", "bug_type": "loop_skip"},
        "sign_error":     {"type": "ast", "bug_type": "sign_error"},
        "swap_args":      {"type": "ast", "bug_type": "swap_args"},
        "uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
    },
    "hard": {
        "division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
        "division_by_zero_zero":  {"type": "ast", "bug_type": "division_by_zero_empty"},  # same injector
        "float_precision":        {"type": "ast", "bug_type": "float_precision"},
        "abs_usage":              {"type": "ast", "bug_type": "abs_usage"},
        "round_error":            {"type": "ast", "bug_type": "round_error"},  # can be extended
    },
    "harder": {
        "missing_lock": {
            "type": "template",
            "buggy": "counter = 0\ndef increment():\n    global counter\n    counter += 1",
            "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n    global counter\n    with lock:\n        counter += 1",
        },
        "double_lock": {
            "type": "template",
            "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n    lock.acquire()\n    lock.acquire()\n    print('working')\n    lock.release()",
            "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n    with lock:\n        print('working')",
        },
        "global_nonatomic": {
            "type": "template",
            "buggy": "count = 0\ndef add():\n    global count\n    count = count + 1",
            "oracle": "count = 0\ndef add():\n    global count\n    count += 1",
        },
        "thread_safe_list": {
            "type": "template",
            "buggy": "import threading\nitems = []\ndef append_item(item):\n    items.append(item)",
            "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n    with lock:\n        items.append(item)",
        },
        "volatile_read": {
            "type": "template",
            "buggy": "import threading\nstop = False\ndef worker():\n    while not stop:\n        pass",
            "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n    while True:\n        with lock:\n            if stop:\n                break",
        },
    },
    "hardest": {
        "deadlock_order": {
            "type": "template",
            "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n    with lock1:\n        with lock2:\n            pass\ndef thread2():\n    with lock2:\n        with lock1:\n            pass",
            "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n    with lock1:\n        with lock2:\n            pass\ndef thread2():\n    with lock1:\n        with lock2:\n            pass",
        },
        "nested_lock_timeout": {
            "type": "template",
            "buggy": "import threading\nlock = threading.Lock()\ndef work():\n    lock.acquire()\n    # critical section\n    lock.release()",
            "oracle": "import threading\nlock = threading.Lock()\ndef work():\n    if lock.acquire(timeout=1):\n        try:\n            # critical section\n        finally:\n            lock.release()",
        },
        "fork_join": {
            "type": "template",
            "buggy": "import threading\ndef worker():\n    pass\nt = threading.Thread(target=worker)\nt.start()",
            "oracle": "import threading\ndef worker():\n    pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
        },
        "mutex_release": {
            "type": "template",
            "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n    lock.acquire()\n    lock.release()\ndef thread_B():\n    lock.release()",
            "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n    with lock:\n        pass\ndef thread_B():\n    with lock:\n        pass",
        },
        "race_on_init": {
            "type": "template",
            "buggy": "import threading\nitems = []\ndef init():\n    global items\n    items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
            "oracle": "import threading\nitems = []\ndef init():\n    global items\n    items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
        },
    },
}

# ----------------------------------------------------------------------
# 3. Derived helpers
# ----------------------------------------------------------------------
TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}

TEMPLATE_BUGS = {}
for level, bugs in BUG_DB.items():
    for bug_id, bug in bugs.items():
        if bug["type"] == "template":
            TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])

# ----------------------------------------------------------------------
# 4. RedTeam Controller (task‑aware)
# ----------------------------------------------------------------------
@dataclass
class RedTeam:
    task: str
    seed: Optional[int] = 42
    noise_prob: float = 0.2
    _random: random.Random = field(init=False)

    def __post_init__(self):
        self._random = random.Random(self.seed)

    def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
        """

        Returns: (buggy_code, bug_type, description, oracle_fix)

        Selects a bug appropriate for the task difficulty.

        """
        bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
        bug_type = self._random.choice(bug_list)

        # Template bug: return hardcoded buggy + oracle
        if bug_type in TEMPLATE_BUGS:
            buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
            description = f"Template bug: {bug_type}"
            if self._random.random() < self.noise_prob:
                buggy_code += "\n# TODO: refactor later"
            return buggy_code, bug_type, description, oracle_code

        # AST injection
        try:
            tree = ast.parse(original_code)
        except SyntaxError:
            return original_code, "parse_error", "Syntax error in original code", original_code

        injector = ASTBugInjector(bug_type)
        modified_tree = injector.visit(tree)
        ast.fix_missing_locations(modified_tree)

        if injector.modified:
            buggy_code = ast.unparse(modified_tree)
            oracle_fix = original_code
            description = f"AST bug: {bug_type}"
        else:
            buggy_code = original_code
            oracle_fix = original_code
            bug_type = "no_op"
            description = "No suitable code structure found for injection"

        if self._random.random() < self.noise_prob:
            buggy_code += "\n# TODO: refactor later"

        return buggy_code, bug_type, description, oracle_fix