Spaces:
Sleeping
Sleeping
File size: 4,398 Bytes
3ba81b5 51457b7 3ba81b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | from __future__ import annotations
import random
from dataclasses import dataclass
from seed_bank import SEED_BANK, SeedSpec
try:
from server.bug_injector import inject_bug
from server.executor import execute_code
from server.graders import compute_ast_distance
except ImportError:
from .server.bug_injector import inject_bug
from .server.executor import execute_code
from .server.graders import compute_ast_distance
V1_BUG_OPERATORS = (
"wrong_operator",
"wrong_builtin",
"condition_negation",
"off_by_one",
"loop_boundary_shift",
"slice_boundary_corruption",
)
MAX_VERIFIED_BUGS_PER_SEED = 4
HOLDOUT_BUGS_PER_SEED = 1
MAX_MUTATION_ATTEMPTS = 4
BUG_OPERATOR_PRIORITY = {
"loop_boundary_shift": 6,
"slice_boundary_corruption": 5,
"condition_negation": 4,
"wrong_operator": 3,
"off_by_one": 2,
"wrong_builtin": 1,
}
@dataclass(frozen=True)
class BugSample:
seed_id: str
original_code: str
buggy_code: str
bug_operator: str
execution_result: str
@dataclass(frozen=True)
class BugBank:
train_samples: tuple[BugSample, ...]
eval_samples: tuple[BugSample, ...]
def validate_seed(seed: SeedSpec) -> None:
result = execute_code(seed.original_code, seed.test)
if result.syntax_error or not result.passed:
raise ValueError(f"Seed {seed.seed_id} does not pass its canonical tests.")
def build_bug_bank(
seeds: tuple[SeedSpec, ...] = SEED_BANK,
max_verified_bugs_per_seed: int = MAX_VERIFIED_BUGS_PER_SEED,
holdout_bugs_per_seed: int = HOLDOUT_BUGS_PER_SEED,
) -> BugBank:
train_samples: list[BugSample] = []
eval_samples: list[BugSample] = []
for seed in seeds:
validate_seed(seed)
verified_samples = _collect_verified_bugs(seed)
verified_samples = sorted(
verified_samples,
key=lambda sample: _bug_difficulty_score(seed, sample),
reverse=True,
)
if len(verified_samples) <= holdout_bugs_per_seed:
raise ValueError(
f"Seed {seed.seed_id} only produced {len(verified_samples)} verified bugs."
)
eval_samples.extend(verified_samples[:holdout_bugs_per_seed])
train_samples.extend(
verified_samples[
holdout_bugs_per_seed : holdout_bugs_per_seed + max_verified_bugs_per_seed
]
)
return BugBank(
train_samples=tuple(train_samples),
eval_samples=tuple(eval_samples),
)
def _collect_verified_bugs(seed: SeedSpec) -> list[BugSample]:
verified_samples: list[BugSample] = []
seen_codes: set[str] = set()
for bug_operator in V1_BUG_OPERATORS:
for attempt in range(MAX_MUTATION_ATTEMPTS):
random.seed(f"{seed.seed_id}:{bug_operator}:{attempt}")
buggy_code, changed = inject_bug(seed.original_code, bug_operator)
if not changed:
continue
if buggy_code in seen_codes:
continue
result = execute_code(buggy_code, seed.test)
if result.syntax_error or result.passed:
continue
seen_codes.add(buggy_code)
verified_samples.append(
BugSample(
seed_id=seed.seed_id,
original_code=seed.original_code,
buggy_code=buggy_code,
bug_operator=bug_operator,
execution_result=result.output[:500] if result.output else "",
)
)
return verified_samples
def _bug_difficulty_score(seed: SeedSpec, sample: BugSample) -> float:
operator_score = BUG_OPERATOR_PRIORITY.get(sample.bug_operator, 0)
ast_similarity = compute_ast_distance(seed.original_code, sample.buggy_code)
execution_lines = _count_nonempty_lines(sample.execution_result)
# Bias toward bugs that preserve the function shape but still require a real local repair.
local_repair_score = ast_similarity
execution_signal = min(execution_lines / 4.0, 1.0)
return float(operator_score) + local_repair_score + execution_signal
def _count_nonempty_lines(text: str) -> int:
return sum(1 for line in text.splitlines() if line.strip())
|