|
|
"""Dataset loader for RAG Bench datasets.""" |
|
|
import os |
|
|
from typing import List, Dict, Optional |
|
|
from datasets import load_dataset |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
class RAGBenchLoader: |
|
|
"""Load and manage RAG Bench datasets.""" |
|
|
|
|
|
SUPPORTED_DATASETS = [ |
|
|
'covidqa', |
|
|
'cuad', |
|
|
'delucionqa', |
|
|
'emanual', |
|
|
'expertqa', |
|
|
'finqa', |
|
|
'hagrid', |
|
|
'hotpotqa', |
|
|
'msmarco', |
|
|
'pubmedqa', |
|
|
'tatqa', |
|
|
'techqa' |
|
|
] |
|
|
|
|
|
def __init__(self, cache_dir: str = "./data_cache"): |
|
|
"""Initialize the dataset loader. |
|
|
|
|
|
Args: |
|
|
cache_dir: Directory to cache downloaded datasets |
|
|
""" |
|
|
self.cache_dir = cache_dir |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
def load_dataset(self, dataset_name: str, split: str = "test", |
|
|
max_samples: Optional[int] = None) -> List[Dict]: |
|
|
"""Load a RAG Bench dataset from rungalileo/ragbench. |
|
|
|
|
|
Args: |
|
|
dataset_name: Name of the dataset to load |
|
|
split: Dataset split (train/validation/test) |
|
|
max_samples: Maximum number of samples to load |
|
|
|
|
|
Returns: |
|
|
List of dictionaries containing dataset samples |
|
|
""" |
|
|
if dataset_name not in self.SUPPORTED_DATASETS: |
|
|
raise ValueError(f"Unsupported dataset: {dataset_name}. " |
|
|
f"Supported: {self.SUPPORTED_DATASETS}") |
|
|
|
|
|
print(f"Loading {dataset_name} dataset ({split} split) from rungalileo/ragbench...") |
|
|
|
|
|
try: |
|
|
|
|
|
dataset = load_dataset("rungalileo/ragbench", dataset_name, split=split, |
|
|
cache_dir=self.cache_dir) |
|
|
|
|
|
processed_data = [] |
|
|
samples = dataset if max_samples is None else dataset.select(range(min(max_samples, len(dataset)))) |
|
|
|
|
|
|
|
|
for item in tqdm(samples, desc=f"Processing {dataset_name}"): |
|
|
processed_data.append(self._process_ragbench_item(item, dataset_name)) |
|
|
|
|
|
print(f"Loaded {len(processed_data)} samples from {dataset_name}") |
|
|
return processed_data |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading {dataset_name}: {str(e)}") |
|
|
print("Falling back to sample data for testing...") |
|
|
return self._create_sample_data(dataset_name, max_samples or 10) |
|
|
|
|
|
def _process_ragbench_item(self, item: Dict, dataset_name: str) -> Dict: |
|
|
"""Process a single RAGBench dataset item into standardized format. |
|
|
|
|
|
Args: |
|
|
item: Raw dataset item |
|
|
dataset_name: Name of the dataset |
|
|
|
|
|
Returns: |
|
|
Processed item dictionary |
|
|
""" |
|
|
|
|
|
processed = { |
|
|
"question": item.get("question", ""), |
|
|
"answer": item.get("answer", ""), |
|
|
"context": "", |
|
|
"documents": [], |
|
|
"dataset": dataset_name |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if "documents" in item: |
|
|
if isinstance(item["documents"], list): |
|
|
processed["documents"] = [str(doc) for doc in item["documents"]] |
|
|
processed["context"] = " ".join(processed["documents"]) |
|
|
else: |
|
|
processed["documents"] = [str(item["documents"])] |
|
|
processed["context"] = str(item["documents"]) |
|
|
elif "retrieved_contexts" in item: |
|
|
if isinstance(item["retrieved_contexts"], list): |
|
|
processed["documents"] = [str(ctx) for ctx in item["retrieved_contexts"]] |
|
|
processed["context"] = " ".join(processed["documents"]) |
|
|
else: |
|
|
processed["documents"] = [str(item["retrieved_contexts"])] |
|
|
processed["context"] = str(item["retrieved_contexts"]) |
|
|
elif "context" in item: |
|
|
if isinstance(item["context"], list): |
|
|
processed["documents"] = [str(ctx) for ctx in item["context"]] |
|
|
processed["context"] = " ".join(processed["documents"]) |
|
|
else: |
|
|
processed["documents"] = [str(item["context"])] |
|
|
processed["context"] = str(item["context"]) |
|
|
|
|
|
|
|
|
if "metadata" in item: |
|
|
processed["metadata"] = item["metadata"] |
|
|
|
|
|
return processed |
|
|
|
|
|
def load_all_datasets(self, split: str = "test", max_samples: Optional[int] = None) -> Dict[str, List[Dict]]: |
|
|
"""Load all RAGBench datasets. |
|
|
|
|
|
Args: |
|
|
split: Dataset split to load |
|
|
max_samples: Maximum samples per dataset |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping dataset names to their data |
|
|
""" |
|
|
all_data = {} |
|
|
for dataset_name in self.SUPPORTED_DATASETS: |
|
|
print(f"\n{'='*50}") |
|
|
print(f"Loading {dataset_name}...") |
|
|
print(f"{'='*50}") |
|
|
try: |
|
|
all_data[dataset_name] = self.load_dataset(dataset_name, split, max_samples) |
|
|
except Exception as e: |
|
|
print(f"Failed to load {dataset_name}: {str(e)}") |
|
|
all_data[dataset_name] = [] |
|
|
|
|
|
return all_data |
|
|
|
|
|
def _create_sample_data(self, dataset_name: str, num_samples: int) -> List[Dict]: |
|
|
"""Create sample data for testing when actual dataset is unavailable.""" |
|
|
sample_data = [] |
|
|
for i in range(num_samples): |
|
|
|
|
|
sample_docs = [ |
|
|
f"Document 1: This is the first sample document {i+1} for {dataset_name} dataset. " |
|
|
f"It contains relevant information to answer the question.", |
|
|
f"Document 2: This is the second sample document {i+1} providing additional context. " |
|
|
f"It includes more details about the topic.", |
|
|
f"Document 3: This is the third sample document {i+1} with supplementary information." |
|
|
] |
|
|
|
|
|
sample_data.append({ |
|
|
"question": f"Sample question {i+1} for {dataset_name}?", |
|
|
"answer": f"Sample answer {i+1}", |
|
|
"documents": sample_docs, |
|
|
"context": " ".join(sample_docs), |
|
|
"dataset": dataset_name |
|
|
}) |
|
|
return sample_data |
|
|
|
|
|
def get_test_data(self, dataset_name: str, num_samples: int = 100) -> List[Dict]: |
|
|
"""Get test data for TRACE evaluation. |
|
|
|
|
|
Args: |
|
|
dataset_name: Name of the dataset |
|
|
num_samples: Number of test samples |
|
|
|
|
|
Returns: |
|
|
List of test samples |
|
|
""" |
|
|
return self.load_dataset(dataset_name, split="test", max_samples=num_samples) |