|
|
import json |
|
|
from typing import List, Dict |
|
|
|
|
|
|
|
|
class DataLoader: |
|
|
"""Load and process datasets""" |
|
|
|
|
|
@staticmethod |
|
|
def load_data(file_path: str) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Load data from JSON file |
|
|
Expected format: [{"query": "...", "answer": "..."}, ...] |
|
|
""" |
|
|
with open(file_path, 'r') as f: |
|
|
if file_path.endswith('.jsonl'): |
|
|
data = [json.loads(line) for line in f] |
|
|
else: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
for item in data: |
|
|
if 'query' not in item or 'answer' not in item: |
|
|
raise ValueError("Each data item must have 'query' and 'answer' fields") |
|
|
|
|
|
return data |
|
|
|
|
|
@staticmethod |
|
|
def load_math_dataset(file_path: str) -> List[Dict[str, str]]: |
|
|
"""Load MATH or GSM8K format dataset""" |
|
|
return DataLoader.load_data(file_path) |
|
|
|
|
|
@staticmethod |
|
|
def load_mmlu_dataset(file_path: str) -> List[Dict[str, str]]: |
|
|
"""Load MMLU-Pro format dataset""" |
|
|
return DataLoader.load_data(file_path) |