Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import random | |
| from dataclasses import dataclass | |
| from seed_bank import SEED_BANK, SeedSpec | |
| try: | |
| from server.bug_injector import inject_bug | |
| from server.executor import execute_code | |
| from server.graders import compute_ast_distance | |
| except ImportError: | |
| from .server.bug_injector import inject_bug | |
| from .server.executor import execute_code | |
| from .server.graders import compute_ast_distance | |
| V1_BUG_OPERATORS = ( | |
| "wrong_operator", | |
| "wrong_builtin", | |
| "condition_negation", | |
| "off_by_one", | |
| "loop_boundary_shift", | |
| "slice_boundary_corruption", | |
| ) | |
| MAX_VERIFIED_BUGS_PER_SEED = 4 | |
| HOLDOUT_BUGS_PER_SEED = 1 | |
| MAX_MUTATION_ATTEMPTS = 4 | |
| BUG_OPERATOR_PRIORITY = { | |
| "loop_boundary_shift": 6, | |
| "slice_boundary_corruption": 5, | |
| "condition_negation": 4, | |
| "wrong_operator": 3, | |
| "off_by_one": 2, | |
| "wrong_builtin": 1, | |
| } | |
| class BugSample: | |
| seed_id: str | |
| original_code: str | |
| buggy_code: str | |
| bug_operator: str | |
| execution_result: str | |
| class BugBank: | |
| train_samples: tuple[BugSample, ...] | |
| eval_samples: tuple[BugSample, ...] | |
| def validate_seed(seed: SeedSpec) -> None: | |
| result = execute_code(seed.original_code, seed.test) | |
| if result.syntax_error or not result.passed: | |
| raise ValueError(f"Seed {seed.seed_id} does not pass its canonical tests.") | |
| def build_bug_bank( | |
| seeds: tuple[SeedSpec, ...] = SEED_BANK, | |
| max_verified_bugs_per_seed: int = MAX_VERIFIED_BUGS_PER_SEED, | |
| holdout_bugs_per_seed: int = HOLDOUT_BUGS_PER_SEED, | |
| ) -> BugBank: | |
| train_samples: list[BugSample] = [] | |
| eval_samples: list[BugSample] = [] | |
| for seed in seeds: | |
| validate_seed(seed) | |
| verified_samples = _collect_verified_bugs(seed) | |
| verified_samples = sorted( | |
| verified_samples, | |
| key=lambda sample: _bug_difficulty_score(seed, sample), | |
| reverse=True, | |
| ) | |
| if len(verified_samples) <= holdout_bugs_per_seed: | |
| raise ValueError( | |
| f"Seed {seed.seed_id} only produced {len(verified_samples)} verified bugs." | |
| ) | |
| eval_samples.extend(verified_samples[:holdout_bugs_per_seed]) | |
| train_samples.extend( | |
| verified_samples[ | |
| holdout_bugs_per_seed : holdout_bugs_per_seed + max_verified_bugs_per_seed | |
| ] | |
| ) | |
| return BugBank( | |
| train_samples=tuple(train_samples), | |
| eval_samples=tuple(eval_samples), | |
| ) | |
| def _collect_verified_bugs(seed: SeedSpec) -> list[BugSample]: | |
| verified_samples: list[BugSample] = [] | |
| seen_codes: set[str] = set() | |
| for bug_operator in V1_BUG_OPERATORS: | |
| for attempt in range(MAX_MUTATION_ATTEMPTS): | |
| random.seed(f"{seed.seed_id}:{bug_operator}:{attempt}") | |
| buggy_code, changed = inject_bug(seed.original_code, bug_operator) | |
| if not changed: | |
| continue | |
| if buggy_code in seen_codes: | |
| continue | |
| result = execute_code(buggy_code, seed.test) | |
| if result.syntax_error or result.passed: | |
| continue | |
| seen_codes.add(buggy_code) | |
| verified_samples.append( | |
| BugSample( | |
| seed_id=seed.seed_id, | |
| original_code=seed.original_code, | |
| buggy_code=buggy_code, | |
| bug_operator=bug_operator, | |
| execution_result=result.output[:500] if result.output else "", | |
| ) | |
| ) | |
| return verified_samples | |
| def _bug_difficulty_score(seed: SeedSpec, sample: BugSample) -> float: | |
| operator_score = BUG_OPERATOR_PRIORITY.get(sample.bug_operator, 0) | |
| ast_similarity = compute_ast_distance(seed.original_code, sample.buggy_code) | |
| execution_lines = _count_nonempty_lines(sample.execution_result) | |
| # Bias toward bugs that preserve the function shape but still require a real local repair. | |
| local_repair_score = ast_similarity | |
| execution_signal = min(execution_lines / 4.0, 1.0) | |
| return float(operator_score) + local_repair_score + execution_signal | |
| def _count_nonempty_lines(text: str) -> int: | |
| return sum(1 for line in text.splitlines() if line.strip()) | |