migratron / code_migration /dataset_loader.py
amrithanandini's picture
integrated backend and frontend
1b35d41
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Dataset loader for TimeMachine-bench JSONL files."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
@dataclass
class Task:
"""A single migration task from the TimeMachine-bench dataset."""
repo_name: str
repo_url: str
commit_hash: str
patch: str
test_patch: str
gold_patch: str
reproduction_target_date: str
reproduction_target_version: str
migration_target_date: str
migration_target_version: str
dockerfile: str
version_source: str
script_source: str
dependency_versions: str
test_type: str
test_files: str
test_count: int
related_modules: str
py_file_count: int
total_loc_python: int
difficulty: str
license: str
# Default bundled dataset path (relative to this file)
_DEFAULT_DATASET_PATH = (
Path(__file__).resolve().parent
/ "data"
/ "timemachine-bench-verified.jsonl"
)
class DatasetLoader:
"""Load and query TimeMachine-bench tasks from a JSONL file."""
def __init__(self, dataset_path: str | None = None) -> None:
self._path = Path(dataset_path) if dataset_path else _DEFAULT_DATASET_PATH
self._tasks: list[Task] = self.load()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def load(self) -> list[Task]:
"""Parse the JSONL file and return a list of Task records.
Raises:
FileNotFoundError: If the dataset file does not exist.
ValueError: If a line contains malformed JSON (includes file path
and 1-based line number in the message).
"""
path = self._path
if not path.exists():
raise FileNotFoundError(f"Dataset file not found: {path}")
tasks: list[Task] = []
with open(path, "r", encoding="utf-8") as fh:
for line_no, raw_line in enumerate(fh, start=1):
raw_line = raw_line.strip()
if not raw_line:
continue
try:
record = json.loads(raw_line)
except json.JSONDecodeError as exc:
raise ValueError(
f"Malformed JSON at {path} line {line_no}: {exc}"
) from exc
tasks.append(Task(**record))
return tasks
def filter_by_difficulty(self, difficulty: str) -> list[Task]:
"""Return tasks matching the given difficulty level."""
return [t for t in self._tasks if t.difficulty == difficulty]
def get_by_repo_name(self, repo_name: str) -> Task | None:
"""Return the first task with the given repo_name, or None."""
for t in self._tasks:
if t.repo_name == repo_name:
return t
return None
def __len__(self) -> int:
return len(self._tasks)
def __getitem__(self, index: int) -> Task:
return self._tasks[index]