| | """Additional benchmark adapters (dpaia EE-Dataset, Multi-SWE-bench, SWE-bench |
| | Multilingual, CrossCodeEval, RepoBench, McEval, MultiPL-E, Defects4J).""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | from typing import Any |
| |
|
| | from adapters import DatasetAdapter |
| | from adapters.code_editing import SWEBenchLiteAdapter |
| |
|
| | |
| | _highlight_code = None |
| | _code_offset = None |
| | _extract_test_classes = None |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class DPAIAEEDatasetAdapter(DatasetAdapter): |
| | slug = "dpaia-ee" |
| | display_name = "DPAIA EE-Dataset" |
| | has_ground_truth = False |
| | has_tasks = False |
| |
|
| | def __init__(self, rows: list[dict[str, Any]]): |
| | self._rows = rows |
| |
|
| | def problem_count(self) -> int: |
| | return len(self._rows) |
| |
|
| | def get_problem_summary(self, idx: int) -> dict[str, Any]: |
| | row = self._rows[idx] |
| | tags = row.get("tags", []) |
| | tag_str = ", ".join(tags[:3]) if isinstance(tags, list) else str(tags) |
| | return { |
| | "idx": idx, |
| | "task_id": row.get("instance_id", str(idx)), |
| | "entry_point": row.get("repo", f"dpaia_{idx}"), |
| | "num_inputs": 0, |
| | "source": tag_str or "DPAIA", |
| | } |
| |
|
| | def get_problem_detail(self, idx: int) -> dict[str, Any]: |
| | row = self._rows[idx] |
| | patch = row.get("patch", "") |
| | test_patch = row.get("test_patch", "") |
| | fail_to_pass = row.get("FAIL_TO_PASS", []) |
| | if isinstance(fail_to_pass, str): |
| | try: |
| | fail_to_pass = json.loads(fail_to_pass) |
| | except (json.JSONDecodeError, TypeError): |
| | fail_to_pass = [fail_to_pass] |
| | pass_to_pass = row.get("PASS_TO_PASS", []) |
| | if isinstance(pass_to_pass, str): |
| | try: |
| | pass_to_pass = json.loads(pass_to_pass) |
| | except (json.JSONDecodeError, TypeError): |
| | pass_to_pass = [pass_to_pass] |
| |
|
| | instance_id = row.get("instance_id", str(idx)) |
| | repo = row.get("repo", "") |
| |
|
| | return { |
| | "idx": idx, |
| | "task_id": instance_id, |
| | "entry_point": repo or f"dpaia_{idx}", |
| | "code": patch, |
| | "highlighted_code": "", |
| | "inputs": [], |
| | "outputs": [], |
| | "test": None, |
| | "tasks": [], |
| | "source": ", ".join(row.get("tags", [])[:3]) |
| | if isinstance(row.get("tags"), list) |
| | else "DPAIA", |
| | "has_ground_truth": False, |
| | "has_tasks": False, |
| | "description": row.get("problem_statement", ""), |
| | "patch": patch, |
| | "test_patch": test_patch, |
| | "fail_to_pass": fail_to_pass, |
| | "pass_to_pass": pass_to_pass, |
| | "repo": repo, |
| | "base_commit": row.get("base_commit", ""), |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class MultiSWEBenchAdapter(DatasetAdapter): |
| | slug = "multiswebench" |
| | display_name = "Multi-SWE-bench" |
| | has_ground_truth = False |
| | has_tasks = False |
| |
|
| | def __init__(self, rows: list[dict[str, Any]]): |
| | self._rows = rows |
| |
|
| | def problem_count(self) -> int: |
| | return len(self._rows) |
| |
|
| | def get_problem_summary(self, idx: int) -> dict[str, Any]: |
| | row = self._rows[idx] |
| | instance_id = row.get("instance_id", str(idx)) |
| | org = row.get("org", "") |
| | repo = row.get("repo", "") |
| | full_repo = f"{org}/{repo}" if org and repo else repo |
| | return { |
| | "idx": idx, |
| | "task_id": instance_id, |
| | "entry_point": instance_id.split("__")[-1] if instance_id else f"mswe_{idx}", |
| | "num_inputs": 0, |
| | "source": row.get("_language", full_repo or "unknown"), |
| | } |
| |
|
| | def get_problem_detail(self, idx: int) -> dict[str, Any]: |
| | row = self._rows[idx] |
| | patch = row.get("fix_patch", "") |
| | instance_id = row.get("instance_id", str(idx)) |
| | org = row.get("org", "") |
| | repo_name = row.get("repo", "") |
| | full_repo = f"{org}/{repo_name}" if org and repo_name else repo_name |
| | lang = row.get("_language", "") |
| | number = row.get("number", "") |
| |
|
| | |
| | title = row.get("title", "") |
| | body = row.get("body", "") |
| | description = title |
| | if body: |
| | description += "\n\n" + body |
| |
|
| | links: dict[str, str] = {} |
| | if full_repo: |
| | links["repo_url"] = f"https://github.com/{full_repo}" |
| | if number and full_repo: |
| | links["issue_url"] = f"https://github.com/{full_repo}/pull/{number}" |
| |
|
| | return { |
| | "idx": idx, |
| | "task_id": instance_id, |
| | "entry_point": instance_id.split("__")[-1] if instance_id else f"mswe_{idx}", |
| | "code": patch, |
| | "highlighted_code": "", |
| | "inputs": [], |
| | "outputs": [], |
| | "test": None, |
| | "tasks": [], |
| | "source": lang or full_repo, |
| | "has_ground_truth": False, |
| | "has_tasks": False, |
| | "description": description, |
| | "patch": patch, |
| | "test_patch": row.get("test_patch", ""), |
| | "fail_to_pass": [], |
| | "pass_to_pass": [], |
| | "repo": full_repo, |
| | "hints": row.get("hints", ""), |
| | **links, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class SWEBenchMultilingualAdapter(SWEBenchLiteAdapter): |
| | slug = "swebenchmultilingual" |
| | display_name = "SWE-bench Multilingual" |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class CrossCodeEvalAdapter(DatasetAdapter): |
| | slug = "crosscodeeval" |
| | display_name = "CrossCodeEval" |
| | has_ground_truth = False |
| | has_tasks = False |
| |
|
| | def __init__(self, rows: list[dict[str, Any]]): |
| | self._rows = rows |
| |
|
| | def problem_count(self) -> int: |
| | return len(self._rows) |
| |
|
| | @staticmethod |
| | def _get_metadata(row: dict, key: str, default: str = "") -> str: |
| | """Extract a value from the nested metadata dict.""" |
| | meta = row.get("metadata", {}) |
| | if isinstance(meta, dict): |
| | return meta.get(key, default) |
| | return default |
| |
|
| | def get_problem_summary(self, idx: int) -> dict[str, Any]: |
| | row = self._rows[idx] |
| | task_id = self._get_metadata(row, "task_id", str(idx)) |
| | return { |
| | "idx": idx, |
| | "task_id": task_id, |
| | "entry_point": task_id.rsplit("/", 1)[-1] if task_id else f"cceval_{idx}", |
| | "num_inputs": 0, |
| | "source": row.get("language", "unknown"), |
| | } |
| |
|
| | def get_problem_detail(self, idx: int) -> dict[str, Any]: |
| | row = self._rows[idx] |
| | prompt = row.get("prompt", "") |
| | reference = row.get("groundtruth", "") |
| | right_context = row.get("right_context", "") |
| | lang = row.get("language", "python") |
| | lang_key = lang.lower() |
| |
|
| | task_id = self._get_metadata(row, "task_id", str(idx)) |
| |
|
| | |
| | display_code = prompt + "\n/* [HOLE] */\n" + right_context |
| | merged_code = prompt + reference + right_context if reference else prompt + right_context |
| |
|
| | before_hole = prompt |
| | gt_start_line = before_hole.count("\n") + 1 |
| | gt_line_count = reference.count("\n") + (1 if reference else 0) |
| | gt_end_line = gt_start_line + gt_line_count - 1 |
| |
|
| | return { |
| | "idx": idx, |
| | "task_id": task_id, |
| | "entry_point": task_id.rsplit("/", 1)[-1] if task_id else f"cceval_{idx}", |
| | "code": display_code, |
| | "highlighted_code": _highlight_code(display_code, language=lang_key), |
| | "inputs": [], |
| | "outputs": [], |
| | "test": None, |
| | "tasks": [], |
| | "source": lang, |
| | "has_ground_truth": False, |
| | "has_tasks": False, |
| | "fim_prefix": prompt, |
| | "fim_ground_truth": reference, |
| | "fim_ground_truth_highlighted": _highlight_code(reference, language=lang_key) |
| | if reference |
| | else "", |
| | "fim_merged_code": merged_code, |
| | "fim_merged_highlighted": _highlight_code( |
| | merged_code, |
| | highlight_lines=list(range(gt_start_line, gt_end_line + 1)), |
| | language=lang_key, |
| | ) |
| | if merged_code |
| | else "", |
| | "fim_gt_start_line": gt_start_line, |
| | "fim_gt_end_line": gt_end_line, |
| | "language": lang, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class RepoBenchAdapter(DatasetAdapter): |
| | slug = "repobench" |
| | display_name = "RepoBench" |
| | has_ground_truth = False |
| | has_tasks = False |
| |
|
| | def __init__(self, rows: list[dict[str, Any]]): |
| | self._rows = rows |
| |
|
| | def problem_count(self) -> int: |
| | return len(self._rows) |
| |
|
| | def get_problem_summary(self, idx: int) -> dict[str, Any]: |
| | row = self._rows[idx] |
| | return { |
| | "idx": idx, |
| | "task_id": str(row.get("repo_name", idx)), |
| | "entry_point": row.get("file_path", f"repobench_{idx}").rsplit("/", 1)[-1], |
| | "num_inputs": 0, |
| | "source": row.get("language", row.get("_setting", "unknown")), |
| | } |
| |
|
| | def get_problem_detail(self, idx: int) -> dict[str, Any]: |
| | row = self._rows[idx] |
| | |
| | context = row.get("all_code", row.get("context", "")) |
| | next_line = row.get("next_line", row.get("gold_snippet_code", "")) |
| | lang = row.get("language", "python") |
| | lang_key = lang.lower() |
| |
|
| | display_code = context + "\n/* [HOLE] */\n" if context else "" |
| | merged_code = context + "\n" + next_line if context and next_line else context |
| |
|
| | gt_start_line = context.count("\n") + 2 if context else 1 |
| | gt_line_count = next_line.count("\n") + 1 if next_line else 0 |
| | gt_end_line = gt_start_line + gt_line_count - 1 |
| |
|
| | return { |
| | "idx": idx, |
| | "task_id": str(row.get("repo_name", idx)), |
| | "entry_point": row.get("file_path", f"repobench_{idx}").rsplit("/", 1)[-1], |
| | "code": display_code, |
| | "highlighted_code": _highlight_code(display_code, language=lang_key) |
| | if display_code |
| | else "", |
| | "inputs": [], |
| | "outputs": [], |
| | "test": None, |
| | "tasks": [], |
| | "source": row.get("_setting", lang), |
| | "has_ground_truth": False, |
| | "has_tasks": False, |
| | "fim_prefix": context, |
| | "fim_ground_truth": next_line, |
| | "fim_ground_truth_highlighted": _highlight_code(next_line, language=lang_key) |
| | if next_line |
| | else "", |
| | "fim_merged_code": merged_code, |
| | "fim_merged_highlighted": _highlight_code( |
| | merged_code, |
| | highlight_lines=list(range(gt_start_line, gt_end_line + 1)), |
| | language=lang_key, |
| | ) |
| | if merged_code |
| | else "", |
| | "fim_gt_start_line": gt_start_line, |
| | "fim_gt_end_line": gt_end_line, |
| | "language": lang, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class McEvalAdapter(DatasetAdapter): |
| | slug = "mceval" |
| | display_name = "McEval" |
| | has_ground_truth = False |
| | has_tasks = False |
| |
|
| | def __init__(self, hf_dataset): |
| | self._ds = hf_dataset |
| |
|
| | def problem_count(self) -> int: |
| | return len(self._ds) |
| |
|
| | def get_problem_summary(self, idx: int) -> dict[str, Any]: |
| | row = self._ds[idx] |
| | return { |
| | "idx": idx, |
| | "task_id": row.get("task_id", str(idx)), |
| | "entry_point": row.get("entry_point", row.get("task_id", f"mceval_{idx}")), |
| | "num_inputs": 0, |
| | "source": row.get("language", "unknown"), |
| | } |
| |
|
| | def get_problem_detail(self, idx: int) -> dict[str, Any]: |
| | row = self._ds[idx] |
| | prompt = row.get("prompt", "") |
| | canonical = row.get("canonical_solution", "") |
| | code = prompt + canonical |
| | lang = row.get("language", "python") |
| | lang_key = lang.lower() |
| | |
| | lang_map = { |
| | "c++": "cpp", |
| | "c#": "csharp", |
| | "objective-c": "objectivec", |
| | "visual basic": "vb.net", |
| | "typescript": "typescript", |
| | } |
| | lang_key = lang_map.get(lang_key, lang_key) |
| |
|
| | return { |
| | "idx": idx, |
| | "task_id": row.get("task_id", str(idx)), |
| | "entry_point": row.get("entry_point", row.get("task_id", f"mceval_{idx}")), |
| | "code": code, |
| | "highlighted_code": _highlight_code(code, language=lang_key), |
| | "inputs": [], |
| | "outputs": [], |
| | "test": row.get("test", ""), |
| | "tasks": [], |
| | "source": lang, |
| | "has_ground_truth": False, |
| | "has_tasks": False, |
| | "description": row.get("prompt", ""), |
| | "language": lang, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class MultiPLEAdapter(DatasetAdapter): |
| | slug = "multiple" |
| | display_name = "MultiPL-E" |
| | has_ground_truth = False |
| | has_tasks = False |
| |
|
| | |
| | LANGUAGES = ["py", "cpp", "java", "js", "ts", "go", "rs", "cs", "rb", "lua"] |
| |
|
| | _LANG_LABELS = { |
| | "py": "Python", |
| | "cpp": "C++", |
| | "java": "Java", |
| | "js": "JavaScript", |
| | "ts": "TypeScript", |
| | "go": "Go", |
| | "rs": "Rust", |
| | "cs": "C#", |
| | "rb": "Ruby", |
| | "lua": "Lua", |
| | } |
| | _LANG_PYGMENTS = { |
| | "py": "python", |
| | "cpp": "cpp", |
| | "java": "java", |
| | "js": "javascript", |
| | "ts": "typescript", |
| | "go": "go", |
| | "rs": "rust", |
| | "cs": "csharp", |
| | "rb": "ruby", |
| | "lua": "lua", |
| | } |
| |
|
| | def __init__(self, datasets_by_lang: dict[str, Any]): |
| | self._by_lang = datasets_by_lang |
| | first_lang = next(iter(self._by_lang)) |
| | self._count = len(self._by_lang[first_lang]) |
| |
|
| | def problem_count(self) -> int: |
| | return self._count |
| |
|
| | def get_problem_summary(self, idx: int) -> dict[str, Any]: |
| | first_lang = next(iter(self._by_lang)) |
| | row = self._by_lang[first_lang][idx] |
| | return { |
| | "idx": idx, |
| | "task_id": row.get("name", str(idx)), |
| | "entry_point": row.get("name", f"multiple_{idx}"), |
| | "num_inputs": len(self._by_lang), |
| | "source": "MultiPL-E", |
| | } |
| |
|
| | def get_problem_detail(self, idx: int) -> dict[str, Any]: |
| | first_lang = next(iter(self._by_lang)) |
| | row = self._by_lang[first_lang][idx] |
| |
|
| | lang_solutions = [] |
| | for lang in self._by_lang: |
| | lrow = self._by_lang[lang][idx] |
| | prompt = lrow.get("prompt", "") |
| | |
| | tests = lrow.get("tests", "") |
| | lang_key = self._LANG_PYGMENTS.get(lang, lang) |
| | lang_label = self._LANG_LABELS.get(lang, lang) |
| | lang_solutions.append( |
| | { |
| | "language": lang, |
| | "language_label": lang_label, |
| | "code": prompt, |
| | "highlighted_code": _highlight_code(prompt, language=lang_key), |
| | "test": tests, |
| | } |
| | ) |
| |
|
| | py_row = self._by_lang.get("py", self._by_lang[first_lang])[idx] |
| | default_code = py_row.get("prompt", "") |
| |
|
| | return { |
| | "idx": idx, |
| | "task_id": row.get("name", str(idx)), |
| | "entry_point": row.get("name", f"multiple_{idx}"), |
| | "code": default_code, |
| | "highlighted_code": _highlight_code(default_code), |
| | "inputs": [], |
| | "outputs": [], |
| | "test": py_row.get("tests", ""), |
| | "tasks": [], |
| | "source": "MultiPL-E", |
| | "has_ground_truth": False, |
| | "has_tasks": False, |
| | "lang_solutions": lang_solutions, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class Defects4JAdapter(DatasetAdapter): |
| | slug = "defects4j" |
| | display_name = "Defects4J" |
| | has_ground_truth = False |
| | has_tasks = False |
| |
|
| | def __init__(self, hf_dataset): |
| | self._ds = hf_dataset |
| |
|
| | def problem_count(self) -> int: |
| | return len(self._ds) |
| |
|
| | @staticmethod |
| | def _project_from_bug_id(bug_id: str) -> str: |
| | """Extract project name from bug_id like 'Compress-35'.""" |
| | return bug_id.rsplit("-", 1)[0] if "-" in bug_id else bug_id |
| |
|
| | def get_problem_summary(self, idx: int) -> dict[str, Any]: |
| | row = self._ds[idx] |
| | bug_id = row.get("bug_id", str(idx)) |
| | project = self._project_from_bug_id(bug_id) |
| | return { |
| | "idx": idx, |
| | "task_id": bug_id, |
| | "entry_point": project, |
| | "num_inputs": 0, |
| | "source": project, |
| | } |
| |
|
| | def get_problem_detail(self, idx: int) -> dict[str, Any]: |
| | row = self._ds[idx] |
| | bug_id = row.get("bug_id", str(idx)) |
| | project = self._project_from_bug_id(bug_id) |
| | buggy = row.get("func_before", "") |
| | fixed = row.get("func_after", "") |
| | return { |
| | "idx": idx, |
| | "task_id": bug_id, |
| | "entry_point": project, |
| | "code": fixed, |
| | "highlighted_code": _highlight_code(fixed, language="java") if fixed else "", |
| | "inputs": [], |
| | "outputs": [], |
| | "test": None, |
| | "tasks": [], |
| | "source": project, |
| | "has_ground_truth": False, |
| | "has_tasks": False, |
| | "description": "", |
| | "buggy_code": buggy, |
| | "buggy_highlighted_code": _highlight_code(buggy, language="java") if buggy else "", |
| | "fixed_code": fixed, |
| | "fixed_highlighted_code": _highlight_code(fixed, language="java") if fixed else "", |
| | "bug_category": "Bug Fix", |
| | "bug_subtype": project, |
| | "bug_explanation": "", |
| | "language": "Java", |
| | } |
| |
|