debugZero / bug_bank.py
The-Fool-09's picture
Upload folder using huggingface_hub
51457b7 verified
from __future__ import annotations
import random
from dataclasses import dataclass
from seed_bank import SEED_BANK, SeedSpec
try:
from server.bug_injector import inject_bug
from server.executor import execute_code
from server.graders import compute_ast_distance
except ImportError:
from .server.bug_injector import inject_bug
from .server.executor import execute_code
from .server.graders import compute_ast_distance
V1_BUG_OPERATORS = (
"wrong_operator",
"wrong_builtin",
"condition_negation",
"off_by_one",
"loop_boundary_shift",
"slice_boundary_corruption",
)
MAX_VERIFIED_BUGS_PER_SEED = 4
HOLDOUT_BUGS_PER_SEED = 1
MAX_MUTATION_ATTEMPTS = 4
BUG_OPERATOR_PRIORITY = {
"loop_boundary_shift": 6,
"slice_boundary_corruption": 5,
"condition_negation": 4,
"wrong_operator": 3,
"off_by_one": 2,
"wrong_builtin": 1,
}
@dataclass(frozen=True)
class BugSample:
seed_id: str
original_code: str
buggy_code: str
bug_operator: str
execution_result: str
@dataclass(frozen=True)
class BugBank:
train_samples: tuple[BugSample, ...]
eval_samples: tuple[BugSample, ...]
def validate_seed(seed: SeedSpec) -> None:
result = execute_code(seed.original_code, seed.test)
if result.syntax_error or not result.passed:
raise ValueError(f"Seed {seed.seed_id} does not pass its canonical tests.")
def build_bug_bank(
seeds: tuple[SeedSpec, ...] = SEED_BANK,
max_verified_bugs_per_seed: int = MAX_VERIFIED_BUGS_PER_SEED,
holdout_bugs_per_seed: int = HOLDOUT_BUGS_PER_SEED,
) -> BugBank:
train_samples: list[BugSample] = []
eval_samples: list[BugSample] = []
for seed in seeds:
validate_seed(seed)
verified_samples = _collect_verified_bugs(seed)
verified_samples = sorted(
verified_samples,
key=lambda sample: _bug_difficulty_score(seed, sample),
reverse=True,
)
if len(verified_samples) <= holdout_bugs_per_seed:
raise ValueError(
f"Seed {seed.seed_id} only produced {len(verified_samples)} verified bugs."
)
eval_samples.extend(verified_samples[:holdout_bugs_per_seed])
train_samples.extend(
verified_samples[
holdout_bugs_per_seed : holdout_bugs_per_seed + max_verified_bugs_per_seed
]
)
return BugBank(
train_samples=tuple(train_samples),
eval_samples=tuple(eval_samples),
)
def _collect_verified_bugs(seed: SeedSpec) -> list[BugSample]:
verified_samples: list[BugSample] = []
seen_codes: set[str] = set()
for bug_operator in V1_BUG_OPERATORS:
for attempt in range(MAX_MUTATION_ATTEMPTS):
random.seed(f"{seed.seed_id}:{bug_operator}:{attempt}")
buggy_code, changed = inject_bug(seed.original_code, bug_operator)
if not changed:
continue
if buggy_code in seen_codes:
continue
result = execute_code(buggy_code, seed.test)
if result.syntax_error or result.passed:
continue
seen_codes.add(buggy_code)
verified_samples.append(
BugSample(
seed_id=seed.seed_id,
original_code=seed.original_code,
buggy_code=buggy_code,
bug_operator=bug_operator,
execution_result=result.output[:500] if result.output else "",
)
)
return verified_samples
def _bug_difficulty_score(seed: SeedSpec, sample: BugSample) -> float:
operator_score = BUG_OPERATOR_PRIORITY.get(sample.bug_operator, 0)
ast_similarity = compute_ast_distance(seed.original_code, sample.buggy_code)
execution_lines = _count_nonempty_lines(sample.execution_result)
# Bias toward bugs that preserve the function shape but still require a real local repair.
local_repair_score = ast_similarity
execution_signal = min(execution_lines / 4.0, 1.0)
return float(operator_score) + local_repair_score + execution_signal
def _count_nonempty_lines(text: str) -> int:
return sum(1 for line in text.splitlines() if line.strip())