# redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels) import ast import random from dataclasses import dataclass, field from typing import Tuple, Optional, List, Dict # ---------------------------------------------------------------------- # 1. AST Bug Injector (extended for all simple bugs) # ---------------------------------------------------------------------- class ASTBugInjector(ast.NodeTransformer): def __init__(self, bug_type: str): super().__init__() self.bug_type = bug_type self.modified = False # --- Easy: null_check, simple_typo, string_index, default_value, empty_return --- def visit_If(self, node: ast.If): # null_check: remove the if-guard if self.bug_type == "null_check" and not self.modified: if node.body and len(node.body) == 1: self.modified = True return node.body[0] # division_by_zero_empty: remove the empty check if self.bug_type == "division_by_zero_empty" and not self.modified: # pattern: if not data: return 0 – we delete the entire if if (isinstance(node.test, ast.UnaryOp) and isinstance(node.test.op, ast.Not) and isinstance(node.test.operand, ast.Name)): self.modified = True return None # signal to remove this node from parent return self.generic_visit(node) def visit_Name(self, node: ast.Name): if self.bug_type == "simple_typo" and not self.modified: if node.id == "users": self.modified = True return ast.Name(id="usres", ctx=node.ctx) return self.generic_visit(node) def visit_Subscript(self, node: ast.Subscript): if self.bug_type == "string_index" and not self.modified: if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant): old_val = node.slice.value.value if isinstance(old_val, int): self.modified = True node.slice = ast.Index(value=ast.Constant(value=old_val + 1)) return self.generic_visit(node) def visit_Call(self, node: ast.Call): # default_value: change dict.get(key) to dict[key] (no default) if self.bug_type == "default_value" and not self.modified: if (isinstance(node.func, ast.Attribute) and node.func.attr == "get" and len(node.args) == 1): self.modified = True return ast.Subscript( value=node.func.value, slice=ast.Index(value=node.args[0]), ctx=node.ctx ) # abs_usage: remove abs() if self.bug_type == "abs_usage" and not self.modified: if isinstance(node.func, ast.Name) and node.func.id == "abs": self.modified = True return node.args[0] return self.generic_visit(node) def visit_FunctionDef(self, node: ast.FunctionDef): # empty_return: insert a premature return None if self.bug_type == "empty_return" and not self.modified: self.modified = True node.body.insert(0, ast.Return(value=ast.Constant(value=None))) return self.generic_visit(node) # --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var --- def visit_For(self, node: ast.For): if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified: if (isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name) and node.iter.func.id == "range"): if self.bug_type == "off_by_one": new_iter = ast.Call( func=ast.Name(id='range', ctx=ast.Load()), args=[ ast.Constant(value=1), ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1)) ], keywords=[] ) node.iter = new_iter self.modified = True elif self.bug_type == "loop_skip" and len(node.iter.args) == 1: new_iter = ast.Call( func=ast.Name(id='range', ctx=ast.Load()), args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))], keywords=[] ) node.iter = new_iter self.modified = True return self.generic_visit(node) def visit_BinOp(self, node: ast.BinOp): # sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv if not self.modified: if self.bug_type in ("wrong_operator", "sign_error"): if isinstance(node.op, ast.Add): node.op = ast.Sub() self.modified = True elif isinstance(node.op, ast.Sub): node.op = ast.Add() self.modified = True elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div): node.op = ast.FloorDiv() self.modified = True return self.generic_visit(node) def visit_arguments(self, node: ast.arguments): # swap_args: swap first two arguments of a function if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2: self.modified = True node.args[0], node.args[1] = node.args[1], node.args[0] return self.generic_visit(node) def visit_Assign(self, node: ast.Assign): # uninitialised_var: remove an assignment statement (replaced with Pass) if self.bug_type == "uninitialised_var" and not self.modified: self.modified = True return ast.Pass() return self.generic_visit(node) # ---------------------------------------------------------------------- # 2. Bug database (25 bugs, categorized by difficulty) # ---------------------------------------------------------------------- BUG_DB = { "easy": { "null_check": {"type": "ast", "bug_type": "null_check"}, "simple_typo": {"type": "ast", "bug_type": "simple_typo"}, "string_index": {"type": "ast", "bug_type": "string_index"}, "default_value": {"type": "ast", "bug_type": "default_value"}, "empty_return": {"type": "ast", "bug_type": "empty_return"}, }, "medium": { "off_by_one": {"type": "ast", "bug_type": "off_by_one"}, "loop_skip": {"type": "ast", "bug_type": "loop_skip"}, "sign_error": {"type": "ast", "bug_type": "sign_error"}, "swap_args": {"type": "ast", "bug_type": "swap_args"}, "uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"}, }, "hard": { "division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"}, "division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"}, # same injector "float_precision": {"type": "ast", "bug_type": "float_precision"}, "abs_usage": {"type": "ast", "bug_type": "abs_usage"}, "round_error": {"type": "ast", "bug_type": "round_error"}, # can be extended }, "harder": { "missing_lock": { "type": "template", "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1", "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1", }, "double_lock": { "type": "template", "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()", "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')", }, "global_nonatomic": { "type": "template", "buggy": "count = 0\ndef add():\n global count\n count = count + 1", "oracle": "count = 0\ndef add():\n global count\n count += 1", }, "thread_safe_list": { "type": "template", "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)", "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)", }, "volatile_read": { "type": "template", "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass", "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break", }, }, "hardest": { "deadlock_order": { "type": "template", "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass", "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass", }, "nested_lock_timeout": { "type": "template", "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()", "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()", }, "fork_join": { "type": "template", "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()", "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()", }, "mutex_release": { "type": "template", "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()", "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass", }, "race_on_init": { "type": "template", "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)", "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)", }, }, } # ---------------------------------------------------------------------- # 3. Derived helpers # ---------------------------------------------------------------------- TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()} TEMPLATE_BUGS = {} for level, bugs in BUG_DB.items(): for bug_id, bug in bugs.items(): if bug["type"] == "template": TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"]) # ---------------------------------------------------------------------- # 4. RedTeam Controller (task‑aware) # ---------------------------------------------------------------------- @dataclass class RedTeam: task: str seed: Optional[int] = 42 noise_prob: float = 0.2 _random: random.Random = field(init=False) def __post_init__(self): self._random = random.Random(self.seed) def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]: """ Returns: (buggy_code, bug_type, description, oracle_fix) Selects a bug appropriate for the task difficulty. """ bug_list = TASK_BUG_MAP.get(self.task, ["null_check"]) bug_type = self._random.choice(bug_list) # Template bug: return hardcoded buggy + oracle if bug_type in TEMPLATE_BUGS: buggy_code, oracle_code = TEMPLATE_BUGS[bug_type] description = f"Template bug: {bug_type}" if self._random.random() < self.noise_prob: buggy_code += "\n# TODO: refactor later" return buggy_code, bug_type, description, oracle_code # AST injection try: tree = ast.parse(original_code) except SyntaxError: return original_code, "parse_error", "Syntax error in original code", original_code injector = ASTBugInjector(bug_type) modified_tree = injector.visit(tree) ast.fix_missing_locations(modified_tree) if injector.modified: buggy_code = ast.unparse(modified_tree) oracle_fix = original_code description = f"AST bug: {bug_type}" else: buggy_code = original_code oracle_fix = original_code bug_type = "no_op" description = "No suitable code structure found for injection" if self._random.random() < self.noise_prob: buggy_code += "\n# TODO: refactor later" return buggy_code, bug_type, description, oracle_fix