CapStoneRAG10 / dataset_loader.py
Developer
Initial commit for HuggingFace Spaces - RAG Capstone Project with Qdrant Cloud
1d10b0a
"""Dataset loader for RAG Bench datasets."""
import os
from typing import List, Dict, Optional
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
class RAGBenchLoader:
"""Load and manage RAG Bench datasets."""
SUPPORTED_DATASETS = [
'covidqa',
'cuad',
'delucionqa',
'emanual',
'expertqa',
'finqa',
'hagrid',
'hotpotqa',
'msmarco',
'pubmedqa',
'tatqa',
'techqa'
]
def __init__(self, cache_dir: str = "./data_cache"):
"""Initialize the dataset loader.
Args:
cache_dir: Directory to cache downloaded datasets
"""
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
def load_dataset(self, dataset_name: str, split: str = "test",
max_samples: Optional[int] = None) -> List[Dict]:
"""Load a RAG Bench dataset from rungalileo/ragbench.
Args:
dataset_name: Name of the dataset to load
split: Dataset split (train/validation/test)
max_samples: Maximum number of samples to load
Returns:
List of dictionaries containing dataset samples
"""
if dataset_name not in self.SUPPORTED_DATASETS:
raise ValueError(f"Unsupported dataset: {dataset_name}. "
f"Supported: {self.SUPPORTED_DATASETS}")
print(f"Loading {dataset_name} dataset ({split} split) from rungalileo/ragbench...")
try:
# Load from rungalileo/ragbench
dataset = load_dataset("rungalileo/ragbench", dataset_name, split=split,
cache_dir=self.cache_dir)
processed_data = []
samples = dataset if max_samples is None else dataset.select(range(min(max_samples, len(dataset))))
# Process the dataset
for item in tqdm(samples, desc=f"Processing {dataset_name}"):
processed_data.append(self._process_ragbench_item(item, dataset_name))
print(f"Loaded {len(processed_data)} samples from {dataset_name}")
return processed_data
except Exception as e:
print(f"Error loading {dataset_name}: {str(e)}")
print("Falling back to sample data for testing...")
return self._create_sample_data(dataset_name, max_samples or 10)
def _process_ragbench_item(self, item: Dict, dataset_name: str) -> Dict:
"""Process a single RAGBench dataset item into standardized format.
Args:
item: Raw dataset item
dataset_name: Name of the dataset
Returns:
Processed item dictionary
"""
# RAGBench datasets typically have: question, documents, answer, and retrieved_contexts
processed = {
"question": item.get("question", ""),
"answer": item.get("answer", ""),
"context": "", # For embedding and retrieval
"documents": [], # Store original documents list
"dataset": dataset_name,
"ground_truth_scores": {} # NEW: Extract ground truth evaluation scores
}
# Extract documents - RAGBench uses 'documents' as primary source for embeddings
# Priority: documents > retrieved_contexts > context
if "documents" in item:
if isinstance(item["documents"], list):
processed["documents"] = [str(doc) for doc in item["documents"]]
processed["context"] = " ".join(processed["documents"])
else:
processed["documents"] = [str(item["documents"])]
processed["context"] = str(item["documents"])
elif "retrieved_contexts" in item:
if isinstance(item["retrieved_contexts"], list):
processed["documents"] = [str(ctx) for ctx in item["retrieved_contexts"]]
processed["context"] = " ".join(processed["documents"])
else:
processed["documents"] = [str(item["retrieved_contexts"])]
processed["context"] = str(item["retrieved_contexts"])
elif "context" in item:
if isinstance(item["context"], list):
processed["documents"] = [str(ctx) for ctx in item["context"]]
processed["context"] = " ".join(processed["documents"])
else:
processed["documents"] = [str(item["context"])]
processed["context"] = str(item["context"])
# Extract ground truth evaluation scores from RAGBench dataset
# These are pre-computed metrics from the RAGBench paper
ground_truth_scores = {}
# Extract metric scores - try multiple possible field names
# RAGBench paper uses these metric names (with various possible field formats)
score_fields = [
# (possible_field_names, canonical_metric_name)
(["relevance_score", "context_relevance", "relevance", "R"], "context_relevance"),
(["utilization_score", "context_utilization", "utilization", "T"], "context_utilization"),
(["completeness_score", "completeness", "C"], "completeness"),
(["adherence_score", "adherence", "A", "overall_supported"], "adherence"),
]
for field_names, metric_name in score_fields:
for field_name in field_names:
if field_name in item:
try:
# Handle string/numeric conversion
score_value = item[field_name]
if isinstance(score_value, bool):
# Boolean adherence: True=1.0, False=0.0
score_value = 1.0 if score_value else 0.0
elif isinstance(score_value, str):
# Try to convert string to float
if score_value.lower() in ['true', 'yes']:
score_value = 1.0
elif score_value.lower() in ['false', 'no']:
score_value = 0.0
else:
score_value = float(score_value)
ground_truth_scores[metric_name] = float(score_value)
break # Found this metric, move to next
except (ValueError, TypeError):
continue # Try next field name
# Store ground truth scores if any were found
if ground_truth_scores:
processed["ground_truth_scores"] = ground_truth_scores
# Store additional metadata if available
if "metadata" in item:
processed["metadata"] = item["metadata"]
return processed
def load_all_datasets(self, split: str = "test", max_samples: Optional[int] = None) -> Dict[str, List[Dict]]:
"""Load all RAGBench datasets.
Args:
split: Dataset split to load
max_samples: Maximum samples per dataset
Returns:
Dictionary mapping dataset names to their data
"""
all_data = {}
for dataset_name in self.SUPPORTED_DATASETS:
print(f"\n{'='*50}")
print(f"Loading {dataset_name}...")
print(f"{'='*50}")
try:
all_data[dataset_name] = self.load_dataset(dataset_name, split, max_samples)
except Exception as e:
print(f"Failed to load {dataset_name}: {str(e)}")
all_data[dataset_name] = []
return all_data
def _create_sample_data(self, dataset_name: str, num_samples: int) -> List[Dict]:
"""Create sample data for testing when actual dataset is unavailable."""
sample_data = []
for i in range(num_samples):
# Create multiple sample documents per question
sample_docs = [
f"Document 1: This is the first sample document {i+1} for {dataset_name} dataset. "
f"It contains relevant information to answer the question.",
f"Document 2: This is the second sample document {i+1} providing additional context. "
f"It includes more details about the topic.",
f"Document 3: This is the third sample document {i+1} with supplementary information."
]
sample_data.append({
"question": f"Sample question {i+1} for {dataset_name}?",
"answer": f"Sample answer {i+1}",
"documents": sample_docs,
"context": " ".join(sample_docs), # Combined for backward compatibility
"dataset": dataset_name
})
return sample_data
def get_test_data(self, dataset_name: str, num_samples: int = 100) -> List[Dict]:
"""Get test data for TRACE evaluation.
Args:
dataset_name: Name of the dataset
num_samples: Number of test samples
Returns:
List of test samples
"""
return self.load_dataset(dataset_name, split="test", max_samples=num_samples)
def get_test_data_size(self, dataset_name: str) -> int:
"""Get the total number of test samples available in a dataset.
Args:
dataset_name: Name of the dataset
Returns:
Total number of test samples available
"""
try:
from datasets import load_dataset_builder
# Load dataset builder to get dataset info
builder = load_dataset_builder("rungalileo/ragbench", dataset_name)
# Try to get test split size
if hasattr(builder.info, 'splits') and builder.info.splits:
if 'test' in builder.info.splits:
return builder.info.splits['test'].num_examples
elif 'validation' in builder.info.splits:
return builder.info.splits['validation'].num_examples
else:
# Get first available split
first_split = list(builder.info.splits.keys())[0]
return builder.info.splits[first_split].num_examples
# Fallback: load full test dataset to count
ds = load_dataset("rungalileo/ragbench", dataset_name, split="test", trust_remote_code=True)
return len(ds)
except Exception as e:
print(f"Error getting test data size: {e}")
# Return a reasonable default
return 100