"""Dataset loader for RAG Bench datasets.""" import os from typing import List, Dict, Optional from datasets import load_dataset import pandas as pd from tqdm import tqdm class RAGBenchLoader: """Load and manage RAG Bench datasets.""" SUPPORTED_DATASETS = [ 'covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa' ] def __init__(self, cache_dir: str = "./data_cache"): """Initialize the dataset loader. Args: cache_dir: Directory to cache downloaded datasets """ self.cache_dir = cache_dir os.makedirs(cache_dir, exist_ok=True) def load_dataset(self, dataset_name: str, split: str = "test", max_samples: Optional[int] = None) -> List[Dict]: """Load a RAG Bench dataset from rungalileo/ragbench. Args: dataset_name: Name of the dataset to load split: Dataset split (train/validation/test) max_samples: Maximum number of samples to load Returns: List of dictionaries containing dataset samples """ if dataset_name not in self.SUPPORTED_DATASETS: raise ValueError(f"Unsupported dataset: {dataset_name}. " f"Supported: {self.SUPPORTED_DATASETS}") print(f"Loading {dataset_name} dataset ({split} split) from rungalileo/ragbench...") try: # Load from rungalileo/ragbench dataset = load_dataset("rungalileo/ragbench", dataset_name, split=split, cache_dir=self.cache_dir) processed_data = [] samples = dataset if max_samples is None else dataset.select(range(min(max_samples, len(dataset)))) # Process the dataset for item in tqdm(samples, desc=f"Processing {dataset_name}"): processed_data.append(self._process_ragbench_item(item, dataset_name)) print(f"Loaded {len(processed_data)} samples from {dataset_name}") return processed_data except Exception as e: print(f"Error loading {dataset_name}: {str(e)}") print("Falling back to sample data for testing...") return self._create_sample_data(dataset_name, max_samples or 10) def _process_ragbench_item(self, item: Dict, dataset_name: str) -> Dict: """Process a single RAGBench dataset item into standardized format. Args: item: Raw dataset item dataset_name: Name of the dataset Returns: Processed item dictionary """ # RAGBench datasets typically have: question, documents, answer, and retrieved_contexts processed = { "question": item.get("question", ""), "answer": item.get("answer", ""), "context": "", # For embedding and retrieval "documents": [], # Store original documents list "dataset": dataset_name, "ground_truth_scores": {} # NEW: Extract ground truth evaluation scores } # Extract documents - RAGBench uses 'documents' as primary source for embeddings # Priority: documents > retrieved_contexts > context if "documents" in item: if isinstance(item["documents"], list): processed["documents"] = [str(doc) for doc in item["documents"]] processed["context"] = " ".join(processed["documents"]) else: processed["documents"] = [str(item["documents"])] processed["context"] = str(item["documents"]) elif "retrieved_contexts" in item: if isinstance(item["retrieved_contexts"], list): processed["documents"] = [str(ctx) for ctx in item["retrieved_contexts"]] processed["context"] = " ".join(processed["documents"]) else: processed["documents"] = [str(item["retrieved_contexts"])] processed["context"] = str(item["retrieved_contexts"]) elif "context" in item: if isinstance(item["context"], list): processed["documents"] = [str(ctx) for ctx in item["context"]] processed["context"] = " ".join(processed["documents"]) else: processed["documents"] = [str(item["context"])] processed["context"] = str(item["context"]) # Extract ground truth evaluation scores from RAGBench dataset # These are pre-computed metrics from the RAGBench paper ground_truth_scores = {} # Extract metric scores - try multiple possible field names # RAGBench paper uses these metric names (with various possible field formats) score_fields = [ # (possible_field_names, canonical_metric_name) (["relevance_score", "context_relevance", "relevance", "R"], "context_relevance"), (["utilization_score", "context_utilization", "utilization", "T"], "context_utilization"), (["completeness_score", "completeness", "C"], "completeness"), (["adherence_score", "adherence", "A", "overall_supported"], "adherence"), ] for field_names, metric_name in score_fields: for field_name in field_names: if field_name in item: try: # Handle string/numeric conversion score_value = item[field_name] if isinstance(score_value, bool): # Boolean adherence: True=1.0, False=0.0 score_value = 1.0 if score_value else 0.0 elif isinstance(score_value, str): # Try to convert string to float if score_value.lower() in ['true', 'yes']: score_value = 1.0 elif score_value.lower() in ['false', 'no']: score_value = 0.0 else: score_value = float(score_value) ground_truth_scores[metric_name] = float(score_value) break # Found this metric, move to next except (ValueError, TypeError): continue # Try next field name # Store ground truth scores if any were found if ground_truth_scores: processed["ground_truth_scores"] = ground_truth_scores # Store additional metadata if available if "metadata" in item: processed["metadata"] = item["metadata"] return processed def load_all_datasets(self, split: str = "test", max_samples: Optional[int] = None) -> Dict[str, List[Dict]]: """Load all RAGBench datasets. Args: split: Dataset split to load max_samples: Maximum samples per dataset Returns: Dictionary mapping dataset names to their data """ all_data = {} for dataset_name in self.SUPPORTED_DATASETS: print(f"\n{'='*50}") print(f"Loading {dataset_name}...") print(f"{'='*50}") try: all_data[dataset_name] = self.load_dataset(dataset_name, split, max_samples) except Exception as e: print(f"Failed to load {dataset_name}: {str(e)}") all_data[dataset_name] = [] return all_data def _create_sample_data(self, dataset_name: str, num_samples: int) -> List[Dict]: """Create sample data for testing when actual dataset is unavailable.""" sample_data = [] for i in range(num_samples): # Create multiple sample documents per question sample_docs = [ f"Document 1: This is the first sample document {i+1} for {dataset_name} dataset. " f"It contains relevant information to answer the question.", f"Document 2: This is the second sample document {i+1} providing additional context. " f"It includes more details about the topic.", f"Document 3: This is the third sample document {i+1} with supplementary information." ] sample_data.append({ "question": f"Sample question {i+1} for {dataset_name}?", "answer": f"Sample answer {i+1}", "documents": sample_docs, "context": " ".join(sample_docs), # Combined for backward compatibility "dataset": dataset_name }) return sample_data def get_test_data(self, dataset_name: str, num_samples: int = 100) -> List[Dict]: """Get test data for TRACE evaluation. Args: dataset_name: Name of the dataset num_samples: Number of test samples Returns: List of test samples """ return self.load_dataset(dataset_name, split="test", max_samples=num_samples) def get_test_data_size(self, dataset_name: str) -> int: """Get the total number of test samples available in a dataset. Args: dataset_name: Name of the dataset Returns: Total number of test samples available """ try: from datasets import load_dataset_builder # Load dataset builder to get dataset info builder = load_dataset_builder("rungalileo/ragbench", dataset_name) # Try to get test split size if hasattr(builder.info, 'splits') and builder.info.splits: if 'test' in builder.info.splits: return builder.info.splits['test'].num_examples elif 'validation' in builder.info.splits: return builder.info.splits['validation'].num_examples else: # Get first available split first_split = list(builder.info.splits.keys())[0] return builder.info.splits[first_split].num_examples # Fallback: load full test dataset to count ds = load_dataset("rungalileo/ragbench", dataset_name, split="test", trust_remote_code=True) return len(ds) except Exception as e: print(f"Error getting test data size: {e}") # Return a reasonable default return 100